-
Notifications
You must be signed in to change notification settings - Fork 658
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some custom objects are not being serialized with push_to_hub_keras #595
Comments
Similarly with
|
I was successful in uploading the custom objects with
class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
def __init__(self, **kwargs):
super().__init__(**kwargs)
# The trainable temperature term. The initial value is
# the square root of the key dimension.
self.tau = tf.Variable(
math.sqrt(float(self._key_dim)),
trainable=True
)
# Build the diagonal attention mask
diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
self.diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)
def get_config(self):
config = super().get_config()
config.update({
"tau": self.tau.numpy(), #<---- IMPORTANT
"diag_attn_mask": self.diag_attn_mask.numpy(), #<---- IMPORTANT
})
return config
Colab Notebook usedUsage of the pre-trained
|
Hi @ariG23498! This is great and very insightful! 🤗 The I wonder if there's any way to achieve this programmatically instead of expecting users having to implement this (cc @Rocketknight1 or @gante might have some ideas). Worst case, we could throw some warning when there are custom layers being saved and pointing to documentation. WDYT? |
@osanseviero I feel like this might be not specific to 2.7. I saw @ariG23498 downgraded TF to 2.6 to save the model, so model was still saved (his first issue was related to that and we fixed it that way) but custom layer was still needed to be registered for us to load the model, the error we got about AdamW was related to that (see below) and it has nothing to do with 2.7. |
Yep. This is the only way to go. Even TensorFlow throws an error while trying to save a custom model. It is a platform specific error that needs to be checked by the user and not the HF team IMO. |
I agree with what @ariG23498 wrote about custom layers, its flexibility (which makes it hard to automate) can be a boon. And the creators of custom layers are mostly power user anyways, in my experience :D For completeness of discussion, there is a point yet to be addressed in this discussion. The error that @osanseviero originally points at can be avoided by importing
import tensorflow_addons as tfa
from huggingface_hub import from_pretrained_keras
loaded_model = from_pretrained_keras("keras-io/vit-small-ds")
loaded_model.summary()
from huggingface_hub import from_pretrained_keras
loaded_model = from_pretrained_keras("keras-io/vit-small-ds")
loaded_model.summary() Digging deeper, we can see that What do you think? EDIT: at the very least, we can throw a warning when the user is pushing a model to the hub with these kind of optimizers. |
@gante @ariG23498 @osanseviero |
For this one I'm planning to test BTW weird enough, when you import AdamW without actually compiling the model again this issue |
Self-contained code example:
Error
The text was updated successfully, but these errors were encountered: