File tree Expand file tree Collapse file tree 2 files changed +23
-0
lines changed
keras_rs/src/layers/embedding/jax Expand file tree Collapse file tree 2 files changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -144,6 +144,22 @@ gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
144144 --worker=all \
145145 --command=" source .keras-env/bin/activate && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
146146
147+ # ==============================================================================
148+ # Kill Previous Training Processes
149+ # ==============================================================================
150+ # echo ">>> Listing matching processes..."
151+ # gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
152+ # --project ${PROJECT} \
153+ # --zone ${ZONE} \
154+ # --worker=all \
155+ # --command="ps aux | grep '[e]xamples.ml_perf.main' || true"
156+
157+ # echo ">>> Terminating any existing training processes..."
158+ # gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \
159+ # --project ${PROJECT} \
160+ # --zone ${ZONE} \
161+ # --worker=all \
162+ # --command="pkill -9 -f 'python3.12 -m examples.ml_perf.[m]ain.*' || true"
147163
148164# ==============================================================================
149165# Verify Installation
Original file line number Diff line number Diff line change @@ -580,12 +580,17 @@ def _sparsecore_preprocess(
580580 )
581581
582582 layout = self ._sparsecore_layout
583+ print (f"-->{ layout = } " )
583584 mesh = layout .device_mesh .backend_mesh
585+ print (f"-->{ mesh = } " )
584586 global_device_count = mesh .devices .size
587+ print (f"-->{ global_device_count = } " )
585588 local_device_count = mesh .local_mesh .devices .size
589+ print (f"{ local_device_count = } " )
586590 num_sc_per_device = jte_utils .num_sparsecores_per_device (
587591 mesh .devices .item (0 )
588592 )
593+ print (f"-->{ num_sc_per_device = } " )
589594
590595 preprocessed , stats = embedding_utils .stack_and_shard_samples (
591596 self ._config .feature_specs ,
@@ -594,13 +599,15 @@ def _sparsecore_preprocess(
594599 global_device_count ,
595600 num_sc_per_device ,
596601 )
602+ print (f"-->{ stats = } " )
597603
598604 if training :
599605 # Synchronize input statistics across all devices and update the
600606 # underlying stacked tables specs in the feature specs.
601607
602608 # Aggregate stats across all processes/devices via pmax.
603609 num_local_cpu_devices = jax .local_device_count ("cpu" )
610+ print (f"-->{ num_local_cpu_devices = } " )
604611
605612 def pmax_aggregate (x : Any ) -> Any :
606613 if not hasattr (x , "ndim" ):
You can’t perform that action at this time.
0 commit comments