- Qwen3 dense and MoE (
omegalax/models/qwen3) with cache-aware decode inomegalax/text/api.py. - Qwen3.5 MoE and Qwen3-VL (
omegalax/models/qwen3_5,omegalax/models/qwen3_vl,omegalax/vlm/api.py). - HuggingFace safetensor loaders for all architectures:
create_qwen3_from_safetensors,create_qwen3_5_from_safetensors, andcreate_qwen3_vl_from_safetensors. - Supported models:
- Qwen3 dense:
Qwen/Qwen3-0.6B,Qwen/Qwen3-1.7B,Qwen/Qwen3-4B,Qwen/Qwen3-8B,Qwen/Qwen3-14B,Qwen/Qwen3-32B. - Qwen3 MoE:
Qwen/Qwen3-30B-A3B-Instruct-2507. - Qwen3.5:
Qwen/Qwen3.5-397B-A17B. - Qwen3-VL:
Qwen/Qwen3-VL-2B-Instruct.
- Qwen3 dense:
All tensor variables use Shazeer's shape-suffix notation.
The full dimension key lives in the omegalax.models package docstring (omegalax/models/__init__.py).
Use Python 3.11+ with a JAX build that matches your accelerator (e.g., jax[cuda12] for CUDA 12):
uv syncCreate a Qwen3 text model and run a forward+decode step:
import jax
import jax.numpy as jnp
from omegalax.text import api
rng = jax.random.key(0)
model, cfg = api.init_model("Qwen/Qwen3-0.6B", rng, tp_size=1, fsdp_size=1)
tokens = jax.random.randint(rng, (2, 32), 0, cfg.vocab_size, jnp.int32)
logits, aux_loss = api.forward(model, tokens, pad_id=0, cfg=cfg)
cache = api.make_cache(cfg, batch_size=2, token_len=32, generate_steps=8)
next_logits, cache, aux_loss = api.decode(model, cache, tokens, pad_id=0, cfg=cfg)The expected flow is:
- Start from a raw JSONL file where each line is a session with a
session_idandmessages. - Compile it into canonical payload-block ArrayRecord shards.
- Build a chunk-index dataset at the target sequence length.
- Train from the chunk-index dataset.
Example raw JSONL row:
{"session_id":"demo-0","messages":[{"role":"user","content":"Hello"},{"role":"assistant","content":"Hi there"}]}Compile a raw SFT dataset into Grain payload shards:
uv run scripts/compile_sft_dataset.py \
--data-path /path/to/train.jsonl \
--out-dir /path/to/train_payload \
--messages-per-record 128Build a chunk index for text SFT:
uv run scripts/build_sft_chunk_index.py \
--data-path /path/to/train_payload \
--out-dir /path/to/train_chunks \
--model-id qwen3-smoke \
--max-length 512If the dataset contains image content, also pass --processor (and optionally --preprocessor-config) when building the chunk index.
Run text SFT from the compiled Grain chunk-index dataset:
uv run scripts/train_text_sft.py \
--model-id qwen3-smoke \
--data-path /path/to/train_chunks \
--max-length 512 \
--batch-size 8 \
--tp-size 1 \
--fsdp-size 1Run VLM SFT from the compiled Grain chunk-index dataset:
uv run scripts/train_vlm_sft.py \
--model-id qwen3-vl-smoke \
--data-path /path/to/train_chunks \
--processor Qwen/Qwen3-VL-2B-Instruct \
--max-length 512 \
--batch-size 4 \
--tp-size 1 \
--fsdp-size 1Resume from the latest checkpoint with --resume. Training checkpoints also persist the Grain iterator state.
Export any supported model (Qwen3 dense/MoE, Qwen3.5, Qwen3-VL) to HuggingFace safetensors:
uv run scripts/export_to_hf.py --model-id qwen3-smoke --out-dir /tmp/qwen3-smoke-export --tp-size 1 --fsdp-size 1Initialize a VLM (Qwen3.5 or Qwen3-VL) and run a multimodal forward pass:
import jax
import jax.numpy as jnp
from omegalax import vlm
rng = jax.random.key(0)
model, cfg = vlm.api.init_model("qwen3.5-smoke", rng, tp_size=1, fsdp_size=1)
tokens = jnp.ones((1, 16), dtype=jnp.int32)
pixel_values = jnp.zeros((1, 3, 2, 14, 14), dtype=jnp.float32) # B, C, T, H, W
image_grid_thw = jnp.array([[1, 1, 1]], dtype=jnp.int32)
logits, aux_loss = vlm.api.forward(
model, tokens, pad_id=0, cfg=cfg, pixel_values=pixel_values, image_grid_thw=image_grid_thw
)All loaders expect a directory containing safetensors and config.json:
from huggingface_hub import snapshot_download
from omegalax.models.qwen3.params import create_qwen3_from_safetensors
ckpt_dir = snapshot_download("Qwen/Qwen3-8B")
model = create_qwen3_from_safetensors(ckpt_dir, "Qwen/Qwen3-8B", tp_size=1, fsdp_size=1)For Qwen3.5 and Qwen3-VL, use create_qwen3_5_from_safetensors(..., tp_size=1, fsdp_size=1) or create_qwen3_vl_from_safetensors(..., tp_size=1, fsdp_size=1) respectively. When starting from a raw HF config, omegalax.models.qwen3_vl.make_vl_config_from_hf() will build a matching JAX config.
Tests use absltest:
- Run all non-real-weight tests (default):
uv run --extra=torch-tests -- python -m unittest discover -s tests -p "test_*.py"- Run everything including real-weight parity suites (downloads checkpoints; slow):
OMEGALAX_RUN_REAL_WEIGHTS_TESTS=1 uv run --extra=torch-tests -- python -m unittest discover -s tests -p "test_*.py"- Smoke/tiny-model checks only (CPU-friendly, no HF downloads):
uv run --extra=torch-tests -- python -m unittest discover -s tests -p "test_*smoke.py"Run a single suite via uv run --extra=torch-tests -- python -m unittest tests.test_qwen3_0_6b.
