You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: pathwaysutils/elastic/README.md
+40-29
Original file line number
Diff line number
Diff line change
@@ -1,30 +1,31 @@
1
1
# Elastic Training with Pathways
2
2
3
-
This document demonstrates how to leverage the elasticity primitives within `manager.py` to create resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API.
3
+
This document demonstrates how to leverage the elasticity primitives within `pathwaysutils.elastic` to create a resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API.
4
4
5
5
## Overview
6
6
7
-
Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions or hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves:
7
+
Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions and hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves:
8
8
9
-
1.**Training State Management:** Regularly snapshotting the training state (model params, optimizer state, data iterator state).
10
-
2.**Failure Detection:** Pathways Resource Manager detecting when workers join or leave.
11
-
3.**Failure Propogation:** Pathways runtime propogates the error to JAX client.
12
-
4.**Training Reconfiguration:** Adapting the training computation distribution to the current set of healthy workers.
13
-
5.**Resumption:** Continuing training from the last valid snapshot with the new configuration.
9
+
1.**Training State Management**: Regularly snapshotting the training state (model params, optimizer state, data iterator state).
10
+
1.**Failure Detection**: Pathways Resource Manager detects when workers join or leave.
11
+
1.**Failure Propogation**: Pathways runtime propagates the error to JAX client.
12
+
1.**Training Reconfiguration**: Adapting the training computation distribution to the current set of healthy workers.
13
+
1.**Resumption**: Continuing training from the last valid snapshot with the new configuration.
14
14
15
15
The `pathwaysutils.elastic` primitives provide building blocks to integrate this logic into JAX training loops run using the Pathways' `Proxy` JAX backend.
16
16
17
17
## Prerequisites
18
18
19
19
* A [Pathways compatible GKE cluster](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster) with TPU and CPU nodepools.
20
20
*`kubectl` configured to interact with your cluster.
21
-
* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils`library with elasticity features integrated.
21
+
* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils`package with elasticity features integrated.
22
22
23
23
## Elastic MaxText Training with Pathways on GKE
24
24
25
-
This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below.
25
+
This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below.
The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives.
71
+
The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives.
57
72
58
73
### 2. Running the Elastic Training Loop and Simulating hardware failures
59
74
60
-
The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a worker failure by issuing a `SIGILL` to a Pathways worker pod, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected.
75
+
The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a hardware failure by issuing a `kubectl drain` to a randomly selected TPU node, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected.
61
76
62
77
```bash
63
78
#!/bin/bash
64
79
WORKING_DIR=</LOCAL/DIRECTORY/PATH>
65
80
USER_LABEL_SELECTOR="<USER>"
66
81
LOG_DIR="${WORKING_DIR}/logs"
82
+
RUN_ID=pathways-${USER_LABEL_SELECTOR}
83
+
LOG_FILE="${LOG_DIR}/logs_${RUN_ID}.log"
67
84
JOB_DEFINITION_FILE="${WORKING_DIR}/pathwaysjob-elastic.yaml" # Copy the above yaml into this file
68
85
69
86
mkdir -p ${LOG_DIR}
70
87
71
-
run_id=$(date +"%s")
72
-
echo "Running Elastic MaxText with Run ID: $run_id"
88
+
echo "Running Elastic MaxText with Run ID: ${RUN_ID}"
73
89
74
90
# 1. Launch the PathwaysJob
75
91
kubectl apply -f "$JOB_DEFINITION_FILE"
76
-
if [ $? -ne 0 ]; then
77
-
echo "Error: Failed to apply job definition."
78
-
exit 1
79
-
fi
80
92
81
93
# 2. Monitor the PathwaysJob
82
94
echo "Waiting for pods to start..."
83
95
head_pod=""
84
96
for i in $(seq 1 10)
85
97
do
86
-
head_pod=$(kubectl get pods | grep "$USER_LABEL_SELECTOR" | grep 'head' | grep 'Running' | awk '{print $1}' | head -n 1)
98
+
head_pod=$(kubectl get pods -o=name --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep 'head' | head -n 1)
87
99
if [ -n "$head_pod" ]; then
88
100
echo "Found head pod: $head_pod"
89
101
break
@@ -93,21 +105,20 @@ do
93
105
done
94
106
95
107
if [ -z "$head_pod" ]; then
96
-
echo "Error: Could not find running head pod after multiple attempts. Cleaning up..."
108
+
echo "Error: Could not find running head pod after multiple attempts. Cleaning up..." 1>&2
97
109
kubectl delete -f "$JOB_DEFINITION_FILE"
98
110
exit 1
99
111
fi
100
112
101
-
log_file="${LOG_DIR}/logs_${run_id}.log"
102
-
echo "Streaming logs from $head_pod to $log_file"
103
-
kubectl logs -f "$head_pod" >> "${log_file}" &
113
+
echo "Streaming logs from $head_pod to ${LOG_FILE}"
114
+
kubectl logs -f "$head_pod" >> "${LOG_FILE}" &
104
115
logs_pid=$!
105
116
echo "Waiting for job to start making progress..."
0 commit comments