File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
keras_rs/src/layers/embedding/jax Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -442,7 +442,7 @@ def sparsecore_build(
442442
443443 # Collect all stacked tables.
444444 table_specs = embedding .get_table_specs (feature_specs )
445- table_stacks = embedding_utils .get_table_stacks (table_specs )
445+ table_stacks = jte_table_stacking .get_table_stacks (table_specs )
446446
447447 # Create variables for all stacked tables and slot variables.
448448 with sparsecore_distribution .scope ():
@@ -518,7 +518,7 @@ def _sparsecore_symbolic_preprocess(
518518 table_specs = embedding .get_table_specs (
519519 self ._config .feature_specs
520520 )
521- table_stacks = embedding_utils .get_table_stacks (table_specs )
521+ table_stacks = jte_table_stacking .get_table_stacks (table_specs )
522522 stacked_table_specs = {
523523 stack_name : stack [0 ].stacked_table_spec
524524 for stack_name , stack in table_stacks .items ()
You can’t perform that action at this time.
0 commit comments