Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert 'Vision Transformer without Attention' to Keras 3. #1855

Draft
wants to merge 19 commits into
base: master
Choose a base branch
from

Conversation

fkouteib
Copy link
Contributor

@fkouteib fkouteib commented May 4, 2024

Tensorflow and PyTorch only compatibilty.

@fkouteib
Copy link
Contributor Author

fkouteib commented May 4, 2024

On Tensorflow, I am able to train and test the model, but hit this issue when loading the saved model to do inference on it. It may be the same issue as keras-team/keras#19492 but I am not 100% sure.

$HOME/.tf_venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py:418: UserWarning: Skipping variable loading for optimizer 'adamw', because it has 1 variables whereas the saved optimizer has 219 variables.
trackable.load_own_variables(weights_store.get(inner_path))
Traceback (most recent call last):
File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 1092, in
probabilities = predict(predict_ds)
File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 1062, in predict
logits = saved_model.predict(predict_ds)
File "$HOME/.tf_venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 720, in call
augmented_images = self.data_augmentation(images)
TypeError: Exception encountered when calling ShiftViTModel.call().
'TrackedDict' object is not callable
Arguments received by ShiftViTModel.call():
• images=tf.Tensor(shape=(10, 32, 32, 3), dtype=uint8)

@fkouteib fkouteib marked this pull request as draft May 4, 2024 19:36
@fkouteib
Copy link
Contributor Author

fkouteib commented May 4, 2024

On PyTorch, I am hitting this issue when compiling the initial model before training starts.

File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 937, in
model(sample_ds, training=False)
File "/$HOME/.torch_venv/lib/python3.10/site-packages/keras/src/utils/traceback_utils.py", line 122, in error_handler
raise e.with_traceback(filtered_tb) from None
File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "$HOMEkeras-io_rw/examples/vision/shiftvit.py", line 723, in call
x = stage(x, training=False)
File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/$HOME/keras-io_rw/examples/vision/shiftvit.py", line 569, in call
x = shift_block(x, training=training)
File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "$HOME/.torch_venv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "$HOME/keras-io_rw/examples/vision/shiftvit.py", line 429, in call
x_splits[0] = self.get_shift_pad(x_splits[0], mode="left")
TypeError: Exception encountered when calling ShiftViTBlock.call().
'tuple' object does not support item assignment
Arguments received by ShiftViTBlock.call():
• x=torch.Tensor(shape=torch.Size([256, 12, 12, 96]), dtype=float32)
• training=False

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

I believe you may be able to make the code fully-backend agnostic without implementing backend-specific train_steps. Instead, you could override compute_loss() and make it work with all backends. The train step is generic and only loss computation appears to be custom.

@fkouteib
Copy link
Contributor Author

fkouteib commented May 5, 2024

Thx for the review and suggestion Francois! I dropped the custom train and test steps. The combination of overriding call() method and the native compute_loss() method was equivalent to the custom loss method.

Current issues I am debugging:

  • Tensorflow: passes train and test steps. Failing inference when switching to use a loaded *.keras model with custom objects (see first comment for full error).
  • Pytorch: fails train step on first epoch with "'tuple' object does not support item assignment" (see second comment above for full error).
  • JAX: fails train step on first epoch with error below.

Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was wrapped_fn at /home/faycel.kouteib/.tf_jax_venv/lib/python3.10/site-packages/keras/src/backend/jax/core.py:153 traced for make_jaxpr.
The leaked intermediate value was created on line /home/faycel.kouteib/keras-io_rw/examples/vision/shiftvit.py:535 ().
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
$HOME/.tf_jax_venv/lib/python3.10/site-packages/keras/src/layers/layer.py:771 (call)
$HOME/.tf_jax_venv/lib/python3.10/site-packages/keras/src/layers/layer.py:1279 (_maybe_build)
$HOME/.tf_jax_venv/lib/python3.10/site-packages/keras/src/layers/layer.py:223 (build_wrapper)
$HOME/keras-io_rw/examples/vision/shiftvit.py:535 (build)
$HOME/keras-io_rw/examples/vision/shiftvit.py:535 ()


Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates.

torch issue

This caught a bug: ops.split is supposed to return a list, but for torch it returns a tuple (same as torch.split). I fixed it. You can route around it by creating an output list and appending elements to it. Once done, the code runs with torch.

tf issue with deserialization

You need to call deserialize_keras_object on the models/layers passed to constructors, to enable deserialization / model loading.

e.g.

 self.data_augmentation = keras.saving.deserialize_keras_object(data_augmentation)

jax issue

This one has to do with tracer leaks. Those problems are unique to JAX and can be tricky to debug. A first problem is using ops.linspace instead of np.linspace in build(). There are further issues down the line however.

# Update the metrics
self.compiled_metrics.update_state(labels, logits)
return {m.name: m.result() for m in self.metrics}

def call(self, images):
augmented_images = self.data_augmentation(images)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Surely this should only be applied at training time? Also, we may consider moving it to the data pipeline instead of inside the model.

Copy link
Contributor Author

@fkouteib fkouteib May 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. Only active at training time. see here for full context, but block level comment summarizes this well.

The augmentation pipeline consists of:
Rescaling
Resizing
Random cropping
Random horizontal flipping
Note: The image data augmentation layers do not apply data transformations at inference time. This means that when these layers are called with training=False they behave differently. Refer to the [documentation (https://keras.io/api/layers/preprocessing_layers/image_augmentation/) for more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants