Skip to content

Commit 3519afc

Browse files
authored
Update README.md with review comments 1
1 parent b32730f commit 3519afc

File tree

1 file changed

+40
-29
lines changed

1 file changed

+40
-29
lines changed

pathwaysutils/elastic/README.md

+40-29
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,31 @@
11
# Elastic Training with Pathways
22

3-
This document demonstrates how to leverage the elasticity primitives within `manager.py` to create resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API.
3+
This document demonstrates how to leverage the elasticity primitives within `pathwaysutils.elastic` to create a resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API.
44

55
## Overview
66

7-
Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions or hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves:
7+
Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions and hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves:
88

9-
1. **Training State Management:** Regularly snapshotting the training state (model params, optimizer state, data iterator state).
10-
2. **Failure Detection:** Pathways Resource Manager detecting when workers join or leave.
11-
3. **Failure Propogation:** Pathways runtime propogates the error to JAX client.
12-
4. **Training Reconfiguration:** Adapting the training computation distribution to the current set of healthy workers.
13-
5. **Resumption:** Continuing training from the last valid snapshot with the new configuration.
9+
1. **Training State Management**: Regularly snapshotting the training state (model params, optimizer state, data iterator state).
10+
1. **Failure Detection**: Pathways Resource Manager detects when workers join or leave.
11+
1. **Failure Propogation**: Pathways runtime propagates the error to JAX client.
12+
1. **Training Reconfiguration**: Adapting the training computation distribution to the current set of healthy workers.
13+
1. **Resumption**: Continuing training from the last valid snapshot with the new configuration.
1414

1515
The `pathwaysutils.elastic` primitives provide building blocks to integrate this logic into JAX training loops run using the Pathways' `Proxy` JAX backend.
1616

1717
## Prerequisites
1818

1919
* A [Pathways compatible GKE cluster](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster) with TPU and CPU nodepools.
2020
* `kubectl` configured to interact with your cluster.
21-
* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils` library with elasticity features integrated.
21+
* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils` package with elasticity features integrated.
2222

2323
## Elastic MaxText Training with Pathways on GKE
2424

25-
This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below.
25+
This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below.
2626

2727
### 1. Elastic PathwaysJob Definition (`pathwaysjob-elastic.py`)
28+
Please set the variables marked with `<>` below before executing the script.
2829
```yaml
2930
apiVersion: pathways-job.pathways.domain/v1
3031
kind: PathwaysJob
@@ -34,7 +35,7 @@ spec:
3435
maxRestarts: 0
3536
workers:
3637
- type: ct5lp-hightpu-4t
37-
topology: 4x8
38+
topology: 4x8
3839
numSlices: 3
3940
maxSliceRestarts: 2
4041
pathwaysDir: "gs://<BUCKET>" # Pre-create this bucket.
@@ -50,40 +51,51 @@ spec:
5051
command:
5152
- bash
5253
- -c
53-
- |
54-
python3 -m MaxText.elastic_train MaxText/configs/base.yml base_output_directory=gs://<BUCKET> per_device_batch_size=4 enable_checkpointing=false remat_policy=full global_parameter_scale=8 steps=50 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True run_name=pathways-<USER> enable_pathways_goodput=True
54+
- >
55+
python3 -m MaxText.elastic_train MaxText/configs/base.yml
56+
base_output_directory=gs://<BUCKET>
57+
per_device_batch_size=4
58+
enable_checkpointing=false
59+
remat_policy=full
60+
global_parameter_scale=8
61+
steps=50
62+
max_target_length=2048
63+
use_iota_embed=true
64+
reuse_example_batch=1
65+
dataset_type=synthetic
66+
attention=flash
67+
gcs_metrics=True
68+
enable_pathways_goodput=True
69+
run_name=pathways-<USER>
5570
```
56-
The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives.
71+
The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives.
5772

5873
### 2. Running the Elastic Training Loop and Simulating hardware failures
5974

60-
The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a worker failure by issuing a `SIGILL` to a Pathways worker pod, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected.
75+
The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a hardware failure by issuing a `kubectl drain` to a randomly selected TPU node, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected.
6176

6277
```bash
6378
#!/bin/bash
6479
WORKING_DIR=</LOCAL/DIRECTORY/PATH>
6580
USER_LABEL_SELECTOR="<USER>"
6681
LOG_DIR="${WORKING_DIR}/logs"
82+
RUN_ID=pathways-${USER_LABEL_SELECTOR}
83+
LOG_FILE="${LOG_DIR}/logs_${RUN_ID}.log"
6784
JOB_DEFINITION_FILE="${WORKING_DIR}/pathwaysjob-elastic.yaml" # Copy the above yaml into this file
6885
6986
mkdir -p ${LOG_DIR}
7087
71-
run_id=$(date +"%s")
72-
echo "Running Elastic MaxText with Run ID: $run_id"
88+
echo "Running Elastic MaxText with Run ID: ${RUN_ID}"
7389
7490
# 1. Launch the PathwaysJob
7591
kubectl apply -f "$JOB_DEFINITION_FILE"
76-
if [ $? -ne 0 ]; then
77-
echo "Error: Failed to apply job definition."
78-
exit 1
79-
fi
8092
8193
# 2. Monitor the PathwaysJob
8294
echo "Waiting for pods to start..."
8395
head_pod=""
8496
for i in $(seq 1 10)
8597
do
86-
head_pod=$(kubectl get pods | grep "$USER_LABEL_SELECTOR" | grep 'head' | grep 'Running' | awk '{print $1}' | head -n 1)
98+
head_pod=$(kubectl get pods -o=name --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep 'head' | head -n 1)
8799
if [ -n "$head_pod" ]; then
88100
echo "Found head pod: $head_pod"
89101
break
@@ -93,21 +105,20 @@ do
93105
done
94106
95107
if [ -z "$head_pod" ]; then
96-
echo "Error: Could not find running head pod after multiple attempts. Cleaning up..."
108+
echo "Error: Could not find running head pod after multiple attempts. Cleaning up..." 1>&2
97109
kubectl delete -f "$JOB_DEFINITION_FILE"
98110
exit 1
99111
fi
100112
101-
log_file="${LOG_DIR}/logs_${run_id}.log"
102-
echo "Streaming logs from $head_pod to $log_file"
103-
kubectl logs -f "$head_pod" >> "${log_file}" &
113+
echo "Streaming logs from $head_pod to ${LOG_FILE}"
114+
kubectl logs -f "$head_pod" >> "${LOG_FILE}" &
104115
logs_pid=$!
105116
echo "Waiting for job to start making progress..."
106117
sleep 90s
107118
108119
# 3. Simulate Failure: Evict a Worker Pod
109120
echo "Randomly select a worker pod to disrupt..."
110-
read -r node_name pod_name <<<$(kubectl get pods -o wide | grep "$USER_LABEL_SELECTOR" | grep 'worker-[0-9]-0-' | grep 'Running' | shuf | head -n 1 | awk '{print $7, $1}')
121+
read -r node_name pod_name <<<$(kubectl get pods -o wide --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep worker | shuf | head -n 1 | awk '{print $7, $1}')
111122
112123
if [ -z "$pod_name" ] || [ -z "$node_name" ]; then
113124
echo "Warning: Could not find a running worker pod to disrupt. Skipping disruption."
@@ -127,11 +138,11 @@ fi
127138
sleep 90s
128139
129140
# 6. Terminate the Job and Cleanup
130-
echo "Terminating Run ID $run_id"
141+
echo "Terminating Run ID ${RUN_ID}"
131142
kubectl delete -f "$JOB_DEFINITION_FILE"
132143
# Ensure log streaming process is killed
133144
kill "$logs_pid" 2>/dev/null
134-
echo "Completed Run ID $run_id."
145+
echo "Completed Run ID ${RUN_ID}."
135146
136147
# 6. Verify by printing steps where training reconfigured from N to N-1 slices and later back to N slices
137148
# Expect output like:
@@ -156,5 +167,5 @@ awk '
156167
prev_step = step
157168
prev_good_slice_count = good_slice_count
158169
}
159-
' "$log_file"
170+
' "${LOG_FILE}"
160171
```

0 commit comments

Comments
 (0)