Skip to content

Commit ab0e362

Browse files
authored
Make clone_initializer work better with type checking. (#18)
Also, `str` is preferred to `Text` for type hints.
1 parent d500b4d commit ab0e362

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

keras_rs/src/utils/keras_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
1-
from typing import Text, Union
1+
from typing import Union
22

33
import keras
44

55

66
def clone_initializer(
7-
initializer: Union[Text, keras.initializers.Initializer],
7+
initializer: Union[str, keras.initializers.Initializer],
88
) -> keras.initializers.Initializer:
99
"""Clones an initializer to ensure a new seed.
1010
11+
Args:
12+
initializer: The initializer to clone.
13+
14+
Returns:
15+
A cloned initializer if it is clonable, otherwise the original one.
16+
1117
As of tensorflow 2.10, we need to clone user passed initializers when
1218
invoking them twice to avoid creating the same randomized initialization.
1319
"""
20+
if isinstance(initializer, keras.initializers.Initializer):
21+
config = initializer.get_config()
22+
initializer_class: type[keras.initializers.Initializer] = (
23+
initializer.__class__
24+
)
25+
return initializer_class.from_config(config)
1426
# If we get a string or dict, just return as we cannot and should not clone.
15-
if not isinstance(initializer, keras.initializers.Initializer):
16-
return initializer
17-
config = initializer.get_config()
18-
return initializer.__class__.from_config(config)
27+
return initializer

0 commit comments

Comments
 (0)