Skip to content

Commit 6bb8ae0

Browse files
committed
Use old code for preprocessing
1 parent 2ee8f80 commit 6bb8ae0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def sparsecore_build(
442442

443443
# Collect all stacked tables.
444444
table_specs = embedding.get_table_specs(feature_specs)
445-
table_stacks = embedding_utils.get_table_stacks(table_specs)
445+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
446446

447447
# Create variables for all stacked tables and slot variables.
448448
with sparsecore_distribution.scope():
@@ -518,7 +518,7 @@ def _sparsecore_symbolic_preprocess(
518518
table_specs = embedding.get_table_specs(
519519
self._config.feature_specs
520520
)
521-
table_stacks = embedding_utils.get_table_stacks(table_specs)
521+
table_stacks = jte_table_stacking.get_table_stacks(table_specs)
522522
stacked_table_specs = {
523523
stack_name: stack[0].stacked_table_spec
524524
for stack_name, stack in table_stacks.items()

0 commit comments

Comments
 (0)