Skip to content

Wrong weight names after deserializing model from config #36

Open
@shkarupa-alex

Description

@shkarupa-alex

System information.

  • Have I written custom code (as opposed to using a stock example script provided in Keras): no
  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): MacOS 13.4.1
  • TensorFlow installed from (source or binary): binary
  • TensorFlow version (use command below): v2.13.0-rc2-7-g1cb1a030a62 2.13.0
  • Python version: 3.11
  • GPU model and memory: no
  • Do you want to contribute a PR? (yes/no): no

Describe the problem.

Models with custom layers miss their weight names after restoring from config

Describe the current behavior.

See example below. When replacing BatchNormalization with it's successor (the only thing changes - core vs custom class) after restoring from config we got wrong weight names.

Describe the expected behavior.

Weight names of custom layers should be preserved just like built-in ones.

Standalone code to reproduce the issue.

import tensorflow as tf
from keras import layers, models
from keras.saving import register_keras_serializable
    

inputs = layers.Input(shape=[None, None, 3], dtype='float32')
x = layers.Conv2D(32, 3, padding='same', name='conv')(inputs)
x = layers.BatchNormalization(name='bn')(x)  # !!! <-- core class
model = models.Model(inputs=inputs, outputs=x)
print([w.name for w in model.weights])

model2 = models.Model.from_config(model.get_config())
print([w.name for w in model2.weights])

# ========= Failure case
@register_keras_serializable(package='MyPackage>Normalization')
class CustomBatchNormalization(layers.BatchNormalization):
    pass

inputs = layers.Input(shape=[None, None, 3], dtype='float32')
x = layers.Conv2D(32, 3, padding='same', name='conv')(inputs)
x = CustomBatchNormalization(name='cbn')(x)  # !!! <-- custom class
model = models.Model(inputs=inputs, outputs=x)
print([w.name for w in model.weights])

model2 = models.Model.from_config(model.get_config())
print([w.name for w in model2.weights])

Source code / logs.

Code posted above will print:

['conv/kernel:0', 'conv/bias:0', 'bn/gamma:0', 'bn/beta:0', 'bn/moving_mean:0', 'bn/moving_variance:0']
['conv/kernel:0', 'conv/bias:0', 'bn/gamma:0', 'bn/beta:0', 'bn/moving_mean:0', 'bn/moving_variance:0']
['conv/kernel:0', 'conv/bias:0', 'cbn/gamma:0', 'cbn/beta:0', 'cbn/moving_mean:0', 'cbn/moving_variance:0']
['conv/kernel:0', 'conv/bias:0', 'gamma:0', 'beta:0', 'moving_mean:0', 'moving_variance:0']

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions