|
| 1 | +#!/usr/bin/env bash |
| 2 | +# Sequential RL training: run each job, wait for its output JSON, then start next. |
| 3 | +# Designed to survive SSH session loss (run with setsid). |
| 4 | +set -uo pipefail |
| 5 | + |
| 6 | +PY=/data/venv/bin/python |
| 7 | +ROOT=/data/speech2text/Qwen3-ASR/finetuning |
| 8 | +OUT=${ROOT}/outputs |
| 9 | +DATA=${ROOT}/data |
| 10 | +LOG_DIR=${OUT}/logs |
| 11 | +ADAPTERS=${OUT}/adapters |
| 12 | +export HF_HOME=/data/speech2text/outputs/cache |
| 13 | +export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True |
| 14 | + |
| 15 | +mkdir -p "${LOG_DIR}" "${ADAPTERS}" |
| 16 | + |
| 17 | +run_job() { |
| 18 | + local algo="$1" lang="$2" profile="$3" size="$4" epochs="$5" extra="${6:-}" |
| 19 | + local lang_short="${lang,,}"; lang_short="${lang_short:0:2}" |
| 20 | + local tag="qwen3_${size}_${algo}_${lang_short}_dev100" |
| 21 | + local out_json="${OUT}/${tag}.json" |
| 22 | + local adapter_dir="${ADAPTERS}/qwen3_${size}_${algo}_${lang_short}" |
| 23 | + local log="${LOG_DIR}/${tag}_$(date -u +%Y%m%d_%H%M%S).log" |
| 24 | + |
| 25 | + if [[ -f "${out_json}" ]]; then |
| 26 | + echo "[skip] ${tag} already done → ${out_json}" |
| 27 | + return 0 |
| 28 | + fi |
| 29 | + |
| 30 | + echo "[start] ${tag} $(date -u)" |
| 31 | + "${PY}" "${ROOT}/qwen3_asr_${algo}.py" \ |
| 32 | + --model_path "Qwen/Qwen3-ASR-0.6B" \ |
| 33 | + --train_file "${DATA}/${profile}/train/train.jsonl" \ |
| 34 | + --eval_file "${DATA}/${profile}/dev/dev.jsonl" \ |
| 35 | + --output_dir "${adapter_dir}" \ |
| 36 | + --tag "${tag}" \ |
| 37 | + --language "${lang}" \ |
| 38 | + --epochs "${epochs}" \ |
| 39 | + --grad_acc 4 --lr 5e-6 \ |
| 40 | + --log_steps 25 --eval_steps 0 \ |
| 41 | + --eval_out_dir "${OUT}" \ |
| 42 | + ${extra} \ |
| 43 | + > "${log}" 2>&1 |
| 44 | + local ec=$? |
| 45 | + echo "[done] ${tag} exit=${ec} $(date -u)" |
| 46 | + if [[ -f "${out_json}" ]]; then |
| 47 | + echo "[ok] result saved to ${out_json}" |
| 48 | + cat "${out_json}" | python3 -c "import json,sys; d=json.load(sys.stdin); print(f' WER={d.get(\"wer\",\"n/a\")} CER={d.get(\"cer\",\"n/a\")} n={d.get(\"n\",\"?\")}')"; |
| 49 | + else |
| 50 | + echo "[warn] no output JSON at ${out_json}" |
| 51 | + fi |
| 52 | +} |
| 53 | + |
| 54 | +# Skip GSPO-FR — already running (PID 1304055) |
| 55 | +# Queue: MWER-FR, MWER-ZH, GSPO-ZH |
| 56 | +# Wait for GSPO-FR to finish first (it holds 11GB GPU) |
| 57 | +echo "Waiting for GSPO-FR (PID 1304055) to finish..." |
| 58 | +while kill -0 1304055 2>/dev/null; do sleep 30; done |
| 59 | +echo "GSPO-FR done. Starting MWER-FR..." |
| 60 | + |
| 61 | +run_job mwer French fleurs-fr 0p6b 0.25 "--n_best 4 --mwer_batch_size 1" |
| 62 | +run_job mwer Chinese cv21-zh 0p6b 0.25 "--n_best 4 --mwer_batch_size 1" |
| 63 | +run_job gspo Chinese cv21-zh 0p6b 0.25 "--group_size 4" |
| 64 | + |
| 65 | +echo "All jobs complete." |
0 commit comments