Skip to content

Commit 3d0b640

Browse files
committed
Add debugging statements
1 parent deaf35a commit 3d0b640

File tree

2 files changed

+1
-0
lines changed

2 files changed

+1
-0
lines changed

keras_rs/.DS_Store

6 KB
Binary file not shown.

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ def pmax_aggregate(x: Any) -> Any:
615615
x = np.array(x)
616616
jax.debug.print("--> x.shape={}", x.shape)
617617
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
618+
jax.debug.print("--> tiled_x.shape={}", tiled_x.shape)
618619
return jax.pmap(
619620
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
620621
axis_name="all_cpus",

0 commit comments

Comments
 (0)