Skip to content

Commit 760639a

Browse files
committed
Add debugging statements
1 parent 8b6e300 commit 760639a

File tree

2 files changed

+23
-0
lines changed

2 files changed

+23
-0
lines changed

examples/ml_perf/run.sh

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,22 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
144144
--worker=all \
145145
--command="source .keras-env/bin/activate && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
146146

147+
# ==============================================================================
148+
# Kill Previous Training Processes
149+
# ==============================================================================
150+
# echo ">>> Listing matching processes..."
151+
# gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
152+
# --project ${PROJECT} \
153+
# --zone ${ZONE} \
154+
# --worker=all \
155+
# --command="ps aux | grep '[e]xamples.ml_perf.main' || true"
156+
157+
# echo ">>> Terminating any existing training processes..."
158+
# gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
159+
# --project ${PROJECT} \
160+
# --zone ${ZONE} \
161+
# --worker=all \
162+
# --command="pkill -9 -f 'python3.12 -m examples.ml_perf.[m]ain.*' || true"
147163

148164
# ==============================================================================
149165
# Verify Installation

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,12 +580,17 @@ def _sparsecore_preprocess(
580580
)
581581

582582
layout = self._sparsecore_layout
583+
print(f"-->{layout=}")
583584
mesh = layout.device_mesh.backend_mesh
585+
print(f"-->{mesh=}")
584586
global_device_count = mesh.devices.size
587+
print(f"-->{global_device_count=}")
585588
local_device_count = mesh.local_mesh.devices.size
589+
print(f"{local_device_count=}")
586590
num_sc_per_device = jte_utils.num_sparsecores_per_device(
587591
mesh.devices.item(0)
588592
)
593+
print(f"-->{num_sc_per_device=}")
589594

590595
preprocessed, stats = embedding_utils.stack_and_shard_samples(
591596
self._config.feature_specs,
@@ -594,13 +599,15 @@ def _sparsecore_preprocess(
594599
global_device_count,
595600
num_sc_per_device,
596601
)
602+
print(f"-->{stats=}")
597603

598604
if training:
599605
# Synchronize input statistics across all devices and update the
600606
# underlying stacked tables specs in the feature specs.
601607

602608
# Aggregate stats across all processes/devices via pmax.
603609
num_local_cpu_devices = jax.local_device_count("cpu")
610+
print(f"-->{num_local_cpu_devices=}")
604611

605612
def pmax_aggregate(x: Any) -> Any:
606613
if not hasattr(x, "ndim"):

0 commit comments

Comments
 (0)