Skip to content

Karen/contrastive train#10

Open
kar-m wants to merge 7 commits into
mainfrom
karen/contrastive-train
Open

Karen/contrastive train#10
kar-m wants to merge 7 commits into
mainfrom
karen/contrastive-train

Conversation

@kar-m
Copy link
Copy Markdown
Contributor

@kar-m kar-m commented Mar 12, 2026

steps to replicate training:

# 2. Create virtualenv and install (TPU)
  python3.11 -m venv .venv
  .venv/bin/pip install -e ".[tpu]"

  # 3. Authenticate GCS (needed for dataset, mel files, tokenizer, cache)
  gcloud auth application-default login
  # or if using a service account:
  # export GOOGLE_APPLICATION_CREDENTIALS=/path/to/key.json

  # 4. Set up ramdisk paths (use /dev/shm for speed, or any fast disk)
  mkdir -p /dev/shm/needle_cache
  mkdir -p /dev/shm/needle_tokenizer
  # Tool-call dataset will be downloaded automatically on first run
  # or copy it manually:
  # gcloud storage cp -r gs://cactus-dataset/tool_calls /dev/shm/needle_tool_calls_unified

  # 5. Set environment variables (add to ~/.bashrc to persist)
  export NEEDLE_CACHE_DIR=/dev/shm/needle_cache
  export NEEDLE_TOKENIZER_DIR=/dev/shm/needle_tokenizer
  export NEEDLE_LOCAL_UNIFIED_DIR=/dev/shm/needle_tool_calls_unified

  ---
  Training pipeline

  cd /path/to/needle
  export NEEDLE_CACHE_DIR=/dev/shm/needle_cache
  export NEEDLE_TOKENIZER_DIR=/dev/shm/needle_tokenizer
  export NEEDLE_LOCAL_UNIFIED_DIR=/dev/shm/needle_tool_calls_unified

  # Step 1: Tokenize
  # Trains a new SentencePiece BPE tokenizer (vocab=8192) on the combined data,
  # tokenizes 500k Emilia speech + 500k Toucan tool-call examples,
  # shuffles before 90/10 train/val split, uploads tokenizer to GCS.
  .venv/bin/needle tokenize \
    --max-speech-samples 500000 \
    --toucan-max-samples 500000 \
    --shuffle-before-split \
    --overwrite-gcs-tokenizer \
    --batch-size 10000

  # Step 2: Stage 1 pretrain
  # ~2 hours on v6e-8 TPU
  # Trains all params (encoder + decoder) on speech transcription +
  # audio-text contrastive + tool contrastive losses.
  .venv/bin/needle pretrain \
    --epochs 5 \
    --batch-size 8 \
    --max-speech-samples 500000 \
    --toucan-max-samples 500000 \
    --eval-every 2000 \
    --max-eval-samples 500 \
    --dropout 0.1 \
    --checkpoint-dir /dev/shm/checkpoints_stage1_v2 \
    --wandb

  # Step 3: Stage 2 finetune
  # ~17 minutes on v6e-8 TPU
  # Decoder reinit from scratch, encoder frozen, trains on tool-call data only.
  # Use the last epoch checkpoint from stage 1 (no *_best.pkl is written by pretrain).
  STAGE1_CKPT=$(ls -t /dev/shm/checkpoints_stage1_v2/*.pkl | head -1)

  .venv/bin/needle train \
    --checkpoint "$STAGE1_CKPT" \
    --reinit-decoder \
    --epochs 20 \
    --batch-size 8 \
    --dropout 0.1 \
    --eval-every 1000 \
    --max-eval-samples 500 \
    --checkpoint-dir /dev/shm/checkpoints_stage2_v2 \
    --wandb







Results of 500k pretrain + 90k posttrain

Stage 1 (pretrain)                                                                                                                              
                                                                                                                                                
  ┌───────┬────────────┬────────────────────────────────┐                                                                                         
  │ Epoch │ Train Loss │ Val PPL (speech transcription) │                                                                                       
  ├───────┼────────────┼────────────────────────────────┤                                                                                         
  │ 1     │ 5.1950     │ 34.91                          │                                                                                       
  ├───────┼────────────┼────────────────────────────────┤                                                                                         
  │ 2     │ 3.8850     │ 27.76                          │                                                                                         
  ├───────┼────────────┼────────────────────────────────┤                                                                                         
  │ 3     │ 3.6848     │ 24.04                          │                                                                                         
  ├───────┼────────────┼────────────────────────────────┤                                                                                         
  │ 4     │ 3.4356     │ 16.56                          │                                                                                         
  ├───────┼────────────┼────────────────────────────────┤
  │ 5     │ 3.0279     │ 12.32                          │
  └───────┴────────────┴────────────────────────────────┘

  ---
  Stage 2 (tool-call finetune)

  ┌───────┬────────────┬───────────────┬────────┬─────────────┐
  │ Epoch │ Train Loss │ Val PPL (all) │ ne_ppl │   aud_ppl   │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 1     │ 14.212     │ 88.83         │ 93.42  │ 103.16      │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 2     │ 3.397      │ 28.32         │ 29.94  │ 53.65       │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 3     │ 2.363      │ 20.04         │ 22.78  │ 4,107       │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 4     │ 2.039      │ 9.21          │ 10.94  │ 854         │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 5     │ 1.823      │ 5.34          │ 6.34   │ 131         │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 6     │ 1.691      │ 6.41          │ 7.48   │ 357         │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 7     │ 1.701      │ 3.73          │ 4.22   │ 48,748      │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 8     │ 1.757      │ 5.37          │ 6.27   │ 485,165,195 │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 9     │ 1.756      │ 3.23          │ 3.73   │ 1,472,930   │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 10    │ 1.750      │ 2.90          │ 3.25   │ 10.03       │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 11    │ 1.745      │ 3.57          │ 3.91   │ 3.72        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 12    │ 1.762      │ 3.60          │ 3.99   │ 3.91        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 13    │ 1.721      │ 3.18          │ 3.49   │ 3.44        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 14    │ 1.742      │ 3.54          │ 3.89   │ 3.88        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 15    │ 1.710      │ 2.98          │ 3.31   │ 3.31        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 16    │ 1.727      │ 3.04          │ 3.38   │ 3.38        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 17    │ 1.728      │ 3.03          │ 3.40   │ 3.41        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 18    │ 1.689      │ 2.95          │ 3.30   │ 3.30        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 19    │ 1.585      │ 2.74          │ 3.03   │ 3.03        │
  ├───────┼────────────┼───────────────┼────────┼─────────────┤
  │ 20    │ 1.488      │ 2.52          │ 2.76   │ 2.77        │
  └───────┴────────────┴───────────────┴────────┴─────────────┘

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant