Skip to content

Commit ce15ed7

Browse files
committed
Use old code for preprocessing
1 parent 9ced4ed commit ce15ed7

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ def _sparsecore_set_tables(self, tables: Mapping[str, ArrayLike]) -> None:
718718

719719
config = self._config
720720
num_table_shards = config.mesh.devices.size * config.num_sc_per_device
721-
table_specs = embedding_utils.get_table_specs(config.feature_specs)
721+
table_specs = embedding.get_table_specs(config.feature_specs)
722722
sharded_tables = embedding_utils.stack_and_shard_tables(
723723
table_specs,
724724
tables,

0 commit comments

Comments
 (0)