|
| 1 | +# Elastic Training with Pathways |
| 2 | + |
| 3 | +This document demonstrates how to leverage the elasticity primitives within `manager.py` to create resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions or hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves: |
| 8 | + |
| 9 | +1. **Training State Management:** Regularly snapshotting the training state (model params, optimizer state, data iterator state). |
| 10 | +2. **Failure Detection:** Pathways Resource Manager detecting when workers join or leave. |
| 11 | +3. **Failure Propogation:** Pathways runtime propogates the error to JAX client. |
| 12 | +4. **Training Reconfiguration:** Adapting the training computation distribution to the current set of healthy workers. |
| 13 | +5. **Resumption:** Continuing training from the last valid snapshot with the new configuration. |
| 14 | + |
| 15 | +The `pathwaysutils.elastic` primitives provide building blocks to integrate this logic into JAX training loops run using the Pathways' `Proxy` JAX backend. |
| 16 | + |
| 17 | +## Prerequisites |
| 18 | + |
| 19 | +* A [Pathways compatible GKE cluster](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster) with TPU and CPU nodepools. |
| 20 | +* `kubectl` configured to interact with your cluster. |
| 21 | +* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils` library with elasticity features integrated. |
| 22 | + |
| 23 | +## Elastic MaxText Training with Pathways on GKE |
| 24 | + |
| 25 | +This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below. |
| 26 | + |
| 27 | +### 1. Elastic PathwaysJob Definition (`pathwaysjob-elastic.py`) |
| 28 | +```yaml |
| 29 | +apiVersion: pathways-job.pathways.domain/v1 |
| 30 | +kind: PathwaysJob |
| 31 | +metadata: |
| 32 | + name: pathways-<USER> |
| 33 | +spec: |
| 34 | + maxRestarts: 0 |
| 35 | + workers: |
| 36 | + - type: ct5lp-hightpu-4t |
| 37 | + topology: 4x8 |
| 38 | + numSlices: 3 |
| 39 | + maxSliceRestarts: 2 |
| 40 | + pathwaysDir: "gs://<BUCKET>" # Pre-create this bucket. |
| 41 | + controller: |
| 42 | + deploymentMode: default |
| 43 | + elasticSlices: 1 |
| 44 | + template: |
| 45 | + spec: |
| 46 | + containers: |
| 47 | + - name: main |
| 48 | + image: <MAXTEXT_IMAGE> |
| 49 | + imagePullPolicy: Always |
| 50 | + command: |
| 51 | + - bash |
| 52 | + - -c |
| 53 | + - | |
| 54 | + python3 -m MaxText.elastic_train MaxText/configs/base.yml base_output_directory=gs://<BUCKET> per_device_batch_size=4 enable_checkpointing=false remat_policy=full global_parameter_scale=8 steps=50 max_target_length=2048 use_iota_embed=true reuse_example_batch=1 dataset_type=synthetic attention=flash gcs_metrics=True run_name=pathways-<USER> enable_pathways_goodput=True |
| 55 | +``` |
| 56 | +The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives. |
| 57 | + |
| 58 | +### 2. Running the Elastic Training Loop and Simulating hardware failures |
| 59 | + |
| 60 | +The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a worker failure by issuing a `SIGILL` to a Pathways worker pod, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected. |
| 61 | + |
| 62 | +```bash |
| 63 | +#!/bin/bash |
| 64 | +WORKING_DIR=</LOCAL/DIRECTORY/PATH> |
| 65 | +USER_LABEL_SELECTOR="<USER>" |
| 66 | +LOG_DIR="${WORKING_DIR}/logs" |
| 67 | +JOB_DEFINITION_FILE="${WORKING_DIR}/pathwaysjob-elastic.yaml" # Copy the above yaml into this file |
| 68 | +
|
| 69 | +mkdir -p ${LOG_DIR} |
| 70 | +
|
| 71 | +run_id=$(date +"%s") |
| 72 | +echo "Running Elastic MaxText with Run ID: $run_id" |
| 73 | +
|
| 74 | +# 1. Launch the PathwaysJob |
| 75 | +kubectl apply -f "$JOB_DEFINITION_FILE" |
| 76 | +if [ $? -ne 0 ]; then |
| 77 | +echo "Error: Failed to apply job definition." |
| 78 | +exit 1 |
| 79 | +fi |
| 80 | +
|
| 81 | +# 2. Monitor the PathwaysJob |
| 82 | +echo "Waiting for pods to start..." |
| 83 | +head_pod="" |
| 84 | +for i in $(seq 1 10) |
| 85 | +do |
| 86 | + head_pod=$(kubectl get pods | grep "$USER_LABEL_SELECTOR" | grep 'head' | grep 'Running' | awk '{print $1}' | head -n 1) |
| 87 | + if [ -n "$head_pod" ]; then |
| 88 | + echo "Found head pod: $head_pod" |
| 89 | + break |
| 90 | + fi |
| 91 | + echo "Head pod not found yet, retrying..." |
| 92 | + sleep 10s |
| 93 | +done |
| 94 | +
|
| 95 | +if [ -z "$head_pod" ]; then |
| 96 | + echo "Error: Could not find running head pod after multiple attempts. Cleaning up..." |
| 97 | + kubectl delete -f "$JOB_DEFINITION_FILE" |
| 98 | + exit 1 |
| 99 | +fi |
| 100 | +
|
| 101 | +log_file="${LOG_DIR}/logs_${run_id}.log" |
| 102 | +echo "Streaming logs from $head_pod to $log_file" |
| 103 | +kubectl logs -f "$head_pod" >> "${log_file}" & |
| 104 | +logs_pid=$! |
| 105 | +echo "Waiting for job to start making progress..." |
| 106 | +sleep 90s |
| 107 | +
|
| 108 | +# 3. Simulate Failure: Evict a Worker Pod |
| 109 | +echo "Randomly select a worker pod to disrupt..." |
| 110 | +read -r node_name pod_name <<<$(kubectl get pods -o wide | grep "$USER_LABEL_SELECTOR" | grep 'worker-[0-9]-0-' | grep 'Running' | shuf | head -n 1 | awk '{print $7, $1}') |
| 111 | +
|
| 112 | +if [ -z "$pod_name" ] || [ -z "$node_name" ]; then |
| 113 | + echo "Warning: Could not find a running worker pod to disrupt. Skipping disruption." |
| 114 | +else |
| 115 | + echo "Attempting to cordon '$node_name' and kill pod '$pod_name'..." |
| 116 | + kubectl cordon "$node_name" |
| 117 | + kubectl exec -it "$pod_name" -c pathways-worker -- /bin/sh -c "kill -s SIGILL 1" |
| 118 | + echo "Node cordoned. Waiting briefly for training to reconfigure to N-1 slices..." |
| 119 | + sleep 90s |
| 120 | +
|
| 121 | + # 4. Allow Recovery: Uncordon the Node |
| 122 | + echo "Uncordoning node '$node_name' to allow scheduling again." |
| 123 | + kubectl uncordon "$node_name" |
| 124 | +fi |
| 125 | +
|
| 126 | +# 5. Wait for Training to resume on all slices |
| 127 | +sleep 90s |
| 128 | +
|
| 129 | +# 6. Terminate the Job and Cleanup |
| 130 | +echo "Terminating Run ID $run_id" |
| 131 | +kubectl delete -f "$JOB_DEFINITION_FILE" |
| 132 | +# Ensure log streaming process is killed |
| 133 | +kill "$logs_pid" 2>/dev/null |
| 134 | +echo "Completed Run ID $run_id." |
| 135 | +
|
| 136 | +# 6. Verify by printing steps where training reconfigured from N to N-1 slices and later back to N slices |
| 137 | +# Expect output like: |
| 138 | +# Step: 5, Old Slice Count: 3, New Slice Count: 2 (3 -> 2 slices) |
| 139 | +# Step: 17, Old Slice Count: 2, New Slice Count: 3 (2 -> 3 slices) |
| 140 | +awk ' |
| 141 | + /step=/ && /elastic_manager\.good_slice_count=/ { |
| 142 | + split($0, fields, " ") |
| 143 | + step = "" |
| 144 | + good_slice_count = "" |
| 145 | + for (i in fields) { |
| 146 | + split(fields[i], kv, "=") |
| 147 | + if (kv[1] == "step") { |
| 148 | + step = kv[2] |
| 149 | + } else if (kv[1] == "elastic_manager.good_slice_count") { |
| 150 | + good_slice_count = kv[2] |
| 151 | + } |
| 152 | + } |
| 153 | + if (prev_good_slice_count != "" && prev_good_slice_count != good_slice_count) { |
| 154 | + print "Step: " step ", Old Slice Count: " prev_good_slice_count ", New Slice Count: " good_slice_count |
| 155 | + } |
| 156 | + prev_step = step |
| 157 | + prev_good_slice_count = good_slice_count |
| 158 | + } |
| 159 | +' "$log_file" |
| 160 | +``` |
0 commit comments