Skip to content

Commit a4f17a3

Browse files
committed
Create README.md for running a JAX workload integrated with elasticity primitives
1 parent 42640da commit a4f17a3

File tree

1 file changed

+171
-0
lines changed

1 file changed

+171
-0
lines changed

pathwaysutils/elastic/README.md

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Elastic Training with Pathways
2+
3+
This document demonstrates how to leverage the elasticity primitives within `pathwaysutils.elastic` to create a resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API.
4+
5+
## Overview
6+
7+
Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions and hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves:
8+
9+
1. **Training State Management**: Regularly snapshotting the training state (model params, optimizer state, data iterator state).
10+
1. **Failure Detection**: Pathways Resource Manager detects when workers join or leave.
11+
1. **Failure Propogation**: Pathways runtime propagates the error to JAX client.
12+
1. **Training Reconfiguration**: Adapting the training computation distribution to the current set of healthy workers.
13+
1. **Resumption**: Continuing training from the last valid snapshot with the new configuration.
14+
15+
The `pathwaysutils.elastic` primitives provide elastcity building blocks to use within your JAX training loop when using the Pathways' `Proxy` JAX backend.
16+
17+
## Prerequisites
18+
19+
* A [Pathways compatible GKE cluster](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster) with TPU and CPU nodepools.
20+
* `kubectl` configured to interact with your cluster.
21+
* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils` package with elasticity features integrated.
22+
23+
## Elastic MaxText Training with Pathways on GKE
24+
25+
This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below.
26+
27+
### 1. Elastic PathwaysJob Definition (`pathwaysjob-elastic.py`)
28+
Please set the variables marked with `<>` below before executing the script.
29+
```yaml
30+
apiVersion: pathways-job.pathways.domain/v1
31+
kind: PathwaysJob
32+
metadata:
33+
name: pathways-<USER>
34+
spec:
35+
maxRestarts: 0
36+
workers:
37+
- type: ct5lp-hightpu-4t
38+
topology: 4x8
39+
numSlices: 3
40+
maxSliceRestarts: 2
41+
pathwaysDir: "gs://<BUCKET>" # Pre-create this bucket.
42+
controller:
43+
deploymentMode: default
44+
elasticSlices: 1
45+
template:
46+
spec:
47+
containers:
48+
- name: main
49+
image: <MAXTEXT_IMAGE>
50+
imagePullPolicy: Always
51+
command:
52+
- bash
53+
- -c
54+
- >
55+
python3 -m MaxText.elastic_train MaxText/configs/base.yml
56+
base_output_directory=gs://<BUCKET>
57+
per_device_batch_size=4
58+
enable_checkpointing=false
59+
remat_policy=full
60+
global_parameter_scale=8
61+
steps=50
62+
max_target_length=2048
63+
use_iota_embed=true
64+
reuse_example_batch=1
65+
dataset_type=synthetic
66+
attention=flash
67+
gcs_metrics=True
68+
enable_pathways_goodput=True
69+
run_name=pathways-<USER>
70+
```
71+
The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives.
72+
73+
### 2. Running the Elastic Training Loop and Simulating hardware failures
74+
75+
The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a hardware failure by issuing a `kubectl drain` to a randomly selected TPU node, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected.
76+
77+
```bash
78+
#!/bin/bash
79+
WORKING_DIR=</LOCAL/DIRECTORY/PATH>
80+
USER_LABEL_SELECTOR="<USER>"
81+
LOG_DIR="${WORKING_DIR}/logs"
82+
RUN_ID=pathways-${USER_LABEL_SELECTOR}
83+
LOG_FILE="${LOG_DIR}/logs_${RUN_ID}.log"
84+
JOB_DEFINITION_FILE="${WORKING_DIR}/pathwaysjob-elastic.yaml" # Copy the above yaml into this file
85+
86+
mkdir -p ${LOG_DIR}
87+
88+
echo "Running Elastic MaxText with Run ID: ${RUN_ID}"
89+
90+
# 1. Launch the PathwaysJob
91+
kubectl apply -f "$JOB_DEFINITION_FILE"
92+
93+
# 2. Monitor the PathwaysJob
94+
echo "Waiting for pods to start..."
95+
head_pod=""
96+
for i in $(seq 1 10)
97+
do
98+
head_pod=$(kubectl get pods -o=name --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep 'head' | head -n 1)
99+
if [ -n "$head_pod" ]; then
100+
echo "Found head pod: $head_pod"
101+
break
102+
fi
103+
echo "Head pod not found yet, retrying..."
104+
sleep 10s
105+
done
106+
107+
if [ -z "$head_pod" ]; then
108+
echo "Error: Could not find running head pod after multiple attempts. Cleaning up..." 1>&2
109+
kubectl delete -f "$JOB_DEFINITION_FILE"
110+
exit 1
111+
fi
112+
113+
echo "Streaming logs from $head_pod to ${LOG_FILE}"
114+
kubectl logs -f "$head_pod" >> "${LOG_FILE}" &
115+
logs_pid=$!
116+
echo "Waiting for job to start making progress..."
117+
sleep 90s
118+
119+
# 3. Simulate Failure: Evict a Worker Pod
120+
echo "Randomly select a worker pod to disrupt..."
121+
read -r node_name pod_name <<<$(kubectl get pods -o wide --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep worker | shuf | head -n 1 | awk '{print $7, $1}')
122+
123+
if [ -z "$pod_name" ] || [ -z "$node_name" ]; then
124+
echo "Warning: Could not find a running worker pod to disrupt. Skipping disruption."
125+
else
126+
echo "Attempting to cordon '$node_name' and kill pod '$pod_name'..."
127+
kubectl cordon "$node_name"
128+
kubectl exec -it "$pod_name" -c pathways-worker -- /bin/sh -c "kill -s SIGILL 1"
129+
echo "Node cordoned. Waiting briefly for training to reconfigure to N-1 slices..."
130+
sleep 90s
131+
132+
# 4. Allow Recovery: Uncordon the Node
133+
echo "Uncordoning node '$node_name' to allow scheduling again."
134+
kubectl uncordon "$node_name"
135+
fi
136+
137+
# 5. Wait for Training to resume on all slices
138+
sleep 90s
139+
140+
# 6. Terminate the Job and Cleanup
141+
echo "Terminating Run ID ${RUN_ID}"
142+
kubectl delete -f "$JOB_DEFINITION_FILE"
143+
# Ensure log streaming process is killed
144+
kill "$logs_pid" 2>/dev/null
145+
echo "Completed Run ID ${RUN_ID}."
146+
147+
# 6. Verify by printing steps where training reconfigured from N to N-1 slices and later back to N slices
148+
# Expect output like:
149+
# Step: 5, Old Slice Count: 3, New Slice Count: 2 (3 -> 2 slices)
150+
# Step: 17, Old Slice Count: 2, New Slice Count: 3 (2 -> 3 slices)
151+
awk '
152+
/step=/ && /elastic_manager\.elastic_down_event_count=/ {
153+
split($0, fields, " ")
154+
step = ""
155+
good_slice_count = ""
156+
for (i in fields) {
157+
split(fields[i], kv, "=")
158+
if (kv[1] == "step") {
159+
step = kv[2]
160+
} else if (kv[1] == "elastic_manager.good_slice_count") {
161+
good_slice_count = kv[2]
162+
}
163+
}
164+
if (prev_good_slice_count != "" && prev_good_slice_count != good_slice_count) {
165+
print "Step: " step ", Old Slice Count: " prev_good_slice_count ", New Slice Count: " good_slice_count
166+
}
167+
prev_step = step
168+
prev_good_slice_count = good_slice_count
169+
}
170+
' "${LOG_FILE}"
171+
```

0 commit comments

Comments
 (0)