We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 760639a commit deaf35aCopy full SHA for deaf35a
keras_rs/src/layers/embedding/jax/distributed_embedding.py
@@ -591,6 +591,7 @@ def _sparsecore_preprocess(
591
mesh.devices.item(0)
592
)
593
print(f"-->{num_sc_per_device=}")
594
+ print(f"-->{jax.process_count()=}")
595
596
preprocessed, stats = embedding_utils.stack_and_shard_samples(
597
self._config.feature_specs,
@@ -612,6 +613,7 @@ def _sparsecore_preprocess(
612
613
def pmax_aggregate(x: Any) -> Any:
614
if not hasattr(x, "ndim"):
615
x = np.array(x)
616
+ jax.debug.print("--> x.shape={}", x.shape)
617
tiled_x = np.tile(x, (num_local_cpu_devices, *([1] * x.ndim)))
618
return jax.pmap(
619
lambda y: jax.lax.pmax(y, "all_cpus"), # type: ignore[no-untyped-call]
0 commit comments