Skip to content

Commit deaf35a

Browse files
committed
Add debugging statements
1 parent 760639a commit deaf35a

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,7 @@ def _sparsecore_preprocess(
591591
mesh.devices.item(0)
592592
)
593593
print(f"-->{num_sc_per_device=}")
594+
print(f"-->{jax.process_count()=}")
594595

595596
preprocessed, stats = embedding_utils.stack_and_shard_samples(
596597
self._config.feature_specs,
@@ -612,6 +613,7 @@ def _sparsecore_preprocess(
612613
def pmax_aggregate(x: Any) -> Any:
613614
if not hasattr(x, "ndim"):
614615
x = np.array(x)
616+
jax.debug.print("--> x.shape={}", x.shape)
615617
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
616618
return jax.pmap(
617619
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]

0 commit comments

Comments
 (0)