Skip to content

Commit

Permalink
Make clone_initializer work better with type checking. (#18)
Browse files Browse the repository at this point in the history
Also, `str` is preferred to `Text` for type hints.
  • Loading branch information
hertschuh authored Jan 28, 2025
1 parent d500b4d commit ab0e362
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions keras_rs/src/utils/keras_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
from typing import Text, Union
from typing import Union

import keras


def clone_initializer(
initializer: Union[Text, keras.initializers.Initializer],
initializer: Union[str, keras.initializers.Initializer],
) -> keras.initializers.Initializer:
"""Clones an initializer to ensure a new seed.
Args:
initializer: The initializer to clone.
Returns:
A cloned initializer if it is clonable, otherwise the original one.
As of tensorflow 2.10, we need to clone user passed initializers when
invoking them twice to avoid creating the same randomized initialization.
"""
if isinstance(initializer, keras.initializers.Initializer):
config = initializer.get_config()
initializer_class: type[keras.initializers.Initializer] = (
initializer.__class__
)
return initializer_class.from_config(config)
# If we get a string or dict, just return as we cannot and should not clone.
if not isinstance(initializer, keras.initializers.Initializer):
return initializer
config = initializer.get_config()
return initializer.__class__.from_config(config)
return initializer

0 comments on commit ab0e362

Please sign in to comment.