Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-updating OCR model for reading Captchas Keras 3 example (TF-Only) #1843

Open
wants to merge 2 commits into
base: master
Choose a base branch
from

Conversation

sitamgithub-MSIT
Copy link
Contributor

@sitamgithub-MSIT sitamgithub-MSIT commented Apr 25, 2024

Re-updates the PR here #1788

This PR updates the OCR model for reading Captchas Keras 3.0 example [TF Only Backend]. Recently released Keras ops ctc_decode is added in place of the custom function.

For example, here is the notebook link provided:
https://colab.research.google.com/drive/1vCDb45wLmSI3iBI2_BfDDYDSztgxZ4Qp?usp=sharing

cc: @fchollet

The following describes the Git difference for the changed files:

Changes:
diff --git a/examples/vision/captcha_ocr.py b/examples/vision/captcha_ocr.py
index a6bac599..6ebf1a7a 100644
--- a/examples/vision/captcha_ocr.py
+++ b/examples/vision/captcha_ocr.py
@@ -359,30 +359,6 @@ and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io
 """
 
 
-def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
-    input_shape = ops.shape(y_pred)
-    num_samples, num_steps = input_shape[0], input_shape[1]
-    y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon())
-    input_length = ops.cast(input_length, dtype="int32")
-
-    if greedy:
-        (decoded, log_prob) = tf.nn.ctc_greedy_decoder(
-            inputs=y_pred, sequence_length=input_length
-        )
-    else:
-        (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder(
-            inputs=y_pred,
-            sequence_length=input_length,
-            beam_width=beam_width,
-            top_paths=top_paths,
-        )
-    decoded_dense = []
-    for st in decoded:
-        st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps))
-        decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1))
-    return (decoded_dense, log_prob)
-
-
 # Get the prediction model by extracting layers till the output layer
 prediction_model = keras.models.Model(
     model.input[0], model.get_layer(name="dense2").output
@@ -394,9 +370,11 @@ prediction_model.summary()
 def decode_batch_predictions(pred):
     input_len = np.ones(pred.shape[0]) * pred.shape[1]
     # Use greedy search. For complex tasks, you can use beam search
-    results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][
-        :, :max_length
-    ]
+    results = ops.ctc_decode(pred, sequence_lengths=input_len, strategy="greedy")[0][0]
+    # Convert the SparseTensor to a dense tensor
+    dense_results = tf.sparse.to_dense(results, default_value=-1)
+    # Slice the dense tensor to keep only up to max_length
+    dense_results = dense_results[:, :max_length]
     # Iterate over the results and get back the text
     output_text = []
     for res in results:
(END)
Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Are you also able to replace ctc_batch_cost with keras.losses.CTC?

@sitamgithub-MSIT
Copy link
Contributor Author

Thanks for the PR! Are you also able to replace ctc_batch_cost with keras.losses.CTC?

I had thought about it and was about to discuss it with you. So clearly, we have two OCR examples: this and a handwritten one. Previously, you could see the original captcha PR and handwriting PR. Both architectures are identical and came from the same authors. They mentioned why they added CTC as a layer rather than a direct loss in model.compile().

Now, following the updated Keras 3 PR, we provided the TF backend while using TF compat.v1 and all. Recently, this draft PR demonstrated the implementation of CTC as a loss with no (tf compat.v1 and all) using keras.losses.CTC and ctc_decode.

In response to your query, implementing it with Keras CTC loss will not be a problem because both examples use the same architecture. So, instead of utilizing the CTC layer, would you like to modify both the OCR example with the CTC loss implementation along with ctc_decode ops?

@sitamgithub-MSIT
Copy link
Contributor Author

I tried with keras.loss.CTC() implementation similar to what is done here and model.fit() giving this error:

ValueError: None values not supported.

cc: @fchollet

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants