Skip to content

Commit b32730f

Browse files
committed
Create README.md for running a JAX workload integrated with elasticity primitives
1 parent 42640da commit b32730f

File tree

1 file changed

+160
-0
lines changed

1 file changed

+160
-0
lines changed

pathwaysutils/elastic/README.md

+160
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Elastic Training with Pathways
2+
3+
This document demonstrates how to leverage the elasticity primitives within `manager.py` to create resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API.
4+
5+
## Overview
6+
7+
Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions or hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves:
8+
9+
1. **Training State Management:** Regularly snapshotting the training state (model params, optimizer state, data iterator state).
10+
2. **Failure Detection:** Pathways Resource Manager detecting when workers join or leave.
11+
3. **Failure Propogation:** Pathways runtime propogates the error to JAX client.
12+
4. **Training Reconfiguration:** Adapting the training computation distribution to the current set of healthy workers.
13+
5. **Resumption:** Continuing training from the last valid snapshot with the new configuration.
14+
15+
The `pathwaysutils.elastic` primitives provide building blocks to integrate this logic into JAX training loops run using the Pathways' `Proxy` JAX backend.
16+
17+
## Prerequisites
18+
19+
* A [Pathways compatible GKE cluster](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster) with TPU and CPU nodepools.
20+
* `kubectl` configured to interact with your cluster.
21+
* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils` library with elasticity features integrated.
22+
23+
## Elastic MaxText Training with Pathways on GKE
24+
25+
This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below.
26+
27+
### 1. Elastic PathwaysJob Definition (`pathwaysjob-elastic.py`)
28+
```yaml
29+
apiVersion: pathways-job.pathways.domain/v1
30+
kind: PathwaysJob
31+
metadata:
32+
name: pathways-<USER>
33+
spec:
34+
maxRestarts: 0
35+
workers:
36+
- type: ct5lp-hightpu-4t
37+
topology: 4x8
38+
numSlices: 3
39+
maxSliceRestarts: 2
40+
pathwaysDir: "gs://<BUCKET>" # Pre-create this bucket.
41+
controller:
42+
deploymentMode: default
43+
elasticSlices: 1
44+
template:
45+
spec:
46+
containers:
47+
- name: main
48+
image: <MAXTEXT_IMAGE>
49+
imagePullPolicy: Always
50+
command:
51+
- bash
52+
- -c
53+
- |
54+
python3 -m MaxText.elastic_train MaxText/configs/base.yml base_output_directory=gs://<BUCKET> per_device_batch_size=4 enable_checkpointing=false remat_policy=full global_parameter_scale=8 steps=50 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True run_name=pathways-<USER> enable_pathways_goodput=True
55+
```
56+
The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives.
57+
58+
### 2. Running the Elastic Training Loop and Simulating hardware failures
59+
60+
The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a worker failure by issuing a `SIGILL` to a Pathways worker pod, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected.
61+
62+
```bash
63+
#!/bin/bash
64+
WORKING_DIR=</LOCAL/DIRECTORY/PATH>
65+
USER_LABEL_SELECTOR="<USER>"
66+
LOG_DIR="${WORKING_DIR}/logs"
67+
JOB_DEFINITION_FILE="${WORKING_DIR}/pathwaysjob-elastic.yaml" # Copy the above yaml into this file
68+
69+
mkdir -p ${LOG_DIR}
70+
71+
run_id=$(date +"%s")
72+
echo "Running Elastic MaxText with Run ID: $run_id"
73+
74+
# 1. Launch the PathwaysJob
75+
kubectl apply -f "$JOB_DEFINITION_FILE"
76+
if [ $? -ne 0 ]; then
77+
echo "Error: Failed to apply job definition."
78+
exit 1
79+
fi
80+
81+
# 2. Monitor the PathwaysJob
82+
echo "Waiting for pods to start..."
83+
head_pod=""
84+
for i in $(seq 1 10)
85+
do
86+
head_pod=$(kubectl get pods | grep "$USER_LABEL_SELECTOR" | grep 'head' | grep 'Running' | awk '{print $1}' | head -n 1)
87+
if [ -n "$head_pod" ]; then
88+
echo "Found head pod: $head_pod"
89+
break
90+
fi
91+
echo "Head pod not found yet, retrying..."
92+
sleep 10s
93+
done
94+
95+
if [ -z "$head_pod" ]; then
96+
echo "Error: Could not find running head pod after multiple attempts. Cleaning up..."
97+
kubectl delete -f "$JOB_DEFINITION_FILE"
98+
exit 1
99+
fi
100+
101+
log_file="${LOG_DIR}/logs_${run_id}.log"
102+
echo "Streaming logs from $head_pod to $log_file"
103+
kubectl logs -f "$head_pod" >> "${log_file}" &
104+
logs_pid=$!
105+
echo "Waiting for job to start making progress..."
106+
sleep 90s
107+
108+
# 3. Simulate Failure: Evict a Worker Pod
109+
echo "Randomly select a worker pod to disrupt..."
110+
read -r node_name pod_name <<<$(kubectl get pods -o wide | grep "$USER_LABEL_SELECTOR" | grep 'worker-[0-9]-0-' | grep 'Running' | shuf | head -n 1 | awk '{print $7, $1}')
111+
112+
if [ -z "$pod_name" ] || [ -z "$node_name" ]; then
113+
echo "Warning: Could not find a running worker pod to disrupt. Skipping disruption."
114+
else
115+
echo "Attempting to cordon '$node_name' and kill pod '$pod_name'..."
116+
kubectl cordon "$node_name"
117+
kubectl exec -it "$pod_name" -c pathways-worker -- /bin/sh -c "kill -s SIGILL 1"
118+
echo "Node cordoned. Waiting briefly for training to reconfigure to N-1 slices..."
119+
sleep 90s
120+
121+
# 4. Allow Recovery: Uncordon the Node
122+
echo "Uncordoning node '$node_name' to allow scheduling again."
123+
kubectl uncordon "$node_name"
124+
fi
125+
126+
# 5. Wait for Training to resume on all slices
127+
sleep 90s
128+
129+
# 6. Terminate the Job and Cleanup
130+
echo "Terminating Run ID $run_id"
131+
kubectl delete -f "$JOB_DEFINITION_FILE"
132+
# Ensure log streaming process is killed
133+
kill "$logs_pid" 2>/dev/null
134+
echo "Completed Run ID $run_id."
135+
136+
# 6. Verify by printing steps where training reconfigured from N to N-1 slices and later back to N slices
137+
# Expect output like:
138+
# Step: 5, Old Slice Count: 3, New Slice Count: 2 (3 -> 2 slices)
139+
# Step: 17, Old Slice Count: 2, New Slice Count: 3 (2 -> 3 slices)
140+
awk '
141+
/step=/ && /elastic_manager\.good_slice_count=/ {
142+
split($0, fields, " ")
143+
step = ""
144+
good_slice_count = ""
145+
for (i in fields) {
146+
split(fields[i], kv, "=")
147+
if (kv[1] == "step") {
148+
step = kv[2]
149+
} else if (kv[1] == "elastic_manager.good_slice_count") {
150+
good_slice_count = kv[2]
151+
}
152+
}
153+
if (prev_good_slice_count != "" && prev_good_slice_count != good_slice_count) {
154+
print "Step: " step ", Old Slice Count: " prev_good_slice_count ", New Slice Count: " good_slice_count
155+
}
156+
prev_step = step
157+
prev_good_slice_count = good_slice_count
158+
}
159+
' "$log_file"
160+
```

0 commit comments

Comments
 (0)