|
| 1 | +# Elastic Training with Pathways |
| 2 | + |
| 3 | +This document demonstrates how to leverage the elasticity primitives within `pathwaysutils.elastic` to create a resilient JAX training loop that can handle hardware failures gracefully. We illustrate this using an example based on the MaxText training loop running on TPUs provisioned by GKE via `PathwaysJob` API. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +Distributed training jobs, especially long-running ones, are susceptible to various failures, such as machine preemptions and hardware issues. Elasticity allows a training job to adapt to changes in the number of available accelerators without crashing. It typically involves: |
| 8 | + |
| 9 | +1. **Training State Management**: Regularly snapshotting the training state (model params, optimizer state, data iterator state). |
| 10 | +1. **Failure Detection**: Pathways Resource Manager detects when workers join or leave. |
| 11 | +1. **Failure Propogation**: Pathways runtime propagates the error to JAX client. |
| 12 | +1. **Training Reconfiguration**: Adapting the training computation distribution to the current set of healthy workers. |
| 13 | +1. **Resumption**: Continuing training from the last valid snapshot with the new configuration. |
| 14 | + |
| 15 | +The `pathwaysutils.elastic` primitives provide elastcity building blocks to use within your JAX training loop when using the Pathways' `Proxy` JAX backend. |
| 16 | + |
| 17 | +## Prerequisites |
| 18 | + |
| 19 | +* A [Pathways compatible GKE cluster](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/create-gke-cluster) with TPU and CPU nodepools. |
| 20 | +* `kubectl` configured to interact with your cluster. |
| 21 | +* Access to a container image containing JAX, your model code (e.g., MaxText), and the `pathwaysutils` package with elasticity features integrated. |
| 22 | + |
| 23 | +## Elastic MaxText Training with Pathways on GKE |
| 24 | + |
| 25 | +This example demonstrates running an elastic MaxText job on 3 x v5e-32 slices using Pathways. See the [PathwaysJob docs](https://cloud.google.com/ai-hypercomputer/docs/workloads/pathways-on-cloud/pathways-intro#pathwaysjob_api) for more details about the various attributes set in the YAML below. |
| 26 | + |
| 27 | +### 1. Elastic PathwaysJob Definition (`pathwaysjob-elastic.py`) |
| 28 | +Please set the variables marked with `<>` below before executing the script. |
| 29 | +```yaml |
| 30 | +apiVersion: pathways-job.pathways.domain/v1 |
| 31 | +kind: PathwaysJob |
| 32 | +metadata: |
| 33 | + name: pathways-<USER> |
| 34 | +spec: |
| 35 | + maxRestarts: 0 |
| 36 | + workers: |
| 37 | + - type: ct5lp-hightpu-4t |
| 38 | + topology: 4x8 |
| 39 | + numSlices: 3 |
| 40 | + maxSliceRestarts: 2 |
| 41 | + pathwaysDir: "gs://<BUCKET>" # Pre-create this bucket. |
| 42 | + controller: |
| 43 | + deploymentMode: default |
| 44 | + elasticSlices: 1 |
| 45 | + template: |
| 46 | + spec: |
| 47 | + containers: |
| 48 | + - name: main |
| 49 | + image: <MAXTEXT_IMAGE> |
| 50 | + imagePullPolicy: Always |
| 51 | + command: |
| 52 | + - bash |
| 53 | + - -c |
| 54 | + - > |
| 55 | + python3 -m MaxText.elastic_train MaxText/configs/base.yml |
| 56 | + base_output_directory=gs://<BUCKET> |
| 57 | + per_device_batch_size=4 |
| 58 | + enable_checkpointing=false |
| 59 | + remat_policy=full |
| 60 | + global_parameter_scale=8 |
| 61 | + steps=50 |
| 62 | + max_target_length=2048 |
| 63 | + use_iota_embed=true |
| 64 | + reuse_example_batch=1 |
| 65 | + dataset_type=synthetic |
| 66 | + attention=flash |
| 67 | + gcs_metrics=True |
| 68 | + enable_pathways_goodput=True |
| 69 | + run_name=pathways-<USER> |
| 70 | +``` |
| 71 | +The MaxText elastic training [script](https://github.com/AI-Hypercomputer/maxtext/blob/main/MaxText/elastic_train.py) invoked by the `main` container above is integrated with `pathwaysutils.elastic` primitives. |
| 72 | + |
| 73 | +### 2. Running the Elastic Training Loop and Simulating hardware failures |
| 74 | + |
| 75 | +The following bash script demonstrates launching the above elastic maxtext job with Pathways, monitoring its progress, simulating a hardware failure by issuing a `kubectl drain` to a randomly selected TPU node, and observing the recovery. Please set the variables marked as `<>` below before executing the script. At the end of the script, we verify elasticity worked as expected. |
| 76 | + |
| 77 | +```bash |
| 78 | +#!/bin/bash |
| 79 | +WORKING_DIR=</LOCAL/DIRECTORY/PATH> |
| 80 | +USER_LABEL_SELECTOR="<USER>" |
| 81 | +LOG_DIR="${WORKING_DIR}/logs" |
| 82 | +RUN_ID=pathways-${USER_LABEL_SELECTOR} |
| 83 | +LOG_FILE="${LOG_DIR}/logs_${RUN_ID}.log" |
| 84 | +JOB_DEFINITION_FILE="${WORKING_DIR}/pathwaysjob-elastic.yaml" # Copy the above yaml into this file |
| 85 | +
|
| 86 | +mkdir -p ${LOG_DIR} |
| 87 | +
|
| 88 | +echo "Running Elastic MaxText with Run ID: ${RUN_ID}" |
| 89 | +
|
| 90 | +# 1. Launch the PathwaysJob |
| 91 | +kubectl apply -f "$JOB_DEFINITION_FILE" |
| 92 | +
|
| 93 | +# 2. Monitor the PathwaysJob |
| 94 | +echo "Waiting for pods to start..." |
| 95 | +head_pod="" |
| 96 | +for i in $(seq 1 10) |
| 97 | +do |
| 98 | + head_pod=$(kubectl get pods -o=name --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep 'head' | head -n 1) |
| 99 | + if [ -n "$head_pod" ]; then |
| 100 | + echo "Found head pod: $head_pod" |
| 101 | + break |
| 102 | + fi |
| 103 | + echo "Head pod not found yet, retrying..." |
| 104 | + sleep 10s |
| 105 | +done |
| 106 | +
|
| 107 | +if [ -z "$head_pod" ]; then |
| 108 | + echo "Error: Could not find running head pod after multiple attempts. Cleaning up..." 1>&2 |
| 109 | + kubectl delete -f "$JOB_DEFINITION_FILE" |
| 110 | + exit 1 |
| 111 | +fi |
| 112 | +
|
| 113 | +echo "Streaming logs from $head_pod to ${LOG_FILE}" |
| 114 | +kubectl logs -f "$head_pod" >> "${LOG_FILE}" & |
| 115 | +logs_pid=$! |
| 116 | +echo "Waiting for job to start making progress..." |
| 117 | +sleep 90s |
| 118 | +
|
| 119 | +# 3. Simulate Failure: Evict a Worker Pod |
| 120 | +echo "Randomly select a worker pod to disrupt..." |
| 121 | +read -r node_name pod_name <<<$(kubectl get pods -o wide --field-selector='status.phase==Running' | grep "$USER_LABEL_SELECTOR" | grep worker | shuf | head -n 1 | awk '{print $7, $1}') |
| 122 | +
|
| 123 | +if [ -z "$pod_name" ] || [ -z "$node_name" ]; then |
| 124 | + echo "Warning: Could not find a running worker pod to disrupt. Skipping disruption." |
| 125 | +else |
| 126 | + echo "Attempting to cordon '$node_name' and kill pod '$pod_name'..." |
| 127 | + kubectl cordon "$node_name" |
| 128 | + kubectl exec -it "$pod_name" -c pathways-worker -- /bin/sh -c "kill -s SIGILL 1" |
| 129 | + echo "Node cordoned. Waiting briefly for training to reconfigure to N-1 slices..." |
| 130 | + sleep 90s |
| 131 | +
|
| 132 | + # 4. Allow Recovery: Uncordon the Node |
| 133 | + echo "Uncordoning node '$node_name' to allow scheduling again." |
| 134 | + kubectl uncordon "$node_name" |
| 135 | +fi |
| 136 | +
|
| 137 | +# 5. Wait for Training to resume on all slices |
| 138 | +sleep 90s |
| 139 | +
|
| 140 | +# 6. Terminate the Job and Cleanup |
| 141 | +echo "Terminating Run ID ${RUN_ID}" |
| 142 | +kubectl delete -f "$JOB_DEFINITION_FILE" |
| 143 | +# Ensure log streaming process is killed |
| 144 | +kill "$logs_pid" 2>/dev/null |
| 145 | +echo "Completed Run ID ${RUN_ID}." |
| 146 | +
|
| 147 | +# 6. Verify by printing steps where training reconfigured from N to N-1 slices and later back to N slices |
| 148 | +# Expect output like: |
| 149 | +# Step: 5, Old Slice Count: 3, New Slice Count: 2 (3 -> 2 slices) |
| 150 | +# Step: 17, Old Slice Count: 2, New Slice Count: 3 (2 -> 3 slices) |
| 151 | +awk ' |
| 152 | + /step=/ && /elastic_manager\.elastic_down_event_count=/ { |
| 153 | + split($0, fields, " ") |
| 154 | + step = "" |
| 155 | + good_slice_count = "" |
| 156 | + for (i in fields) { |
| 157 | + split(fields[i], kv, "=") |
| 158 | + if (kv[1] == "step") { |
| 159 | + step = kv[2] |
| 160 | + } else if (kv[1] == "elastic_manager.good_slice_count") { |
| 161 | + good_slice_count = kv[2] |
| 162 | + } |
| 163 | + } |
| 164 | + if (prev_good_slice_count != "" && prev_good_slice_count != good_slice_count) { |
| 165 | + print "Step: " step ", Old Slice Count: " prev_good_slice_count ", New Slice Count: " good_slice_count |
| 166 | + } |
| 167 | + prev_step = step |
| 168 | + prev_good_slice_count = good_slice_count |
| 169 | + } |
| 170 | +' "${LOG_FILE}" |
| 171 | +``` |
0 commit comments