Skip to content

p-doom/omegalax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

96 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

p(doom)

omegalax: A JAX-based training codebase for LLMs/VLMs.

Overview

  • Qwen3 dense and MoE (omegalax/models/qwen3) with cache-aware decode in omegalax/text/api.py.
  • Qwen3.5 MoE and Qwen3-VL (omegalax/models/qwen3_5, omegalax/models/qwen3_vl, omegalax/vlm/api.py).
  • HuggingFace safetensor loaders for all architectures: create_qwen3_from_safetensors, create_qwen3_5_from_safetensors, and create_qwen3_vl_from_safetensors.
  • Supported models:
    • Qwen3 dense: Qwen/Qwen3-0.6B, Qwen/Qwen3-1.7B, Qwen/Qwen3-4B, Qwen/Qwen3-8B, Qwen/Qwen3-14B, Qwen/Qwen3-32B.
    • Qwen3 MoE: Qwen/Qwen3-30B-A3B-Instruct-2507.
    • Qwen3.5: Qwen/Qwen3.5-397B-A17B.
    • Qwen3-VL: Qwen/Qwen3-VL-2B-Instruct.

Tensor naming convention

All tensor variables use Shazeer's shape-suffix notation. The full dimension key lives in the omegalax.models package docstring (omegalax/models/__init__.py).

Install

Use Python 3.11+ with a JAX build that matches your accelerator (e.g., jax[cuda12] for CUDA 12):

uv sync

Quickstart (language-only)

Create a Qwen3 text model and run a forward+decode step:

import jax
import jax.numpy as jnp
from omegalax.text import api

rng = jax.random.key(0)
model, cfg = api.init_model("Qwen/Qwen3-0.6B", rng, tp_size=1, fsdp_size=1)
tokens = jax.random.randint(rng, (2, 32), 0, cfg.vocab_size, jnp.int32)
logits, aux_loss = api.forward(model, tokens, pad_id=0, cfg=cfg)
cache = api.make_cache(cfg, batch_size=2, token_len=32, generate_steps=8)
next_logits, cache, aux_loss = api.decode(model, cache, tokens, pad_id=0, cfg=cfg)

Training

The expected flow is:

  1. Start from a raw JSONL file where each line is a session with a session_id and messages.
  2. Compile it into canonical payload-block ArrayRecord shards.
  3. Build a chunk-index dataset at the target sequence length.
  4. Train from the chunk-index dataset.

Example raw JSONL row:

{"session_id":"demo-0","messages":[{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"}]}

Compile a raw SFT dataset into Grain payload shards:

uv run scripts/compile_sft_dataset.py \
  --data-path /path/to/train.jsonl \
  --out-dir /path/to/train_payload \
  --messages-per-record 128

Build a chunk index for text SFT:

uv run scripts/build_sft_chunk_index.py \
  --data-path /path/to/train_payload \
  --out-dir /path/to/train_chunks \
  --model-id qwen3-smoke \
  --max-length 512

If the dataset contains image content, also pass --processor (and optionally --preprocessor-config) when building the chunk index.

Run text SFT from the compiled Grain chunk-index dataset:

uv run scripts/train_text_sft.py \
  --model-id qwen3-smoke \
  --data-path /path/to/train_chunks \
  --max-length 512 \
  --batch-size 8 \
  --tp-size 1 \
  --fsdp-size 1

Run VLM SFT from the compiled Grain chunk-index dataset:

uv run scripts/train_vlm_sft.py \
  --model-id qwen3-vl-smoke \
  --data-path /path/to/train_chunks \
  --processor Qwen/Qwen3-VL-2B-Instruct \
  --max-length 512 \
  --batch-size 4 \
  --tp-size 1 \
  --fsdp-size 1

Resume from the latest checkpoint with --resume. Training checkpoints also persist the Grain iterator state.

Export any supported model (Qwen3 dense/MoE, Qwen3.5, Qwen3-VL) to HuggingFace safetensors:

uv run scripts/export_to_hf.py --model-id qwen3-smoke --out-dir /tmp/qwen3-smoke-export --tp-size 1 --fsdp-size 1

Quickstart (vision-language)

Initialize a VLM (Qwen3.5 or Qwen3-VL) and run a multimodal forward pass:

import jax
import jax.numpy as jnp
from omegalax import vlm

rng = jax.random.key(0)
model, cfg = vlm.api.init_model("qwen3.5-smoke", rng, tp_size=1, fsdp_size=1)
tokens = jnp.ones((1, 16), dtype=jnp.int32)
pixel_values = jnp.zeros((1, 3, 2, 14, 14), dtype=jnp.float32)  # B, C, T, H, W
image_grid_thw = jnp.array([[1, 1, 1]], dtype=jnp.int32)
logits, aux_loss = vlm.api.forward(
    model, tokens, pad_id=0, cfg=cfg, pixel_values=pixel_values, image_grid_thw=image_grid_thw
)

Loading HuggingFace checkpoints

All loaders expect a directory containing safetensors and config.json:

from huggingface_hub import snapshot_download
from omegalax.models.qwen3.params import create_qwen3_from_safetensors

ckpt_dir = snapshot_download("Qwen/Qwen3-8B")
model = create_qwen3_from_safetensors(ckpt_dir, "Qwen/Qwen3-8B", tp_size=1, fsdp_size=1)

For Qwen3.5 and Qwen3-VL, use create_qwen3_5_from_safetensors(..., tp_size=1, fsdp_size=1) or create_qwen3_vl_from_safetensors(..., tp_size=1, fsdp_size=1) respectively. When starting from a raw HF config, omegalax.models.qwen3_vl.make_vl_config_from_hf() will build a matching JAX config.

Tests

Tests use absltest:

  • Run all non-real-weight tests (default):
uv run --extra=torch-tests -- python -m unittest discover -s tests -p "test_*.py"
  • Run everything including real-weight parity suites (downloads checkpoints; slow):
OMEGALAX_RUN_REAL_WEIGHTS_TESTS=1 uv run --extra=torch-tests -- python -m unittest discover -s tests -p "test_*.py"
  • Smoke/tiny-model checks only (CPU-friendly, no HF downloads):
uv run --extra=torch-tests -- python -m unittest discover -s tests -p "test_*smoke.py"

Run a single suite via uv run --extra=torch-tests -- python -m unittest tests.test_qwen3_0_6b.

About

A training codebase that isn't slop.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages