Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convnext architecture dev #356

Open
wants to merge 9 commits into
base: development
Choose a base branch
from

Conversation

Lorenzobattistela
Copy link
Contributor

Implementing ConvNeXt architecture referred in this paper and #353 .

Reviewer: @owenvallis

This PR re implements #354 but based on the development branch to fix some test and formatting issues.

@Lorenzobattistela
Copy link
Contributor Author

Test errors on CI: ImportError: cannot import name 'convnext' from 'tensorflow.keras.applications'
Maybe convnext wasnt a keras.applications module in the version of tensorflow CI is using.
(https://www.tensorflow.org/api_docs/python/tf/keras/applications/convnext)

In my local, all tests run ok. My TF version is 2.13.0

@owenvallis
Copy link
Collaborator

Hi @Lorenzobattistela, looks like convnext wasn't introduced until TF v2.10. I wonder if we can do a TF version check for this? The other option is we could provide a general architecture class that accepts a tf.keras.application as input and wraps it. I'm not sure if it would be simple to apply to all applications, but would be cleaner for the package and avoids the version issue altogether.

@Lorenzobattistela
Copy link
Contributor Author

So, @owenvallis I thought about what you said. Anyway, I updated the code to do some version checking on test (to skip if tf < 2.10), and maybe we can add some version checking to inform a more useful error to the user if it tries to use it with a minor tf version.

However, I think the refactoring path to a wrapper for keras applications is the best approach. I'm willing to work on this, will start refactoring it.

It is up to you to merge or not this, maybe we can use this as a "hotfix" and then refactor to something better. Thanks for the review.

@Lorenzobattistela
Copy link
Contributor Author

@owenvallis something is going wrong with isort, but i did ran it

convnext.trainable = True
for layer in convnext.layers:
# freeze all layeres befor the last 3 blocks
if not re.search("^block[5,6,7]|^top", layer.name):
Copy link

@erikreed erikreed Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also trying out this architecture. But does this EfficientNetV2 layer naming apply to convnext?

model = tf.keras.applications.ConvNeXtBase()
[l.name for l in model.layers if re.search("^block[5,6,7]|^top", l.name)]
# this outputs []

The test also suggests partial is not being applied as expected since the number of trainable layers is 0 with partial.


edit: another candidate might be "convnext_base_stage_3_block_2", also unfreezing the last layer norm since it comes after the final block.

model.trainable = True
for layer in model.layers:
    # freeze all layers before the last block
    if not re.search("^convnext_base_stage_3_block_2", layer.name):
        layer.trainable = False
model.layers[-1].trainable = True

This results in about 10% of weights being unfrozen and only the final block [1].

Total params: 87566464 (334.04 MB)
Trainable params: 8450048 (32.23 MB)
Non-trainable params: 79116416 (301.81 MB)

[1]
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants