diff --git a/.claude/skills b/.claude/skills new file mode 120000 index 0000000..5a74acc --- /dev/null +++ b/.claude/skills @@ -0,0 +1 @@ +skills \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..8b46091 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,58 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + +jobs: + test: + name: Full Test Suite (${{ matrix.os }}, Ruby ${{ matrix.ruby }}) + runs-on: ${{ matrix.os }} + timeout-minutes: 120 + strategy: + fail-fast: false + matrix: + os: + - ubuntu-latest + - macos-14 + ruby: + - "4.0.1" + + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + submodules: recursive + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: pip + cache-dependency-path: requirements.txt + + - name: Install Linux native dependencies + if: runner.os == 'Linux' + run: | + sudo apt-get update + sudo apt-get install -y libopenblas-dev liblapack-dev liblapacke-dev gfortran + if [ ! -f /usr/include/lapacke.h ] && [ -f /usr/include/x86_64-linux-gnu/lapacke.h ]; then + sudo ln -s /usr/include/x86_64-linux-gnu/lapacke.h /usr/include/lapacke.h + fi + echo "CMAKE_INCLUDE_PATH=/usr/include/x86_64-linux-gnu:/usr/include" >> "$GITHUB_ENV" + + - name: Set up Ruby + uses: ruby/setup-ruby@v1 + with: + ruby-version: ${{ matrix.ruby }} + bundler-cache: true + + - name: Install test dependencies + run: | + bundle exec rake test:deps + echo "${GITHUB_WORKSPACE}/.venv-test/bin" >> "$GITHUB_PATH" + + - name: Run all tests + run: bundle exec rake test diff --git a/Gemfile b/Gemfile index 8a27905..5f67ed8 100644 --- a/Gemfile +++ b/Gemfile @@ -2,5 +2,8 @@ source "https://rubygems.org" gemspec -gem "minitest", "~> 5.20" -gem "rake", "~> 13.0" +# Force CI/dependency resolution to use released mlx gem, not local submodule gemspecs. +gem "mlx", ">= 0.30.7.6", "< 1.0" + +# Use local mlx-ruby submodule during development. +# gem "mlx", path: "mlx-ruby" diff --git a/Gemfile.lock b/Gemfile.lock index 1da79e7..ea2cdd3 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -2,25 +2,63 @@ PATH remote: . specs: mlx-ruby-lm (0.30.7.1) - mlx (~> 0.1) + mlx (>= 0.30.7.5, < 1.0) safetensors (~> 0.2) tokenizers (~> 0.6) GEM remote: https://rubygems.org/ specs: - minitest (5.20.0) - rake (13.1.0) + minitest (5.27.0) + mlx (0.30.7.6) + ostruct (0.6.3) + rake (13.3.1) + safetensors (0.2.2-aarch64-linux) + safetensors (0.2.2-aarch64-linux-musl) + safetensors (0.2.2-arm64-darwin) + safetensors (0.2.2-x86_64-darwin) safetensors (0.2.2-x86_64-linux) + safetensors (0.2.2-x86_64-linux-musl) + tokenizers (0.6.3-aarch64-linux) + tokenizers (0.6.3-aarch64-linux-musl) + tokenizers (0.6.3-arm64-darwin) + tokenizers (0.6.3-x86_64-darwin) tokenizers (0.6.3-x86_64-linux) + tokenizers (0.6.3-x86_64-linux-musl) PLATFORMS + aarch64-linux + aarch64-linux-musl + arm64-darwin + x86_64-darwin x86_64-linux + x86_64-linux-musl DEPENDENCIES minitest (~> 5.20) + mlx (>= 0.30.7.6, < 1.0) mlx-ruby-lm! + ostruct rake (~> 13.0) +CHECKSUMS + minitest (5.27.0) sha256=2d3b17f8a36fe7801c1adcffdbc38233b938eb0b4966e97a6739055a45fa77d5 + mlx (0.30.7.6) sha256=1bd1f6b944e990147fdbe2654ba2830f14dc9ff7dcb9c4c5a314c916b4b92d66 + mlx-ruby-lm (0.30.7.1) + ostruct (0.6.3) sha256=95a2ed4a4bd1d190784e666b47b2d3f078e4a9efda2fccf18f84ddc6538ed912 + rake (13.3.1) sha256=8c9e89d09f66a26a01264e7e3480ec0607f0c497a861ef16063604b1b08eb19c + safetensors (0.2.2-aarch64-linux) sha256=5b50146d50a76fe0395b7aef4d13a1da8fcad44e9cf0f5aead935d5d17fb04dd + safetensors (0.2.2-aarch64-linux-musl) sha256=d6dea4e4f5ca11cff8ba4c017382838df5d33d78f79fabd9a5e5e482aa6afd57 + safetensors (0.2.2-arm64-darwin) sha256=19d77df47154038974f76a4e1bac2d778ea04ca2c49abcd5b9f9c0f1a899d10b + safetensors (0.2.2-x86_64-darwin) sha256=a1dc2b415f6ef35c8887b15a6f72c673f3b008455c33aa399154c0eabf5adbcd + safetensors (0.2.2-x86_64-linux) sha256=f447d3d3110a7592b521a23f58b0251283659b27f3700ab627ac6ba517fa04ff + safetensors (0.2.2-x86_64-linux-musl) sha256=0d52871f2b672485cda73bc94807bb6bd74409a33414fdc341950ceb88f76049 + tokenizers (0.6.3-aarch64-linux) sha256=9d54a23f2e2246cc942d183af4549e3972b937d9b01f7a387cb146bf698eee84 + tokenizers (0.6.3-aarch64-linux-musl) sha256=c178d8556769256857d77fb396f8ab004b29d058f59c620a2cfc56b01b501e27 + tokenizers (0.6.3-arm64-darwin) sha256=29a6a5582dce106d846a906ee9e4254c12db45a3855c3ff6881d4be8be03e6b6 + tokenizers (0.6.3-x86_64-darwin) sha256=4b71386cc08ceff5f86b448c74b2b297c00a280a1d502399b6cda23ef94e01fd + tokenizers (0.6.3-x86_64-linux) sha256=77a45cbde59daac33bdda1a74d45c18080478992a00ee7d898e7b8d15d0b3149 + tokenizers (0.6.3-x86_64-linux-musl) sha256=a4b08c53bf0c8f7674c3abd03e013f0bb7c0c2457174b116c2872a37c64f0297 + BUNDLED WITH 4.0.6 diff --git a/README.md b/README.md index e69de29..606bf44 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,176 @@ +# mlx-ruby-lm + +Ruby LLM inference toolkit built on the `mlx` gem. + +## Included tools + +### CLI + +Executable: `mlx_lm` + +Commands: + +- `mlx_lm generate` +- `mlx_lm chat` +- `mlx_lm server` + +Example: + +```bash +mlx_lm generate --model /path/to/model --prompt "Hello" +``` + +### Ruby APIs + +- `MlxLm::LoadUtils`: load model weights/config/tokenizer from model directory. +- `MlxLm::Generate`: token generation (`generate`, `stream_generate`, `generate_step`). +- `MlxLm::SampleUtils`: samplers and logits processors (`top_p`, `top_k`, repetition penalty). +- `MlxLm::ChatTemplate`: default/chatml prompt formatting. +- `MlxLm::Server`: OpenAI-compatible chat completion server (`/v1/models`, `/v1/chat/completions`). +- `MlxLm::Quantize`: model quantization/dequantization helpers. +- `MlxLm::Perplexity`: perplexity/log-likelihood helpers. +- `MlxLm::Benchmark`: simple generation throughput and model stats helpers. +- `MlxLm::Tuner`: LoRA adapters (`LoRALinear`, `LoRAEmbedding`, `apply_lora_layers`). +- `MlxLm::ConvertUtils`: dtype conversion and parameter/size utilities. + +Minimal usage: + +```ruby +require "mlx" +require "mlx_lm" + +model, tokenizer = MlxLm::LoadUtils.load("/path/to/model") +text = MlxLm::Generate.generate(model, tokenizer, "Hello", max_tokens: 64) +puts text +``` + +## Included models + +Current registry includes 106 `model_type` values. + +Families covered include: + +- Llama/Gemma/Qwen/Phi +- Mistral/Mixtral/Granite/Cohere +- DeepSeek/GLM/InternLM/Kimi +- Mamba/RWKV/Recurrent Gemma +- MoE variants (for example `*_moe`, `mixtral`, `jamba`, `granitemoe*`) +- Vision-language variants (for example `qwen*_vl`, `kimi_vl`, `pixtral`, `lfm2-vl`) + +Registered `model_type` values: + +```text +Klear +afm7 +afmoe +apertus +baichuan_m1 +bailing_moe +bailing_moe_linear +bitnet +cohere +cohere2 +dbrx +deepseek +deepseek_v2 +deepseek_v3 +deepseek_v32 +dots1 +ernie4_5 +ernie4_5_moe +exaone +exaone4 +exaone_moe +falcon_h1 +gemma +gemma2 +gemma3 +gemma3_text +gemma3n +glm +glm4 +glm4_moe +glm4_moe_lite +glm_moe_dsa +gpt2 +gpt_bigcode +gpt_neox +gpt_oss +granite +granitemoe +granitemoehybrid +helium +hunyuan +hunyuan_v1_dense +internlm2 +internlm3 +iquestloopcoder +jamba +kimi_k25 +kimi_linear +kimi_vl +lfm2 +lfm2-vl +lfm2_moe +lille-130m +llama +llama4 +llama4_text +longcat_flash +longcat_flash_ngram +mamba +mamba2 +mimo +mimo_v2_flash +minicpm +minicpm3 +minimax +ministral3 +mistral3 +mixtral +nanochat +nemotron +nemotron-nas +nemotron_h +olmo +olmo2 +olmo3 +olmoe +openelm +phi +phi3 +phi3small +phimoe +phixtral +pixtral +plamo +plamo2 +qwen +qwen2 +qwen2_moe +qwen2_vl +qwen3 +qwen3_5 +qwen3_5_moe +qwen3_moe +qwen3_next +qwen3_vl +qwen3_vl_moe +recurrent_gemma +rwkv7 +seed_oss +smollm3 +solar_open +stablelm +starcoder2 +step3p5 +telechat3 +youtu_llm +``` + +To inspect the current registry from Ruby: + +```ruby +require "mlx_lm" +puts MlxLm::Models::REGISTRY.keys.sort +``` diff --git a/Rakefile b/Rakefile index 2906ed6..116b703 100644 --- a/Rakefile +++ b/Rakefile @@ -1,4 +1,10 @@ require "rake/testtask" +require_relative "tasks/onnx_report_task" +require_relative "tasks/parity_inventory_task" + +VENV_DIR = File.expand_path(".venv-test", __dir__) +VENV_PYTHON = File.join(VENV_DIR, "bin", "python") +REQUIREMENTS_FILE = File.expand_path("requirements.txt", __dir__) Rake::TestTask.new(:test) do |t| t.libs << "test" << "lib" @@ -6,10 +12,40 @@ Rake::TestTask.new(:test) do |t| end namespace :test do + desc "Install Python dependencies required by parity tests" + task :deps do + next unless File.exist?(REQUIREMENTS_FILE) + + sh("python3 -m venv #{VENV_DIR}") unless File.exist?(VENV_PYTHON) + sh("#{VENV_PYTHON} -m pip install --upgrade pip") + sh("#{VENV_PYTHON} -m pip install -r #{REQUIREMENTS_FILE}") + end + Rake::TestTask.new(:parity) do |t| t.libs << "test" << "lib" t.test_files = FileList["test/parity/**/*_test.rb"] end end +namespace :parity do + desc "Regenerate the Python/Ruby parity inventory snapshot" + task :inventory do + ParityInventoryTask.run! + end + + desc "Verify the parity inventory snapshot is up-to-date" + task :inventory_check do + next if ParityInventoryTask.run!(check: true) + + raise "parity inventory snapshot is stale" + end +end + +namespace :onnx do + desc "Run compat-only ONNX suite and generate report artifacts under test/reports" + task :report do + OnnxReportTask.run! + end +end + task default: :test diff --git a/lib/mlx_lm.rb b/lib/mlx_lm.rb index eafdbc8..13714bc 100644 --- a/lib/mlx_lm.rb +++ b/lib/mlx_lm.rb @@ -6,19 +6,120 @@ require_relative "mlx_lm/sample_utils" require_relative "mlx_lm/models" require_relative "mlx_lm/models/cache" +require_relative "mlx_lm/models/activations" +require_relative "mlx_lm/models/bitlinear_layers" +require_relative "mlx_lm/models/bitnet" +require_relative "mlx_lm/models/gated_delta" +require_relative "mlx_lm/models/pipeline" +require_relative "mlx_lm/models/rope_utils" +require_relative "mlx_lm/models/ssm" +require_relative "mlx_lm/models/mla" require_relative "mlx_lm/models/llama" require_relative "mlx_lm/models/gemma" require_relative "mlx_lm/models/qwen2" +require_relative "mlx_lm/models/qwen2_vl" +require_relative "mlx_lm/models/qwen" +require_relative "mlx_lm/models/qwen3" +require_relative "mlx_lm/models/qwen3_vl" +require_relative "mlx_lm/models/qwen3_5" +require_relative "mlx_lm/models/qwen3_5_moe" +require_relative "mlx_lm/models/phi" require_relative "mlx_lm/models/phi3" +require_relative "mlx_lm/models/exaone" +require_relative "mlx_lm/models/exaone4" +require_relative "mlx_lm/models/glm" +require_relative "mlx_lm/models/glm4" +require_relative "mlx_lm/models/helium" +require_relative "mlx_lm/models/olmo" +require_relative "mlx_lm/models/seed_oss" require_relative "mlx_lm/models/starcoder2" require_relative "mlx_lm/models/stablelm" require_relative "mlx_lm/models/cohere" +require_relative "mlx_lm/models/cohere2" +require_relative "mlx_lm/models/pixtral" require_relative "mlx_lm/models/gemma2" +require_relative "mlx_lm/models/gemma3_text" +require_relative "mlx_lm/models/gemma3" +require_relative "mlx_lm/models/gemma3n" +require_relative "mlx_lm/models/granite" +require_relative "mlx_lm/models/granitemoe" require_relative "mlx_lm/models/olmo2" +require_relative "mlx_lm/models/olmoe" +require_relative "mlx_lm/models/openelm" require_relative "mlx_lm/models/gpt_neox" +require_relative "mlx_lm/models/switch_layers" +require_relative "mlx_lm/models/qwen3_moe" +require_relative "mlx_lm/models/qwen3_vl_moe" require_relative "mlx_lm/models/mixtral" +require_relative "mlx_lm/models/phixtral" +require_relative "mlx_lm/models/mistral3" +require_relative "mlx_lm/models/minicpm" +require_relative "mlx_lm/models/minicpm3" +require_relative "mlx_lm/models/nanochat" +require_relative "mlx_lm/models/smollm3" +require_relative "mlx_lm/models/lfm2" +require_relative "mlx_lm/models/lfm2_vl" require_relative "mlx_lm/models/deepseek" +require_relative "mlx_lm/models/deepseek_v2" +require_relative "mlx_lm/models/deepseek_v3" +require_relative "mlx_lm/models/deepseek_v32" +require_relative "mlx_lm/models/glm_moe_dsa" +require_relative "mlx_lm/models/kimi_k25" +require_relative "mlx_lm/models/kimi_vl" require_relative "mlx_lm/models/internlm2" +require_relative "mlx_lm/models/internlm3" +require_relative "mlx_lm/models/telechat3" +require_relative "mlx_lm/models/olmo3" +require_relative "mlx_lm/models/gpt2" +require_relative "mlx_lm/models/gpt_bigcode" +require_relative "mlx_lm/models/nemotron" +require_relative "mlx_lm/models/apertus" +require_relative "mlx_lm/models/youtu_llm" +require_relative "mlx_lm/models/ernie4_5" +require_relative "mlx_lm/models/ernie4_5_moe" +require_relative "mlx_lm/models/baichuan_m1" +require_relative "mlx_lm/models/solar_open" +require_relative "mlx_lm/models/lille_130m" +require_relative "mlx_lm/models/mimo" +require_relative "mlx_lm/models/qwen2_moe" +require_relative "mlx_lm/models/phimoe" +require_relative "mlx_lm/models/llama4_text" +require_relative "mlx_lm/models/plamo" +require_relative "mlx_lm/models/mamba" +require_relative "mlx_lm/models/mamba2" +require_relative "mlx_lm/models/hunyuan_v1_dense" +require_relative "mlx_lm/models/dbrx" +require_relative "mlx_lm/models/klear" +require_relative "mlx_lm/models/iquestloopcoder" +require_relative "mlx_lm/models/phi3small" +require_relative "mlx_lm/models/dots1" +require_relative "mlx_lm/models/llama4" +require_relative "mlx_lm/models/ministral3" +require_relative "mlx_lm/models/hunyuan" +require_relative "mlx_lm/models/gpt_oss" +require_relative "mlx_lm/models/mimo_v2_flash" +require_relative "mlx_lm/models/lfm2_moe" +require_relative "mlx_lm/models/afmoe" +require_relative "mlx_lm/models/bailing_moe" +require_relative "mlx_lm/models/exaone_moe" +require_relative "mlx_lm/models/glm4_moe" +require_relative "mlx_lm/models/minimax" +require_relative "mlx_lm/models/nemotron_nas" +require_relative "mlx_lm/models/recurrent_gemma" +require_relative "mlx_lm/models/step3p5" +require_relative "mlx_lm/models/afm7" +require_relative "mlx_lm/models/bailing_moe_linear" +require_relative "mlx_lm/models/falcon_h1" +require_relative "mlx_lm/models/glm4_moe_lite" +require_relative "mlx_lm/models/granitemoehybrid" +require_relative "mlx_lm/models/jamba" +require_relative "mlx_lm/models/kimi_linear" +require_relative "mlx_lm/models/longcat_flash" +require_relative "mlx_lm/models/longcat_flash_ngram" +require_relative "mlx_lm/models/nemotron_h" +require_relative "mlx_lm/models/plamo2" +require_relative "mlx_lm/models/qwen3_next" +require_relative "mlx_lm/models/rwkv7" require_relative "mlx_lm/generate" require_relative "mlx_lm/quantize" require_relative "mlx_lm/load_utils" diff --git a/lib/mlx_lm/benchmark.rb b/lib/mlx_lm/benchmark.rb index 9c3e3e3..64c0a8f 100644 --- a/lib/mlx_lm/benchmark.rb +++ b/lib/mlx_lm/benchmark.rb @@ -6,9 +6,8 @@ module Benchmark def measure_generation(model, prompt_tokens: 32, gen_tokens: 64, vocab_size: 32000) mx = MLX::Core - # Create random prompt tokens (generate float then cast to int) - prompt = mx.random_uniform([prompt_tokens], 0.0, (vocab_size - 1).to_f, mx.float32) - prompt = prompt.astype(mx.int32) + # Create random prompt tokens + prompt = mx.random_uniform([prompt_tokens], 0.0, (vocab_size - 1).to_f, mx.float32).astype(mx.int32) mx.eval(prompt) # Create cache diff --git a/lib/mlx_lm/generate.rb b/lib/mlx_lm/generate.rb index 597ac7b..2cad83c 100644 --- a/lib/mlx_lm/generate.rb +++ b/lib/mlx_lm/generate.rb @@ -71,7 +71,7 @@ def generate_step( } # Prompt prefilling - process prompt in chunks - prompt_arr = prompt.is_a?(::Array) ? mx.array(prompt).astype(mx.uint32) : prompt + prompt_arr = prompt.is_a?(::Array) ? mx.array(prompt, dtype: mx.uint32) : prompt total_prompt_tokens = prompt_arr.size # Process prompt chunks (all but last token) @@ -115,7 +115,7 @@ def stream_generate(model, tokenizer, prompt, max_tokens: 256, **kwargs) if prompt.is_a?(String) prompt = tokenizer.encode(prompt) end - prompt = MLX::Core.array(prompt).astype(MLX::Core.uint32) + prompt = MLX::Core.array(prompt, dtype: MLX::Core.uint32) end detokenizer = tokenizer.detokenizer diff --git a/lib/mlx_lm/models.rb b/lib/mlx_lm/models.rb index 21477c6..a417d9f 100644 --- a/lib/mlx_lm/models.rb +++ b/lib/mlx_lm/models.rb @@ -7,6 +7,7 @@ module Models # Remapping for architectures that share implementation REMAPPING = { "mistral" => "llama", + "falcon_mamba" => "mamba", }.freeze module_function diff --git a/lib/mlx_lm/models/activations.rb b/lib/mlx_lm/models/activations.rb new file mode 100644 index 0000000..25533d9 --- /dev/null +++ b/lib/mlx_lm/models/activations.rb @@ -0,0 +1,46 @@ +module MlxLm + module Models + module Activations + module_function + + def swiglu(gate, x) + MLX::NN.silu(gate) * x + end + + def xielu(x, alpha_p, alpha_n, beta, eps) + mx = MLX::Core + alpha_p = MLX::NN.softplus(alpha_p) + alpha_n = beta + MLX::NN.softplus(alpha_n) + + mx.where( + mx.greater(x, 0.0), + alpha_p * mx.square(x) + beta * x, + (mx.expm1(mx.minimum(x, eps)) - x) * alpha_n + beta * x + ) + end + + class XieLU < MLX::NN::Module + def initialize( + alpha_p_init: 0.8, + alpha_n_init: 0.8, + beta: 0.5, + eps: -1e-6 + ) + super() + mx = MLX::Core + alpha_p_tensor = mx.array(alpha_p_init) + alpha_n_tensor = mx.array(alpha_n_init - beta) + + self.alpha_p = mx.log(mx.exp(alpha_p_tensor) - 1.0) + self.alpha_n = mx.log(mx.exp(alpha_n_tensor) - 1.0) + self.beta = mx.array(beta) + self.eps = mx.array(eps) + end + + def call(x) + Activations.xielu(x, alpha_p, alpha_n, beta, eps) + end + end + end + end +end diff --git a/lib/mlx_lm/models/afm7.rb b/lib/mlx_lm/models/afm7.rb new file mode 100644 index 0000000..c757c55 --- /dev/null +++ b/lib/mlx_lm/models/afm7.rb @@ -0,0 +1,131 @@ +require_relative "afmoe" + +module MlxLm + module Models + module Afm7 + class ModelArgs < Afmoe::ModelArgs + field :model_type, default: "afm7" + field :hidden_dim, default: nil + field :num_layers, default: nil + field :num_kv_reuse_layers, default: 0 + field :num_heads, default: nil + field :num_kv_heads, default: nil + field :hidden_dim_scale_factor, default: nil + + def initialize(**kwargs) + afm7_style = _afm7_style_kwargs?(kwargs) + super + + @hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !@hidden_dim.nil? + @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !@num_layers.nil? + @num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !@num_heads.nil? + @num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !@num_kv_heads.nil? + + if kwargs.key?(:hidden_dim_scale_factor) && !@hidden_dim_scale_factor.nil? && !@hidden_size.nil? + @intermediate_size = (@hidden_size * @hidden_dim_scale_factor.to_f).to_i + end + + if !@hidden_size.nil? && !@num_attention_heads.nil? && @num_attention_heads.to_i > 0 + @head_dim = @hidden_size / @num_attention_heads + end + + if kwargs.key?(:num_kv_reuse_layers) && !@num_hidden_layers.nil? + @num_dense_layers = [@num_hidden_layers.to_i - @num_kv_reuse_layers.to_i, 0].max + elsif afm7_style && !@num_hidden_layers.nil? + @num_dense_layers = @num_hidden_layers + end + + if afm7_style + @num_experts = 1 unless kwargs.key?(:num_experts) + @num_experts_per_tok = 1 unless kwargs.key?(:num_experts_per_tok) + @num_shared_experts = 0 unless kwargs.key?(:num_shared_experts) + @mup_enabled = false unless kwargs.key?(:mup_enabled) + @layer_types = Array.new(@num_hidden_layers) { "full_attention" } unless kwargs.key?(:layer_types) + end + + @num_key_value_heads ||= @num_attention_heads + @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" } unless @num_hidden_layers.nil? + end + + def to_afmoe_dict + { + "model_type" => @model_type, + "layer_types" => @layer_types, + "vocab_size" => @vocab_size, + "hidden_size" => @hidden_size, + "intermediate_size" => @intermediate_size, + "moe_intermediate_size" => @moe_intermediate_size, + "num_hidden_layers" => @num_hidden_layers, + "num_attention_heads" => @num_attention_heads, + "num_key_value_heads" => @num_key_value_heads, + "head_dim" => @head_dim, + "max_position_embeddings" => @max_position_embeddings, + "rms_norm_eps" => @rms_norm_eps, + "rope_theta" => @rope_theta, + "rope_scaling" => @rope_scaling, + "tie_word_embeddings" => @tie_word_embeddings, + "num_experts" => @num_experts, + "num_experts_per_tok" => @num_experts_per_tok, + "num_shared_experts" => @num_shared_experts, + "num_dense_layers" => @num_dense_layers, + "route_norm" => @route_norm, + "route_scale" => @route_scale, + "score_func" => @score_func, + "n_group" => @n_group, + "topk_group" => @topk_group, + "sliding_window" => @sliding_window, + "mup_enabled" => @mup_enabled, + } + end + + private + + def _afm7_style_kwargs?(kwargs) + kwargs.key?(:hidden_dim) || + kwargs.key?(:num_layers) || + kwargs.key?(:num_heads) || + kwargs.key?(:num_kv_heads) || + kwargs.key?(:num_kv_reuse_layers) || + kwargs.key?(:hidden_dim_scale_factor) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = Afmoe::Model.new(Afmoe::ModelArgs.from_dict(args.to_afmoe_dict)) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + wrapped_model.sanitize(weights) + end + + def layers + wrapped_model.layers + end + + def make_cache + return nil unless wrapped_model.respond_to?(:make_cache) + + wrapped_model.make_cache + end + + def cast_predicate + wrapped_model.cast_predicate + end + + def quant_predicate + wrapped_model.quant_predicate + end + end + + Models.register("afm7", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/afmoe.rb b/lib/mlx_lm/models/afmoe.rb new file mode 100644 index 0000000..d8611d4 --- /dev/null +++ b/lib/mlx_lm/models/afmoe.rb @@ -0,0 +1,421 @@ +require_relative "activations" +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module Afmoe + class ModelArgs < BaseModelArgs + field :model_type + field :layer_types + field :vocab_size, default: 200_192 + field :hidden_size, default: 2048 + field :intermediate_size, default: 6144 + field :moe_intermediate_size, default: 1024 + field :num_hidden_layers, default: 32 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: 4 + field :head_dim, default: 64 + field :max_position_embeddings, default: 131_072 + field :rms_norm_eps, default: 1e-5 + field :rope_theta, default: 10_000.0 + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + field :num_experts, default: 128 + field :num_experts_per_tok, default: 8 + field :num_shared_experts, default: 1 + field :num_dense_layers, default: 2 + field :route_norm, default: true + field :route_scale, default: 2.826 + field :score_func, default: "sigmoid" + field :n_group, default: 1 + field :topk_group, default: 1 + field :sliding_window, default: 2048 + field :mup_enabled, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" } + end + end + + class Attention < MLX::NN::Module + def initialize(args, is_local_attention: false) + super() + @hidden_size = args.hidden_size + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = args.head_dim + @is_local_attention = is_local_attention + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new( + @hidden_size, + @num_attention_heads * @head_dim, + bias: false + ) + self.k_proj = MLX::NN::Linear.new( + @hidden_size, + @num_key_value_heads * @head_dim, + bias: false + ) + self.v_proj = MLX::NN::Linear.new( + @hidden_size, + @num_key_value_heads * @head_dim, + bias: false + ) + self.o_proj = MLX::NN::Linear.new( + @num_attention_heads * @head_dim, + @hidden_size, + bias: false + ) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.gate_proj = MLX::NN::Linear.new( + @hidden_size, + @num_attention_heads * @head_dim, + bias: false + ) + + if @is_local_attention + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + queries = q_norm.call(queries) + keys = k_norm.call(keys) + + if @is_local_attention && respond_to?(:rope) + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + end + + if cache + keys, values = cache.update_and_fetch(keys, values) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + + gate = mx.sigmoid(gate_proj.call(x)) + output = output * gate + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args, intermediate_size: nil) + super() + dim = args.hidden_size + hidden_dim = intermediate_size || args.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class MoERouter < MLX::NN::Module + def initialize(args) + super() + self.gate = MLX::NN::Linear.new(args.hidden_size, args.num_experts, bias: false) + end + + def call(x) + gate.call(x) + end + end + + class AfmoeMoE < MLX::NN::Module + def initialize(args) + super() + @args = args + @num_experts = args.num_experts + @num_experts_per_tok = args.num_experts_per_tok + @route_norm = args.route_norm + @route_scale = args.route_scale + @score_func = args.score_func + @n_group = args.n_group + @topk_group = args.topk_group + + self.router = MoERouter.new(args) + self.expert_bias = MLX::Core.zeros([args.num_experts]) + self.experts = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.moe_intermediate_size, + args.num_experts + ) + + if args.num_shared_experts.to_i > 0 + shared_intermediate_size = args.moe_intermediate_size * args.num_shared_experts + self.shared_experts = MLP.new(args, intermediate_size: shared_intermediate_size) + end + end + + def call(x) + mx = MLX::Core + + gates = router.call(x) + scores = if @score_func == "sigmoid" + mx.sigmoid(gates.astype(mx.float32)) + else + mx.softmax(gates.astype(mx.float32), -1) + end + + selection_scores = scores + expert_bias + + if @n_group.to_i > 1 + experts_per_group = selection_scores.shape[-1] / @n_group + selection_scores = mx.unflatten(selection_scores, -1, [@n_group, experts_per_group]) + group_scores = mx.topk(selection_scores, 2, -1) + group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1) + + drop_count = @n_group - @topk_group.to_i + if drop_count > 0 + group_idx = mx.argpartition(group_scores, drop_count - 1, -2) + take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32) + group_idx = mx.take(group_idx, take_ids, -2) + selection_scores = mx.put_along_axis( + selection_scores, + mx.stop_gradient(group_idx), + mx.array(0.0), + -2 + ) + end + + selection_scores = mx.flatten(selection_scores, -2, -1) + end + + k = [@num_experts_per_tok.to_i, selection_scores.shape[-1]].min + inds = mx.argpartition(selection_scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + selected_scores = mx.take_along_axis(scores, inds, -1) + if @route_norm && k > 1 + denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1) + selected_scores = selected_scores / denominator + end + selected_scores = selected_scores * @route_scale + + y = experts.call(x, inds) + y = mx.sum(y * mx.expand_dims(selected_scores, -1), -2).astype(y.dtype) + y = y + shared_experts.call(x) if @args.num_shared_experts.to_i > 0 + y + end + end + + class DecoderLayer < MLX::NN::Module + attr_reader :use_sliding + + def initialize(args, layer_idx, use_sliding: false) + super() + @use_sliding = use_sliding + self.self_attn = Attention.new(args, is_local_attention: @use_sliding) + self.mlp = if layer_idx < args.num_dense_layers + MLP.new(args) + else + AfmoeMoE.new(args) + end + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.pre_mlp_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_mlp_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + r = post_attention_layernorm.call(r) + h = x + r + + r = mlp.call(pre_mlp_layernorm.call(h)) + r = post_mlp_layernorm.call(r) + h + r + end + end + + class AfmoeModel < MLX::NN::Module + attr_reader :layer_types, :sliding_window + + def initialize(args) + super() + @hidden_size = args.hidden_size + @layer_types = args.layer_types + @sliding_window = args.sliding_window + @mup_enabled = args.mup_enabled + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = @layer_types.each_with_index.map do |layer_type, idx| + DecoderLayer.new(args, idx, use_sliding: layer_type == "sliding_attention") + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + + self.fa_idx = @layer_types.index("full_attention") || 0 + self.swa_idx = @layer_types.index("sliding_attention") + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + h = h * Math.sqrt(@hidden_size) if @mup_enabled + + layer_cache = cache || [nil] * layers.length + full_mask = _create_attention_mask(h, layer_cache[fa_idx]) + sliding_mask = if swa_idx.nil? + nil + else + _create_attention_mask(h, layer_cache[swa_idx], window_size: @sliding_window) + end + + layers.each_with_index do |layer, i| + mask = layer.use_sliding ? sliding_mask : full_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset if cache.respond_to?(:offset) + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = AfmoeModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + mx = MLX::Core + result = weights.reject { |key, _| key.to_s.include?("rotary_emb.inv_freq") } + result = result.dup + result.delete("lm_head.weight") if @args.tie_word_embeddings + + @args.num_hidden_layers.times do |layer_idx| + next if layer_idx < @args.num_dense_layers.to_i + + prefix = "model.layers.#{layer_idx}" + %w[up_proj down_proj gate_proj].each do |projection| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.mlp.experts.#{projection}.#{param}"] = mx.stack(stacked) + end + end + end + + result + end + + def layers + model.layers + end + + def make_cache + layers.map do |layer| + if layer.use_sliding + MlxLm::RotatingKVCache.new(max_size: model.sliding_window) + else + MlxLm::KVCache.new + end + end + end + + def cast_predicate + lambda { |key| !key.to_s.include?("expert_bias") } + end + + def quant_predicate + lambda do |path, _| + if path.to_s.include?("router.gate") + { group_size: 64, bits: 8 } + else + true + end + end + end + end + + Models.register("afmoe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/apertus.rb b/lib/mlx_lm/models/apertus.rb new file mode 100644 index 0000000..6d563f0 --- /dev/null +++ b/lib/mlx_lm/models/apertus.rb @@ -0,0 +1,179 @@ +module MlxLm + module Models + module Apertus + class ModelArgs < BaseModelArgs + field :model_type + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :mlp_bias + field :num_attention_heads + field :attention_bias + field :rms_norm_eps + field :vocab_size + field :num_key_value_heads + field :max_position_embeddings + field :rope_theta + field :post_norm + field :qk_norm + field :tie_word_embeddings + field :rope_traditional, default: false + field :rope_scaling, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class ApertusMLP < MLX::NN::Module + def initialize(args) + super() + self.up_proj = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size, + bias: args.mlp_bias + ) + self.down_proj = MLX::NN::Linear.new( + args.intermediate_size, + args.hidden_size, + bias: args.mlp_bias + ) + self.act_fn = Activations::XieLU.new + end + + def call(x) + down_proj.call(act_fn.call(up_proj.call(x))) + end + end + + class ApertusAttention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = dim / @num_attention_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @num_attention_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@num_attention_heads * @head_dim, dim, bias: false) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @num_attention_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @num_key_value_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + o_proj.call(output) + end + end + + class ApertusDecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = ApertusAttention.new(args) + self.mlp = ApertusMLP.new(args) + self.attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.feedforward_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + h = x + self_attn.call(attention_layernorm.call(x), mask: mask, cache: cache) + h + mlp.call(feedforward_layernorm.call(h)) + end + end + + class ApertusModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { ApertusDecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.args = args + self.model_type = args.model_type + self.model = ApertusModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + lm_head.call(out) + end + + def sanitize(weights) + mx = MLX::Core + weights.each do |k, v| + if k.end_with?("alpha_p") || k.end_with?("alpha_n") + weights[k] = mx.squeeze(v) + end + end + weights + end + + def layers + model.layers + end + end + + Models.register("apertus", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/baichuan_m1.rb b/lib/mlx_lm/models/baichuan_m1.rb new file mode 100644 index 0000000..7599ef2 --- /dev/null +++ b/lib/mlx_lm/models/baichuan_m1.rb @@ -0,0 +1,306 @@ +module MlxLm + module Models + module BaichuanM1 + class ModelArgs < BaseModelArgs + field :vocab_size + field :hidden_size + field :intermediate_size + field :num_hidden_layers + field :num_attention_heads + field :num_key_value_heads + field :rope_theta + field :sliding_window + field :sliding_window_layers + field :conv_window + field :rms_norm_eps + field :model_type, default: "baichuan_m1" + field :num_swa_attention_heads, default: nil + field :num_swa_key_value_heads, default: nil + field :tie_word_embeddings, default: false + end + + class Attention < MLX::NN::Module + def initialize(config, layer_idx: nil) + super() + + raise ArgumentError, "Layer index must be provided to Attention module." if layer_idx.nil? + + swa_layers = config.sliding_window_layers || [] + @is_swa = swa_layers.include?(layer_idx) + + @num_heads = if @is_swa && config.num_swa_attention_heads + config.num_swa_attention_heads + else + config.num_attention_heads + end + + @num_kv_heads = if @is_swa && config.num_swa_key_value_heads + config.num_swa_key_value_heads + else + config.num_key_value_heads + end + + @hidden_size = config.hidden_size + @head_dim = @hidden_size / @num_heads + + unless (@head_dim * @num_heads) == @hidden_size + raise ArgumentError, "hidden_size must be divisible by num_heads" + end + + @scale = @head_dim**(-0.5) + + self.w_pack = MLX::NN::Linear.new( + config.hidden_size, + @hidden_size + 2 * @num_kv_heads * @head_dim, + bias: false + ) + self.o_proj = MLX::NN::Linear.new( + @num_heads * @head_dim, + config.hidden_size, + bias: false + ) + + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: config.rope_theta) + + @conv_window = config.conv_window + raise ArgumentError, "conv_window must be 2" unless @conv_window == 2 + + mx = MLX::Core + self.conv_k = mx.zeros([1, 1, @num_kv_heads, 1, @conv_window]) + self.conv_v = mx.zeros([1, 1, @num_kv_heads, 1, @conv_window]) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, d = x.shape + + proj = w_pack.call(x) + q, k, v = mx.split(proj, [d, d + @num_kv_heads * @head_dim], -1) + + q = q.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + k = k.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + v = v.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + layer_cache = cache || [nil, nil] + conv_cache = layer_cache[0] + kv_cache = layer_cache[1] + + if conv_cache + offset = kv_cache.offset + last_k = conv_cache[0] + last_v = conv_cache[1] + else + offset = 0 + last_k = nil + last_v = nil + end + + k_init = k + v_init = v + + k = _custom_convolution(k, conv_k, state: last_k) + v = _custom_convolution(v, conv_v, state: last_v) + q = rope.call(q, offset: offset) + k = rope.call(k, offset: offset) + + if conv_cache + k, v = kv_cache.update_and_fetch(k, v) + if l > 0 + conv_cache[0] = mx.split(k_init, [l - 1], 2)[1] + conv_cache[1] = mx.split(v_init, [l - 1], 2)[1] + end + end + + out = mx.scaled_dot_product_attention(q, k, v, @scale, mask) + out = out.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @head_dim]) + o_proj.call(out) + end + + private + + def _custom_convolution(u, weights, state: nil) + mx = MLX::Core + b, h, l, d = u.shape + + weights = weights.reshape([1, h, @conv_window, 1, 1]) + w0 = mx.take(weights, 0, 2) + w1 = mx.take(weights, 1, 2) + + state ||= mx.zeros([b, h, 1, d], u.dtype) + if l > 1 + u_prev = mx.concatenate([state, mx.split(u, [l - 1], 2)[0]], 2) + else + u_prev = state + end + + mx.add(mx.multiply(u_prev, w0), mx.multiply(u, w1)) + end + end + + class MLP < MLX::NN::Module + def initialize(config) + super() + self.gate_proj = MLX::NN::Linear.new( + config.hidden_size, + config.intermediate_size, + bias: false + ) + self.up_proj = MLX::NN::Linear.new( + config.hidden_size, + config.intermediate_size, + bias: false + ) + self.down_proj = MLX::NN::Linear.new( + config.intermediate_size, + config.hidden_size, + bias: false + ) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(config, layer_idx) + super() + self.self_attn = Attention.new(config, layer_idx: layer_idx) + self.mlp = MLP.new(config) + self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class BaichuanModel < MLX::NN::Module + def initialize(config) + super() + @config = config + @sliding_window = config.sliding_window + @swa_layers = config.sliding_window_layers || [] + + self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size) + self.layers = Array.new(config.num_hidden_layers) { |i| DecoderLayer.new(config, i) } + self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + + self.first_swa_idx = @swa_layers.empty? ? nil : @swa_layers[0] + self.first_global_idx = nil + config.num_hidden_layers.times do |i| + next if @swa_layers.include?(i) + + self.first_global_idx = i + break + end + end + + def call(inputs, cache: nil) + x = embed_tokens.call(inputs) + layer_cache = cache || Array.new(layers.length) { [nil, nil] } + + c_global = first_global_idx.nil? ? nil : layer_cache[first_global_idx][1] + c_swa = first_swa_idx.nil? ? nil : layer_cache[first_swa_idx][1] + + global_mask = _create_attention_mask(x, c_global) + swa_mask = _create_attention_mask(x, c_swa, window_size: @sliding_window) + + layers.each_with_index do |layer, i| + mask = @swa_layers.include?(i) ? swa_mask : global_mask + x = layer.call(x, mask: mask, cache: layer_cache[i]) + end + + norm.call(x) + end + + private + + def _create_attention_mask(x, cache = nil, window_size: nil) + n = x.shape[1] + return cache.make_mask(n, window_size: window_size) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + return _create_causal_mask(n, window_size: window_size) if window_size && n > window_size + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + def initialize(config) + super() + @config = config + self.model_type = config.model_type + self.model = BaichuanModel.new(config) + @tie_word_embeddings = config.tie_word_embeddings + unless @tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false) + end + end + + def make_cache + caches = [] + swa_layers = @config.sliding_window_layers || [] + @config.num_hidden_layers.times do |i| + is_swa = swa_layers.include?(i) + conv_cache = MlxLm::ArraysCache.new(2) + kv_cache = if is_swa + MlxLm::RotatingKVCache.new(max_size: @config.sliding_window) + else + MlxLm::KVCache.new + end + caches << MlxLm::CacheList.new(conv_cache, kv_cache) + end + caches + end + + def sanitize(weights) + mx = MLX::Core + is_quantized = weights.key?("lm_head.scales") + + if !is_quantized && weights.key?("lm_head.weight") + w = weights["lm_head.weight"] + dtype = w.dtype + w = w.astype(mx.float32) + norm = mx.norm(w, nil, -1, true) + w = (w / (norm + 1e-7)).astype(dtype) + weights["lm_head.weight"] = w + end + + weights + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + end + + Models.register("baichuan_m1", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/bailing_moe.rb b/lib/mlx_lm/models/bailing_moe.rb new file mode 100644 index 0000000..16ace89 --- /dev/null +++ b/lib/mlx_lm/models/bailing_moe.rb @@ -0,0 +1,399 @@ +require_relative "activations" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module BailingMoe + class ModelArgs < BaseModelArgs + field :model_type + field :hidden_size + field :intermediate_size + field :max_position_embeddings + field :moe_intermediate_size + field :num_experts + field :num_shared_experts + field :norm_topk_prob + field :num_attention_heads + field :num_experts_per_tok + field :num_hidden_layers + field :num_key_value_heads + field :rms_norm_eps + field :rope_theta + field :vocab_size + field :first_k_dense_replace + field :rope_scaling, default: nil + field :use_bias, default: false + field :use_qkv_bias, default: false + field :norm_head, default: false + field :norm_softmax, default: false + field :use_qk_norm, default: false + field :tie_word_embeddings, default: false + field :partial_rotary_factor, default: 1.0 + field :rotary_dim, default: nil + field :moe_router_enable_expert_bias, default: false + field :moe_router_enable_routed_scaling, default: true + field :routed_scaling_factor, default: 1.0 + field :score_function, default: "softmax" + field :n_group, default: 1 + field :topk_group, default: 4 + field :moe_shared_expert_intermediate_size, default: nil + field :moe_router_enable_shared_expert, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + module_function + + def aggregate_expert_outputs(expert_outputs, scores) + mx = MLX::Core + mx.sum(expert_outputs * mx.expand_dims(scores, -1), -2).astype(expert_outputs.dtype) + end + + def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob, + score_function + ) + mx = MLX::Core + in_type = gates.dtype + + scores = if score_function == "sigmoid" + mx.sigmoid(gates.astype(mx.float32)) + else + mx.softmax(gates.astype(mx.float32), -1) + end + orig_scores = scores + scores = scores + e_score_correction_bias if e_score_correction_bias + + if n_group.to_i > 1 + experts_per_group = scores.shape[-1] / n_group + scores = mx.unflatten(scores, -1, [n_group, experts_per_group]) + group_scores = mx.topk(scores, 2, -1) + group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1) + + drop_count = n_group - topk_group.to_i + if drop_count > 0 + group_idx = mx.argpartition(group_scores, drop_count - 1, -2) + take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32) + group_idx = mx.take(group_idx, take_ids, -2) + scores = mx.put_along_axis( + scores, + mx.stop_gradient(group_idx), + mx.array(0.0), + -2 + ) + end + + scores = mx.flatten(scores, -2, -1) + end + + k = [top_k.to_i, scores.shape[-1]].min + inds = mx.argpartition(scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + selected_scores = mx.take_along_axis(orig_scores, inds, -1) + if k > 1 && norm_topk_prob + denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1) + 1e-20 + selected_scores = selected_scores / denominator + end + + selected_scores = selected_scores * routed_scaling_factor.to_f + [inds, selected_scores.astype(in_type)] + end + + class BailingMoeMLP < MLX::NN::Module + def initialize(args, intermediate_size: nil) + super() + hidden_dim = intermediate_size || args.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.use_bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, args.hidden_size, bias: args.use_bias) + self.up_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.use_bias) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class BailingMoeAttention < MLX::NN::Module + def initialize(args) + super() + @use_qk_norm = args.use_qk_norm + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = args.hidden_size / @num_attention_heads + @scale = @head_dim**(-0.5) + + self.query_key_value = MLX::NN::Linear.new( + args.hidden_size, + (@num_attention_heads + 2 * @num_key_value_heads) * @head_dim, + bias: args.use_qkv_bias + ) + self.dense = MLX::NN::Linear.new( + @num_attention_heads * @head_dim, + args.hidden_size, + bias: args.use_bias + ) + + if @use_qk_norm + self.key_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.query_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + end + + rope_dim = args.rotary_dim || (@head_dim * args.partial_rotary_factor.to_f).to_i + rope_dim = [rope_dim, 1].max + self.rope = MlxLm::Models.initialize_rope( + rope_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = query_key_value.call(x) + + q_size = @num_attention_heads * @head_dim + kv_size = @num_key_value_heads * @head_dim + q, k, v = mx.split(qkv, [q_size, q_size + kv_size], -1) + + queries = q.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if @use_qk_norm + queries = query_layernorm.call(queries) + keys = key_layernorm.call(keys) + end + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + dense.call(output) + end + end + + class BailingMoeGate < MLX::NN::Module + def initialize(args) + super() + @norm_topk_prob = args.norm_topk_prob + @top_k = args.num_experts_per_tok + @n_group = args.n_group + @topk_group = args.topk_group + @routed_scaling_factor = args.routed_scaling_factor + @score_function = args.score_function + + self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.num_experts, bias: false) + self.expert_bias = if args.moe_router_enable_expert_bias + MLX::Core.zeros([args.num_experts]) + else + nil + end + end + + def call(x) + BailingMoe.group_expert_select( + gate_proj.call(x), + expert_bias, + @top_k, + @n_group, + @topk_group, + @routed_scaling_factor, + @norm_topk_prob, + @score_function + ) + end + end + + class BailingMoeSparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + self.switch_mlp = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.moe_intermediate_size, + args.num_experts, + bias: args.use_bias + ) + self.gate = BailingMoeGate.new(args) + + shared_dim = args.moe_shared_expert_intermediate_size || args.moe_intermediate_size + self.shared_experts = if args.num_shared_experts.to_i > 0 && args.moe_router_enable_shared_expert + BailingMoeMLP.new( + args, + intermediate_size: shared_dim * args.num_shared_experts + ) + end + end + + def call(x) + topk_idx, topk_weight = gate.call(x) + out = switch_mlp.call(x, topk_idx) + out = BailingMoe.aggregate_expert_outputs(out, topk_weight) + out = out + shared_experts.call(x) if respond_to?(:shared_experts) + out + end + end + + class BailingMoeDecoderLayer < MLX::NN::Module + def initialize(args, layer_idx:) + super() + self.attention = BailingMoeAttention.new(args) + self.mlp = if !args.num_experts.nil? && layer_idx >= args.first_k_dense_replace + BailingMoeSparseMoeBlock.new(args) + else + BailingMoeMLP.new(args) + end + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = attention.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class BailingMoeModel < MLX::NN::Module + def initialize(args) + super() + self.word_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) do |layer_idx| + BailingMoeDecoderLayer.new(args, layer_idx: layer_idx) + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = word_embeddings.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + norm.call(h) + end + + private + + def _create_attention_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if hidden.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + @norm_head = args.norm_head + self.model_type = args.model_type + self.model = BailingMoeModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.word_embeddings.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + mx = MLX::Core + result = weights.dup + + result.delete("lm_head.weight") if @args.tie_word_embeddings + + if @norm_head && result.key?("lm_head.weight") + w = result["lm_head.weight"] + dtype = w.dtype + w_fp32 = w.astype(mx.float32) + weight_norm = mx.sqrt(mx.sum(mx.square(w_fp32), 0, true)) + 1e-7 + result["lm_head.weight"] = (w_fp32 / weight_norm).astype(dtype) + end + + @args.num_hidden_layers.times do |layer_idx| + next if layer_idx < @args.first_k_dense_replace.to_i + + prefix = "model.layers.#{layer_idx}" + %w[gate_proj down_proj up_proj].each do |projection| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.mlp.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked) + end + end + + if result.key?("#{prefix}.mlp.gate.weight") + result["#{prefix}.mlp.gate.gate_proj.weight"] = result.delete("#{prefix}.mlp.gate.weight") + end + if result.key?("#{prefix}.mlp.gate.bias") + result["#{prefix}.mlp.gate.gate_proj.bias"] = result.delete("#{prefix}.mlp.gate.bias") + end + end + + result + end + + def quant_predicate + lambda do |path, _| + if path.to_s.end_with?("mlp.gate.gate_proj") + { group_size: 64, bits: 8 } + else + true + end + end + end + + def cast_predicate + lambda { |key| !key.to_s.include?("expert_bias") } + end + + def layers + model.layers + end + end + + Models.register("bailing_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/bailing_moe_linear.rb b/lib/mlx_lm/models/bailing_moe_linear.rb new file mode 100644 index 0000000..2a08583 --- /dev/null +++ b/lib/mlx_lm/models/bailing_moe_linear.rb @@ -0,0 +1,91 @@ +require_relative "bailing_moe" + +module MlxLm + module Models + module BailingMoeLinear + class ModelArgs < BailingMoe::ModelArgs + field :model_type, default: "bailing_moe_linear" + field :layer_group_size, default: nil + field :group_norm_size, default: nil + field :use_rmsnorm, default: nil + field :head_dim, default: nil + field :rope_traditional, default: false + + def to_bailing_moe_dict + { + "model_type" => @model_type, + "hidden_size" => @hidden_size, + "intermediate_size" => @intermediate_size, + "max_position_embeddings" => @max_position_embeddings, + "moe_intermediate_size" => @moe_intermediate_size, + "num_experts" => @num_experts, + "num_shared_experts" => @num_shared_experts, + "norm_topk_prob" => @norm_topk_prob, + "num_attention_heads" => @num_attention_heads, + "num_experts_per_tok" => @num_experts_per_tok, + "num_hidden_layers" => @num_hidden_layers, + "num_key_value_heads" => @num_key_value_heads, + "rms_norm_eps" => @rms_norm_eps, + "rope_theta" => @rope_theta, + "vocab_size" => @vocab_size, + "first_k_dense_replace" => @first_k_dense_replace, + "rope_scaling" => @rope_scaling, + "use_bias" => @use_bias, + "use_qkv_bias" => @use_qkv_bias, + "norm_head" => @norm_head, + "norm_softmax" => @norm_softmax, + "use_qk_norm" => @use_qk_norm, + "tie_word_embeddings" => @tie_word_embeddings, + "partial_rotary_factor" => @partial_rotary_factor, + "rotary_dim" => @rotary_dim, + "moe_router_enable_expert_bias" => @moe_router_enable_expert_bias, + "moe_router_enable_routed_scaling" => @moe_router_enable_routed_scaling, + "routed_scaling_factor" => @routed_scaling_factor, + "score_function" => @score_function, + "n_group" => @n_group, + "topk_group" => @topk_group, + "moe_shared_expert_intermediate_size" => @moe_shared_expert_intermediate_size, + "moe_router_enable_shared_expert" => @moe_router_enable_shared_expert, + } + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = BailingMoe::Model.new(BailingMoe::ModelArgs.from_dict(args.to_bailing_moe_dict)) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + wrapped_model.sanitize(weights) + end + + def layers + wrapped_model.layers + end + + def make_cache + return nil unless wrapped_model.respond_to?(:make_cache) + + wrapped_model.make_cache + end + + def cast_predicate + wrapped_model.cast_predicate + end + + def quant_predicate + wrapped_model.quant_predicate + end + end + + Models.register("bailing_moe_linear", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/bitlinear_layers.rb b/lib/mlx_lm/models/bitlinear_layers.rb new file mode 100644 index 0000000..485ab8f --- /dev/null +++ b/lib/mlx_lm/models/bitlinear_layers.rb @@ -0,0 +1,108 @@ +module MlxLm + module Models + class BitLinear < MLX::NN::Module + attr_reader :in_features, :out_features, :invert_weight_scales + + def initialize( + in_features, + out_features, + bias: true, + invert_weight_scales: false + ) + super() + mx = MLX::Core + + @in_features = in_features + @out_features = out_features + @invert_weight_scales = invert_weight_scales + + packed_out_features = (out_features + 3) / 4 + self.weight = mx.zeros([packed_out_features, in_features], mx.uint8) + self.weight_scale = mx.array([1.0], dtype: mx.float32) + self.bias = mx.zeros([out_features], mx.float32) if bias + end + + def call(x) + y = execute_matmul_kernel(x, weight) + state.key?("bias") ? MLX::Core.add(y, bias) : y + end + + def execute_matmul_kernel(x, packed_weights) + # TODO(phase1e): switch to a custom Metal kernel once MLX Ruby exposes + # a stable fast-kernel API equivalent to Python's mx.fast.metal_kernel. + execute_matmul_fallback(x, packed_weights) + end + + private + + def execute_matmul_fallback(x, packed_weights) + input_dims = x.shape[-1] + unless input_dims == @in_features + raise ArgumentError, "Expected input features #{@in_features}, got #{input_dims}" + end + + ternary_weight = unpack_packed_weights(packed_weights, x.dtype) + out = MLX::Core.matmul(x, ternary_weight.T) + + scale = weight_scale.astype(x.dtype) + scale = MLX::Core.divide(1.0, scale) if invert_weight_scales + MLX::Core.multiply(out, scale) + end + + def unpack_packed_weights(packed_weights, dtype) + mx = MLX::Core + + w0 = (mx.bitwise_and(packed_weights, 0x03).astype(dtype) - 1.0) + w1 = (mx.bitwise_and(mx.right_shift(packed_weights, 2), 0x03).astype(dtype) - 1.0) + w2 = (mx.bitwise_and(mx.right_shift(packed_weights, 4), 0x03).astype(dtype) - 1.0) + w3 = (mx.bitwise_and(mx.right_shift(packed_weights, 6), 0x03).astype(dtype) - 1.0) + + expanded = mx.concatenate([w0, w1, w2, w3], 0) + return expanded if expanded.shape[0] == @out_features + + keep = mx.arange(0, @out_features, 1, mx.int32) + mx.take(expanded, keep, 0) + end + end + + module_function + + def bitnet_quantize(model, quantization_config = {}) + modules_to_not_convert = Array(config_value(quantization_config, "modules_to_not_convert", [])) + .map(&:to_s) + invert_weight_scales = config_value(quantization_config, "linear_class", "").to_s != "autobitlinear" + + replacements = [] + leaves = model.leaf_modules + flat = MLX::Utils.tree_flatten(leaves, is_leaf: lambda { |node| node.is_a?(MLX::NN::Module) }) + + flat.each do |path, layer| + path_s = path.to_s + next if modules_to_not_convert.include?(path_s) + next unless layer.is_a?(MLX::NN::Linear) + + out_features, in_features = layer.weight.shape + replacements << [ + path_s, + BitLinear.new( + in_features, + out_features, + bias: layer.state.key?("bias"), + invert_weight_scales: invert_weight_scales + ), + ] + end + + model.update_modules(MLX::Utils.tree_unflatten(replacements)) unless replacements.empty? + model + end + + def config_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + private_class_method :config_value + end +end diff --git a/lib/mlx_lm/models/bitnet.rb b/lib/mlx_lm/models/bitnet.rb new file mode 100644 index 0000000..54b10c0 --- /dev/null +++ b/lib/mlx_lm/models/bitnet.rb @@ -0,0 +1,176 @@ +module MlxLm + module Models + module Bitnet + class ModelArgs < BaseModelArgs + field :model_type, default: "bitnet" + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 32 + field :intermediate_size, default: 11_008 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: nil + field :rms_norm_eps, default: 1e-6 + field :vocab_size, default: 32_000 + field :head_dim, default: nil + field :max_position_embeddings, default: nil + field :attention_bias, default: false + field :mlp_bias, default: false + field :rope_theta, default: 10_000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + bias = args.attention_bias + self.q_proj = BitLinear.new(dim, @n_heads * @head_dim, bias: bias) + self.k_proj = BitLinear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.v_proj = BitLinear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.o_proj = BitLinear.new(@n_heads * @head_dim, dim, bias: bias) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + self.attn_sub_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(attn_sub_norm.call(output)) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + hidden_dim = args.intermediate_size + bias = args.mlp_bias + + self.gate_proj = BitLinear.new(dim, hidden_dim, bias: bias) + self.down_proj = BitLinear.new(hidden_dim, dim, bias: bias) + self.up_proj = BitLinear.new(dim, hidden_dim, bias: bias) + self.ffn_sub_norm = MLX::NN::RMSNorm.new(hidden_dim, eps: args.rms_norm_eps) + end + + def call(x) + h = MLX::NN.relu2(gate_proj.call(x)) * up_proj.call(x) + down_proj.call(ffn_sub_norm.call(h)) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class BitnetModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = BitnetModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("bitnet", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/cache.rb b/lib/mlx_lm/models/cache.rb index 86e06cf..a10b878 100644 --- a/lib/mlx_lm/models/cache.rb +++ b/lib/mlx_lm/models/cache.rb @@ -1,7 +1,53 @@ module MlxLm + # Ruby constant names cannot begin with "_", so this is the _BaseCache abstraction. + class BaseCache + def state + [] + end + + def state=(value) + return if value.nil? || (value.respond_to?(:empty?) && value.empty?) + + raise ArgumentError, "This cache has no state but a state was set." + end + + def meta_state + "" + end + + def meta_state=(value) + return if value.nil? || (value.respond_to?(:empty?) && value.empty?) + + raise ArgumentError, "This cache has no meta_state but a meta_state was set." + end + + def is_trimmable + false + end + + def size + 0 + end + + def nbytes + raise NotImplementedError, "Cache sub-class must implement nbytes" + end + + def empty + raise NotImplementedError, "Cache sub-class must implement empty" + end + + def self.from_state(state, meta_state) + obj = allocate + obj.state = state + obj.meta_state = meta_state + obj + end + end + # Simple KV Cache — concatenates new K,V to existing. # Uses simple concatenation since MLX Ruby doesn't support in-place slice assignment. - class KVCache + class KVCache < BaseCache attr_reader :offset def initialize @@ -35,10 +81,101 @@ def state=(v) @keys, @values = v @offset = @keys ? @keys.shape[2] : 0 end + + def is_trimmable + true + end + + def trim(n) + return 0 if @keys.nil? || n <= 0 + + n = [@offset, n].min + @offset -= n + @keys = _slice_prefix(@keys, @offset) + @values = _slice_prefix(@values, @offset) + n + end + + def to_quantized(group_size: 64, bits: 4) + quant_cache = QuantizedKVCache.new(group_size: group_size, bits: bits) + return quant_cache if @keys.nil? + + mx = MLX::Core + qk = mx.quantize(@keys, group_size, bits) + qv = mx.quantize(@values, group_size, bits) + quant_cache.state = [qk, qv] + quant_cache + end + + def empty + @keys.nil? + end + + def nbytes + return 0 if @keys.nil? + + @keys.nbytes + @values.nbytes + end + + def self.merge(caches) + non_empty = caches.reject(&:empty) + return new if non_empty.empty? + + mx = MLX::Core + template_k, template_v = non_empty.first.state + target_len = non_empty.map(&:size).max + + rows_k = caches.map do |cache| + if cache.empty + shape = template_k.shape.dup + shape[0] = 1 + shape[2] = target_len + mx.zeros(shape, template_k.dtype) + else + keys, _values = cache.state + _left_pad_seq(keys, target_len) + end + end + + rows_v = caches.map do |cache| + if cache.empty + shape = template_v.shape.dup + shape[0] = 1 + shape[2] = target_len + mx.zeros(shape, template_v.dtype) + else + _keys, values = cache.state + _left_pad_seq(values, target_len) + end + end + + out = new + out.state = [mx.concatenate(rows_k, 0), mx.concatenate(rows_v, 0)] + out + end + + private + + def _slice_prefix(array, length) + return array if array.shape[2] == length + + MLX::Core.split(array, [length], 2)[0] + end + + def self._left_pad_seq(array, target_len) + return array if array.shape[2] == target_len + + mx = MLX::Core + pad = target_len - array.shape[2] + pad_shape = array.shape.dup + pad_shape[2] = pad + padding = mx.zeros(pad_shape, array.dtype) + mx.concatenate([padding, array], 2) + end end # Rotating KV Cache — fixed maximum size, old entries rotate out. - class RotatingKVCache + class RotatingKVCache < BaseCache attr_reader :offset def initialize(max_size:, keep: 0) @@ -85,6 +222,495 @@ def update_and_fetch(keys, values) return @keys, @values end + + def state + [@keys, @values] + end + + def state=(v) + @keys, @values = v + @offset = @keys ? @keys.shape[2] : 0 + end + + def meta_state + [@keep, @max_size, @offset] + end + + def meta_state=(v) + @keep, @max_size, @offset = v.map(&:to_i) + end + + def is_trimmable + @offset < @max_size + end + + def trim(n) + return 0 if @keys.nil? || n <= 0 + + n = [@offset, n].min + @offset -= n + keep_len = [@keys.shape[2], @offset].min + @keys = _slice_prefix(@keys, keep_len) + @values = _slice_prefix(@values, keep_len) + n + end + + def empty + @keys.nil? + end + + def nbytes + return 0 if @keys.nil? + + @keys.nbytes + @values.nbytes + end + + def self.merge(caches) + KVCache.merge(caches) + end + + private + + def _slice_prefix(array, length) + return array if array.shape[2] == length + + MLX::Core.split(array, [length], 2)[0] + end + end + + class QuantizedKVCache < BaseCache + attr_reader :offset, :group_size, :bits + + def initialize(group_size: 64, bits: 8) + @keys = nil + @values = nil + @offset = 0 + @group_size = group_size + @bits = bits + end + + def update_and_fetch(keys, values) + mx = MLX::Core + qk = mx.quantize(keys, @group_size, @bits) + qv = mx.quantize(values, @group_size, @bits) + + if @keys.nil? + @keys = qk + @values = qv + else + @keys = _concat_quantized(@keys, qk) + @values = _concat_quantized(@values, qv) + end + + @offset += keys.shape[2] + [@keys, @values] + end + + def size + @offset + end + + def state + [@keys, @values] + end + + def state=(v) + @keys, @values = v + @offset = @keys ? @keys[0].shape[2] : 0 + end + + def meta_state + [@offset, @group_size, @bits] + end + + def meta_state=(v) + @offset, @group_size, @bits = v.map(&:to_i) + end + + def is_trimmable + true + end + + def trim(n) + return 0 if @keys.nil? || n <= 0 + + n = [@offset, n].min + @offset -= n + @keys = _slice_quantized(@keys, @offset) + @values = _slice_quantized(@values, @offset) + n + end + + def empty + @keys.nil? + end + + def nbytes + return 0 if @keys.nil? + + _sum_nbytes(@keys) + _sum_nbytes(@values) + end + + private + + def _concat_quantized(lhs, rhs) + lhs.each_with_index.map do |item, i| + MLX::Core.concatenate([item, rhs[i]], 2) + end + end + + def _slice_quantized(tensors, length) + tensors.map do |item| + item.shape[2] == length ? item : MLX::Core.split(item, [length], 2)[0] + end + end + + def _sum_nbytes(tensors) + tensors.reduce(0) { |acc, t| acc + t.nbytes } + end + end + + class ArraysCache < BaseCache + attr_reader :cache + attr_accessor :left_padding, :lengths + + def initialize(size, left_padding: nil) + @cache = Array.new(size) + @left_padding = left_padding ? MLX::Core.array(left_padding) : nil + @lengths = nil + end + + def []=(idx, value) + @cache[idx] = value + end + + def [](idx) + @cache[idx] + end + + def state + @cache + end + + def state=(v) + @cache = v + end + + def meta_state + [@left_padding, @lengths] + end + + def meta_state=(v) + @left_padding, @lengths = v + end + + def filter(batch_indices) + idx = _indices_array(batch_indices) + @cache = @cache.map { |c| c.nil? ? nil : MLX::Core.take(c, idx, 0) } + end + + def extend(other) + @cache = @cache.zip(other.cache).map do |c, o| + if c.nil? + o + elsif o.nil? + c + else + MLX::Core.concatenate([c, o], 0) + end + end + + if @left_padding && other.left_padding + @left_padding = MLX::Core.concatenate([@left_padding, other.left_padding], 0) + end + if @lengths && other.lengths + @lengths = MLX::Core.concatenate([@lengths, other.lengths], 0) + end + end + + def extract(idx) + single = _indices_array([idx]) + out = ArraysCache.new(@cache.length) + out.state = @cache.map { |c| c.nil? ? nil : MLX::Core.take(c, single, 0) } + if @left_padding + out.left_padding = MLX::Core.take(@left_padding, single, 0) + end + if @lengths + out.lengths = MLX::Core.take(@lengths, single, 0) + end + out + end + + def prepare(lengths: nil, **_kwargs) + @lengths = lengths.nil? ? nil : MLX::Core.array(lengths) + end + + def finalize + @lengths = nil + @left_padding = nil + end + + def advance(n) + @lengths = MLX::Core.subtract(@lengths, n) if @lengths + @left_padding = MLX::Core.subtract(@left_padding, n) if @left_padding + end + + def make_mask(n) + mx = MLX::Core + pos = mx.arange(n).reshape([1, n]) + if @left_padding + mx.greater_equal(pos, @left_padding.reshape([@left_padding.shape[0], 1])) + elsif @lengths + mx.less(pos, @lengths.reshape([@lengths.shape[0], 1])) + else + nil + end + end + + def self.merge(caches) + mx = MLX::Core + n_state = caches[0].cache.length + batch = caches.length + out = new(n_state) + + n_state.times do |e| + init = caches.map { |c| c[e] }.find { |v| !v.nil? } + next if init.nil? + + shape = init.shape.dup + shape[0] = 1 + zero = mx.zeros(shape, init.dtype) + rows = caches.map { |c| c[e] || zero } + out[e] = mx.concatenate(rows, 0) + end + + left_padding_values = caches.map(&:left_padding).compact + out.left_padding = mx.concatenate(left_padding_values, 0) if left_padding_values.length == batch + + length_values = caches.map(&:lengths).compact + out.lengths = mx.concatenate(length_values, 0) if length_values.length == batch + + out + end + + def empty + @cache.empty? || @cache[0].nil? + end + + def nbytes + @cache.compact.reduce(0) { |acc, c| acc + c.nbytes } + end + + private + + def _indices_array(indices) + return indices if indices.is_a?(MLX::Core::Array) + + MLX::Core.array(indices, dtype: MLX::Core.int32) + end + end + + class ChunkedKVCache < BaseCache + attr_reader :offset, :chunk_size, :start_position + + def initialize(chunk_size) + @keys = nil + @values = nil + @offset = 0 + @chunk_size = chunk_size + @start_position = 0 + end + + def maybe_trim_front + return if @keys.nil? || @keys.shape[2] < @chunk_size + + excess = @keys.shape[2] - @chunk_size + return if excess <= 0 + + @start_position += excess + @keys = _slice_tail(@keys, @chunk_size) + @values = _slice_tail(@values, @chunk_size) + end + + def update_and_fetch(keys, values) + mx = MLX::Core + if @keys.nil? + @keys = keys + @values = values + else + @keys = mx.concatenate([@keys, keys], 2) + @values = mx.concatenate([@values, values], 2) + end + @offset += keys.shape[2] + [@keys, @values] + end + + def size + @offset - @start_position + end + + def state + [@keys, @values] + end + + def state=(v) + @keys, @values = v + @offset = @keys ? @keys.shape[2] : 0 + end + + def meta_state + [@chunk_size, @start_position] + end + + def meta_state=(v) + @chunk_size, @start_position = v.map(&:to_i) + end + + def is_trimmable + true + end + + def trim(n) + return 0 if @keys.nil? || n <= 0 + + available = @offset - @start_position + n = [available, n].min + @offset -= n + keep_len = @offset - @start_position + @keys = _slice_prefix(@keys, keep_len) + @values = _slice_prefix(@values, keep_len) + n + end + + def empty + @keys.nil? + end + + def nbytes + return 0 if @keys.nil? + + @keys.nbytes + @values.nbytes + end + + private + + def _slice_prefix(array, length) + return array if array.shape[2] == length + + MLX::Core.split(array, [length], 2)[0] + end + + def _slice_tail(array, length) + return array if array.shape[2] == length + + split_idx = array.shape[2] - length + MLX::Core.split(array, [split_idx], 2)[1] + end + end + + class CacheList < BaseCache + attr_reader :caches + + def initialize(*caches) + @caches = caches + end + + def [](idx) + @caches[idx] + end + + def is_trimmable + @caches.all?(&:is_trimmable) + end + + def trim(n) + trimmed = 0 + @caches.each do |cache| + trimmed = cache.trim(n) + end + trimmed + end + + def state + @caches.map(&:state) + end + + def state=(v) + @caches.zip(v).each do |cache, cache_state| + cache.state = cache_state + end + end + + def meta_state + [ + @caches.map { |c| c.class.name.split("::").last }, + @caches.map(&:meta_state), + ] + end + + def meta_state=(v) + _classes, states = v + @caches.zip(states).each do |cache, cache_state| + cache.meta_state = cache_state + end + end + + def filter(batch_indices) + @caches.each { |cache| cache.filter(batch_indices) if cache.respond_to?(:filter) } + end + + def extend(other) + @caches.zip(other.caches).each do |cache, other_cache| + next unless cache.class.instance_method(:extend).owner != Object + + cache.extend(other_cache) + end + end + + def self.merge(caches) + merged = caches[0].caches.each_index.map do |i| + batch = caches.map { |c| c.caches[i] } + unless batch[0].class.respond_to?(:merge) + raise NotImplementedError, "#{batch[0].class} does not implement .merge" + end + + batch[0].class.merge(batch) + end + new(*merged) + end + + def extract(idx) + CacheList.new(*@caches.map { |cache| cache.extract(idx) }) + end + + def prepare(**kwargs) + @caches.each { |cache| cache.prepare(**kwargs) if cache.respond_to?(:prepare) } + end + + def finalize + @caches.each { |cache| cache.finalize if cache.respond_to?(:finalize) } + end + + def size + @caches.map(&:size).max || 0 + end + + def empty + @caches.empty? || @caches[0].empty + end + + def nbytes + @caches.reduce(0) { |acc, cache| acc + cache.nbytes } + end + + def self.from_state(state, meta_state) + classes, metas = meta_state + caches = state.each_with_index.map do |sub_state, i| + klass = MlxLm.const_get(classes[i]) + klass.from_state(sub_state, metas[i]) + end + new(*caches) + end end module Cache diff --git a/lib/mlx_lm/models/cohere2.rb b/lib/mlx_lm/models/cohere2.rb new file mode 100644 index 0000000..f78a99d --- /dev/null +++ b/lib/mlx_lm/models/cohere2.rb @@ -0,0 +1,224 @@ +module MlxLm + module Models + module Cohere2 + class ModelArgs < BaseModelArgs + field :model_type, default: "cohere2" + field :hidden_size, default: 4096 + field :head_dim, default: 128 + field :num_hidden_layers, default: 32 + field :intermediate_size, default: 14336 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: 8 + field :rope_theta, default: 50_000.0 + field :vocab_size, default: 256000 + field :layer_norm_eps, default: 1e-5 + field :logit_scale, default: 0.0625 + field :attention_bias, default: false + field :layer_norm_bias, default: false + field :sliding_window, default: 4096 + field :sliding_window_pattern, default: 4 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args, layer_idx) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + if (@head_dim * @n_heads) != dim + raise ArgumentError, + "hidden_size must equal num_attention_heads * head_dim (got #{dim} and #{@n_heads} * #{@head_dim})" + end + @scale = @head_dim**(-0.5) + + bias = args.attention_bias + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias) + + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta) + @use_sliding_window = ((layer_idx + 1) % args.sliding_window_pattern) != 0 + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if @use_sliding_window + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + end + + keys, values = cache.update_and_fetch(keys, values) if cache + + sdpa_type = queries.dtype == mx.float16 ? mx.float32 : queries.dtype + output = mx.scaled_dot_product_attention( + queries.astype(sdpa_type), + keys, + values, + @scale, + mask + ).astype(queries.dtype) + + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.self_attn = Attention.new(args, layer_idx) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size) + self.input_layernorm = MLX::NN::LayerNorm.new( + args.hidden_size, + eps: args.layer_norm_eps, + bias: args.layer_norm_bias + ) + end + + def call(x, mask: nil, cache: nil) + h = input_layernorm.call(x) + attn_h = self_attn.call(h, mask: mask, cache: cache) + ff_h = mlp.call(h) + attn_h + ff_h + x + end + end + + class Cohere2Model < MLX::NN::Module + def initialize(args) + super() + @args = args + @window_size = args.sliding_window + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |i| TransformerBlock.new(args, i) } + self.norm = MLX::NN::LayerNorm.new( + args.hidden_size, + eps: args.layer_norm_eps, + bias: args.layer_norm_bias + ) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + pattern = @args.sliding_window_pattern + full_mask = _create_attention_mask(h, layer_cache[pattern - 1]) + swa_mask = _create_attention_mask(h, layer_cache[0], window_size: @window_size) + + layers.each_with_index do |layer, i| + is_global = (i % pattern) == (pattern - 1) + mask = is_global ? full_mask : swa_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache, window_size: nil) + n = h.shape[1] + offset = cache ? cache.offset : 0 + + if window_size + if cache || n > window_size + return _create_causal_mask(n, offset, window_size) + end + return nil if n == 1 + + return "causal" + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset, window_size = nil) + mx = MLX::Core + rinds = mx.arange(offset + n) + linds = offset.zero? ? rinds : mx.arange(offset, offset + n) + + linds = mx.expand_dims(linds, 1) + rinds = mx.expand_dims(rinds, 0) + mask = mx.greater_equal(linds, rinds) + + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + + mask + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model = Cohere2Model.new(args) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + out = model.embed_tokens.as_linear(out) + out * @args.logit_scale + end + + def make_cache + caches = [] + @args.num_hidden_layers.times do |i| + is_global = (i % @args.sliding_window_pattern) == (@args.sliding_window_pattern - 1) + if is_global + caches << MlxLm::KVCache.new + else + caches << MlxLm::RotatingKVCache.new(max_size: @args.sliding_window, keep: 0) + end + end + caches + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + end + + def layers + model.layers + end + end + + Models.register("cohere2", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/dbrx.rb b/lib/mlx_lm/models/dbrx.rb new file mode 100644 index 0000000..590799d --- /dev/null +++ b/lib/mlx_lm/models/dbrx.rb @@ -0,0 +1,286 @@ +require_relative "activations" + +module MlxLm + module Models + module Dbrx + class ModelArgs < BaseModelArgs + field :model_type, default: "dbrx" + field :vocab_size, default: 32_000 + field :d_model, default: 6144 + field :ffn_config, default: {} + field :attn_config, default: {} + field :n_layers, default: 40 + field :n_heads, default: 48 + + def initialize(**kwargs) + super + @ffn_config ||= {} + @attn_config ||= {} + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + @num_heads = args.n_heads + @d_model = args.d_model + @head_dim = @d_model / args.n_heads + @num_key_value_heads = _attn_value(args.attn_config, "kv_n_heads", args.n_heads).to_i + @clip_qkv = _attn_value(args.attn_config, "clip_qkv", 8.0).to_f + @rope_theta = _attn_value(args.attn_config, "rope_theta", 10_000.0).to_f + @scale = @head_dim**(-0.5) + + self.wqkv = MLX::NN::Linear.new( + args.d_model, + (@num_key_value_heads * 2 + @num_heads) * @head_dim, + bias: false + ) + self.out_proj = MLX::NN::Linear.new(args.d_model, args.d_model, bias: false) + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: @rope_theta) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = wqkv.call(x) + qkv = mx.clip(qkv, -@clip_qkv, @clip_qkv) + + splits = [@d_model, @d_model + @head_dim * @num_key_value_heads] + queries, keys, values = mx.split(qkv, splits, -1) + + queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @d_model]) + out_proj.call(output) + end + + private + + def _attn_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class NormAttnNorm < MLX::NN::Module + def initialize(args) + super() + self.norm_1 = MLX::NN::LayerNorm.new(args.d_model, bias: false) + self.norm_2 = MLX::NN::LayerNorm.new(args.d_model, bias: false) + self.attn = Attention.new(args) + end + + def call(x, mask: nil, cache: nil) + h = attn.call(norm_1.call(x), mask: mask, cache: cache) + residual = x + h + [residual, norm_2.call(residual)] + end + end + + class MLP < MLX::NN::Module + def initialize(d_model, ffn_dim) + super() + self.v1 = MLX::NN::Linear.new(d_model, ffn_dim, bias: false) + self.w1 = MLX::NN::Linear.new(d_model, ffn_dim, bias: false) + self.w2 = MLX::NN::Linear.new(ffn_dim, d_model, bias: false) + end + + def call(x) + w2.call(Activations.swiglu(w1.call(x), v1.call(x))) + end + end + + class Router < MLX::NN::Module + def initialize(d_model, num_experts) + super() + self.layer = MLX::NN::Linear.new(d_model, num_experts, bias: false) + end + + def call(x) + layer.call(x) + end + end + + class SparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + @d_model = args.d_model + @ffn_dim = _ffn_value(args.ffn_config, "ffn_hidden_size", args.d_model * 4).to_i + @num_experts = _ffn_value(args.ffn_config, "moe_num_experts", 1).to_i + @num_experts_per_tok = _ffn_value(args.ffn_config, "moe_top_k", 1).to_i + + self.router = Router.new(@d_model, @num_experts) + self.experts = Array.new(@num_experts) { MLP.new(@d_model, @ffn_dim) } + end + + def call(x) + mx = MLX::Core + + top_k = [[@num_experts_per_tok, 1].max, @num_experts].min + orig_shape = x.shape + token_count = orig_shape[0...-1].reduce(1, :*) + flat_x = x.reshape([token_count, orig_shape[-1]]) + + gates = router.call(flat_x) + gates = mx.softmax(gates.astype(mx.float32), -1) + + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, top_k - 1, -1)) + take_ids = mx.array((0...top_k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + scores = mx.take_along_axis(gates, inds, -1) + scores = scores / mx.expand_dims(mx.sum(scores, -1), -1) + scores = scores.astype(flat_x.dtype) + + expert_ids = inds.to_a + expert_scores = scores.to_a + + outputs = Array.new(flat_x.shape[0]) do |token_idx| + token_ids = mx.array([token_idx], dtype: mx.int32) + token_state = mx.squeeze(mx.take(flat_x, token_ids, 0), 0) + + token_out = nil + expert_ids[token_idx].each_with_index do |expert_idx, score_idx| + expert_out = experts[expert_idx.to_i].call(token_state) + weighted = expert_out * expert_scores[token_idx][score_idx].to_f + token_out = token_out.nil? ? weighted : (token_out + weighted) + end + + token_out + end + + mx.stack(outputs, 0).reshape(orig_shape) + end + + private + + def _ffn_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.ffn = SparseMoeBlock.new(args) + self.norm_attn_norm = NormAttnNorm.new(args) + end + + def call(x, mask: nil, cache: nil) + residual, hidden = norm_attn_norm.call(x, mask: mask, cache: cache) + ffn.call(hidden) + residual + end + end + + class DbrxModel < MLX::NN::Module + def initialize(args) + super() + self.wte = MLX::NN::Embedding.new(args.vocab_size, args.d_model) + self.blocks = Array.new(args.n_layers) { DecoderLayer.new(args) } + self.norm_f = MLX::NN::LayerNorm.new(args.d_model, bias: false) + end + + def call(inputs, cache: nil) + h = wte.call(inputs) + layer_cache = cache || [nil] * blocks.length + mask = _create_attention_mask(h, layer_cache[0]) + + blocks.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm_f.call(h) + end + + private + + def _create_attention_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if hidden.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.transformer = DbrxModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.d_model, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = transformer.call(inputs, cache: cache) + lm_head.call(out) + end + + def layers + transformer.blocks + end + + def sanitize(weights) + mx = MLX::Core + num_experts = _ffn_value(@args.ffn_config, "moe_num_experts", 0).to_i + return weights if num_experts <= 0 + + pattern = "experts.mlp" + sanitized = {} + + weights.each do |key, value| + unless key.include?(pattern) + sanitized[key] = value + next + end + + split_weights = mx.split(value, num_experts, 0) + split_weights.each_with_index do |slice, expert_idx| + expert_key = _expert_weight_key(key, expert_idx) + if key.end_with?("w2") || key.end_with?("w2.weight") + slice = slice.transpose([1, 0]) + end + sanitized[expert_key] = slice + end + end + + sanitized + end + + private + + def _expert_weight_key(key, expert_idx) + base = key.end_with?(".weight") ? key.sub(/\.weight\z/, "") : key + "#{base.sub('.mlp', ".#{expert_idx}")}.weight" + end + + def _ffn_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + Models.register("dbrx", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/deepseek.rb b/lib/mlx_lm/models/deepseek.rb index b1941ff..5fe4626 100644 --- a/lib/mlx_lm/models/deepseek.rb +++ b/lib/mlx_lm/models/deepseek.rb @@ -92,28 +92,34 @@ def call(x) end class MoEGate < MLX::NN::Module - def initialize(dim, n_routed_experts, num_experts_per_tok) + def initialize(args) super() - @num_experts_per_tok = num_experts_per_tok - self.weight = MLX::Core.zeros([n_routed_experts, dim]) + @top_k = args.num_experts_per_tok + self.weight = MLX::Core.zeros([args.n_routed_experts, args.hidden_size]) end def call(x) mx = MLX::Core gates = mx.matmul(x, mx.transpose(weight)) - [gates, @num_experts_per_tok] + scores = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype) + k = @top_k + inds = mx.stop_gradient(mx.argpartition(scores * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + scores = mx.take_along_axis(scores, inds, -1) + [inds, scores] end end class DeepseekMoE < MLX::NN::Module def initialize(args) super() + @n_shared_experts = args.n_shared_experts dim = args.hidden_size moe_dim = args.moe_intermediate_size - @num_experts_per_tok = args.num_experts_per_tok - self.gate = MoEGate.new(dim, args.n_routed_experts, args.num_experts_per_tok) - self.experts = Array.new(args.n_routed_experts) { DeepseekMLP.new(dim, moe_dim) } + self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, moe_dim, args.n_routed_experts) + self.gate = MoEGate.new(args) if args.n_shared_experts && args.n_shared_experts > 0 shared_dim = moe_dim * args.n_shared_experts @@ -123,40 +129,11 @@ def initialize(args) def call(x) mx = MLX::Core - ne = @num_experts_per_tok - orig_shape = x.shape - dims = x.shape[-1] - tokens = x.size / dims - x_flat = x.reshape([tokens, dims]) - - gates, _ne = gate.call(x_flat) - inds = mx.argpartition(gates * -1.0, ne - 1, -1) - take_ids = mx.array((0...ne).to_a, mx.int32) - inds = mx.take(inds, take_ids, 1) - - scores = mx.take_along_axis(gates, inds, -1) - scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype) - - inds_list = inds.tolist - y_rows = [] - (0...x_flat.shape[0]).each do |i| - xt = x_flat[i] - selected = [inds_list[i]].flatten - expert_outs = selected.map { |eidx| - mx.expand_dims(experts[eidx].call(xt), 0) - } - yt = mx.concatenate(expert_outs, 0) - st = scores[i] - weighted = yt * mx.expand_dims(st, -1) - summed = mx.sum(weighted, 0) - y_rows << mx.expand_dims(summed, 0) - end - - y = mx.concatenate(y_rows, 0) - y = y.reshape(orig_shape) + inds, scores = gate.call(x) + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores, -1), -2) - # Add shared experts if present - if respond_to?(:shared_experts) + if @n_shared_experts && @n_shared_experts > 0 y = y + shared_experts.call(x) end @@ -229,7 +206,26 @@ def call(inputs, cache: nil) end def sanitize(weights) - weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + mx = MLX::Core + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + + # Convert per-expert weights to stacked SwitchGLU format + @args.num_hidden_layers.times do |l| + prefix = "model.layers.#{l}" + ["gate_proj", "down_proj", "up_proj"].each do |m| + ["weight", "scales", "biases"].each do |k| + key0 = "#{prefix}.mlp.experts.0.#{m}.#{k}" + if result.key?(key0) + to_join = (0...@args.n_routed_experts).map { |e| + result.delete("#{prefix}.mlp.experts.#{e}.#{m}.#{k}") + } + result["#{prefix}.mlp.switch_mlp.#{m}.#{k}"] = mx.stack(to_join) + end + end + end + end + + result end def layers diff --git a/lib/mlx_lm/models/deepseek_v2.rb b/lib/mlx_lm/models/deepseek_v2.rb new file mode 100644 index 0000000..01e2f9a --- /dev/null +++ b/lib/mlx_lm/models/deepseek_v2.rb @@ -0,0 +1,108 @@ +require_relative "deepseek" + +module MlxLm + module Models + module DeepseekV2 + class ModelArgs < BaseModelArgs + field :model_type, default: "deepseek_v2" + field :vocab_size, default: 102_400 + field :hidden_size, default: 4096 + field :intermediate_size, default: 11_008 + field :moe_intermediate_size, default: 1407 + field :num_hidden_layers, default: 30 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: 32 + field :n_shared_experts, default: nil + field :n_routed_experts, default: nil + field :routed_scaling_factor, default: 1.0 + field :kv_lora_rank, default: 512 + field :q_lora_rank, default: 1536 + field :qk_rope_head_dim, default: 64 + field :v_head_dim, default: 128 + field :qk_nope_head_dim, default: 128 + field :topk_method, default: "gready" + field :n_group, default: nil + field :topk_group, default: nil + field :num_experts_per_tok, default: nil + field :moe_layer_freq, default: 1 + field :first_k_dense_replace, default: 0 + field :max_position_embeddings, default: 2048 + field :rms_norm_eps, default: 1e-6 + field :rope_theta, default: 10_000.0 + field :rope_scaling, default: nil + field :attention_bias, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Model < DeepSeek::Model + def initialize(args) + super(DeepSeek::ModelArgs.from_dict(_to_deepseek_config(args))) + self.model_type = args.model_type + end + + def sanitize(weights) + _stack_expert_weights(weights.dup) + end + + private + + def _to_deepseek_config(args) + { + "model_type" => args.model_type, + "vocab_size" => args.vocab_size, + "hidden_size" => args.hidden_size, + "intermediate_size" => args.intermediate_size, + "moe_intermediate_size" => args.moe_intermediate_size, + "num_hidden_layers" => args.num_hidden_layers, + "num_attention_heads" => args.num_attention_heads, + "num_key_value_heads" => args.num_key_value_heads, + "n_shared_experts" => args.n_shared_experts, + "n_routed_experts" => args.n_routed_experts, + "num_experts_per_tok" => args.num_experts_per_tok, + "moe_layer_freq" => args.moe_layer_freq, + "first_k_dense_replace" => args.first_k_dense_replace, + "max_position_embeddings" => args.max_position_embeddings, + "rms_norm_eps" => args.rms_norm_eps, + "rope_theta" => args.rope_theta, + "rope_scaling" => args.rope_scaling, + "attention_bias" => args.attention_bias, + } + end + + def _stack_expert_weights(weights) + num_experts = @args.n_routed_experts.to_i + return weights if num_experts <= 0 + + mx = MLX::Core + projections = %w[gate_proj down_proj up_proj].freeze + params = %w[weight scales biases].freeze + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.mlp" + projections.each do |projection| + params.each do |param| + expert_keys = (0...num_experts).map do |expert_idx| + "#{prefix}.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |key| weights.key?(key) } + + stacked = expert_keys.map { |key| weights.delete(key) } + weights["#{prefix}.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked) + end + end + end + + weights + end + end + + Models.register("deepseek_v2", Model, ModelArgs) + end + + DeepSeekV2 = DeepseekV2 unless const_defined?(:DeepSeekV2) + end +end diff --git a/lib/mlx_lm/models/deepseek_v3.rb b/lib/mlx_lm/models/deepseek_v3.rb new file mode 100644 index 0000000..5ce58c3 --- /dev/null +++ b/lib/mlx_lm/models/deepseek_v3.rb @@ -0,0 +1,34 @@ +require_relative "deepseek_v2" + +module MlxLm + module Models + module DeepseekV3 + class ModelArgs < DeepseekV2::ModelArgs + field :model_type, default: "deepseek_v3" + field :topk_method, default: "noaux_tc" + field :scoring_func, default: "sigmoid" + field :norm_topk_prob, default: true + field :n_group, default: 1 + field :topk_group, default: 1 + field :num_experts_per_tok, default: 1 + end + + class Model < DeepseekV2::Model + def sanitize(weights) + super(weights).reject do |key, _| + key_name = key.to_s + key_name.start_with?("model.layers.61") || key_name.include?("rotary_emb.inv_freq") + end + end + + def cast_predicate + ->(key) { !key.to_s.include?("e_score_correction_bias") } + end + end + + Models.register("deepseek_v3", Model, ModelArgs) + end + + DeepSeekV3 = DeepseekV3 unless const_defined?(:DeepSeekV3) + end +end diff --git a/lib/mlx_lm/models/deepseek_v32.rb b/lib/mlx_lm/models/deepseek_v32.rb new file mode 100644 index 0000000..4960928 --- /dev/null +++ b/lib/mlx_lm/models/deepseek_v32.rb @@ -0,0 +1,45 @@ +require_relative "deepseek" + +module MlxLm + module Models + module DeepseekV32 + class ModelArgs < DeepSeek::ModelArgs + field :model_type, default: "deepseek_v32" + field :index_head_dim, default: 128 + field :index_n_heads, default: 64 + field :index_topk, default: 2048 + field :routed_scaling_factor, default: 1.0 + field :kv_lora_rank, default: 512 + field :q_lora_rank, default: 1536 + field :qk_rope_head_dim, default: 64 + field :v_head_dim, default: 128 + field :qk_nope_head_dim, default: 128 + field :topk_method, default: "noaux_tc" + field :scoring_func, default: "sigmoid" + field :norm_topk_prob, default: true + field :n_group, default: 1 + field :topk_group, default: 1 + end + + class Model < DeepSeek::Model + def sanitize(weights) + sanitized = super(weights) + drop_mtp_layer_weights(sanitized) + end + + private + + def drop_mtp_layer_weights(weights) + cutoff = @args.num_hidden_layers.to_i + + weights.reject do |key, _| + match = key.match(/\Amodel\.layers\.(\d+)(?:\.|\z)/) + match && match[1].to_i >= cutoff + end + end + end + + Models.register("deepseek_v32", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/dots1.rb b/lib/mlx_lm/models/dots1.rb new file mode 100644 index 0000000..ffb30ad --- /dev/null +++ b/lib/mlx_lm/models/dots1.rb @@ -0,0 +1,292 @@ +require_relative "activations" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module Dots1 + class ModelArgs < BaseModelArgs + field :model_type, default: "dots1" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :max_position_embeddings, default: nil + field :num_key_value_heads + field :first_k_dense_replace + field :moe_intermediate_size + field :n_routed_experts + field :n_shared_experts + field :norm_topk_prob + field :num_experts_per_tok + field :rope_theta + field :routed_scaling_factor + field :head_dim, default: nil + field :scoring_func, default: "noaux_tc" + field :n_group, default: 1 + field :topk_group, default: 1 + field :attention_bias, default: false + field :mlp_bias, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @n_group ||= 1 + @topk_group ||= 1 + end + end + + class Dots1Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class Dots1TopkRouter < MLX::NN::Module + def initialize(args) + super() + mx = MLX::Core + @top_k = args.num_experts_per_tok + @norm_topk_prob = args.norm_topk_prob + @n_routed_experts = args.n_routed_experts + @routed_scaling_factor = args.routed_scaling_factor + @n_group = args.n_group + @topk_group = args.topk_group + self.weight = mx.zeros([@n_routed_experts, args.hidden_size]).astype(mx.float32) + self.e_score_correction_bias = mx.zeros([@n_routed_experts]).astype(mx.float32) + end + + def call(x) + mx = MLX::Core + + gates = mx.matmul(x, mx.transpose(weight)) + scores = mx.sigmoid(gates.astype(mx.float32)) + scores = scores + e_score_correction_bias.reshape([1, 1, @n_routed_experts]) + + k = [[@top_k.to_i, 1].max, @n_routed_experts].min + inds = mx.stop_gradient(mx.argpartition(scores * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + selected_scores = mx.take_along_axis(mx.sigmoid(gates.astype(mx.float32)), inds, -1) + if k > 1 && @norm_topk_prob + denom = mx.expand_dims(mx.sum(selected_scores, -1), -1) + selected_scores = selected_scores / denom + end + selected_scores = selected_scores * @routed_scaling_factor.to_f + + [inds, selected_scores.astype(gates.dtype)] + end + end + + class Dots1MLP < MLX::NN::Module + def initialize(args, hidden_size: nil, intermediate_size: nil) + super() + @hidden_size = hidden_size || args.hidden_size + @intermediate_size = intermediate_size || args.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: args.mlp_bias) + self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: args.mlp_bias) + self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: args.mlp_bias) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class Dots1MoE < MLX::NN::Module + def initialize(args) + super() + @n_shared_experts = args.n_shared_experts + self.experts = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.moe_intermediate_size, + args.n_routed_experts, + bias: args.mlp_bias + ) + self.gate = Dots1TopkRouter.new(args) + + if @n_shared_experts && @n_shared_experts > 0 + self.shared_experts = Dots1MLP.new( + args, + intermediate_size: args.moe_intermediate_size * @n_shared_experts + ) + end + end + + def call(x) + mx = MLX::Core + inds, scores = gate.call(x) + y = experts.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores.astype(y.dtype), -1), -2) + + y = y + shared_experts.call(x) if @n_shared_experts && @n_shared_experts > 0 + y + end + end + + class Dots1DecoderLayer < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.self_attn = Dots1Attention.new(args) + if layer_idx >= args.first_k_dense_replace + self.mlp = Dots1MoE.new(args) + else + self.mlp = Dots1MLP.new(args) + end + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class Dots1Model < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |layer_idx| Dots1DecoderLayer.new(args, layer_idx) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Dots1Model.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.dup + result.delete("lm_head.weight") if @args.tie_word_embeddings + + experts_count = @args.n_routed_experts.to_i + if experts_count > 0 + mx = MLX::Core + @args.num_hidden_layers.times do |layer_idx| + next if layer_idx < @args.first_k_dense_replace + + prefix = "model.layers.#{layer_idx}.mlp" + %w[gate_proj down_proj up_proj].each do |projection| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.experts.0.#{projection}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...experts_count).map do |expert_idx| + "#{prefix}.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.experts.#{projection}.#{param}"] = mx.stack(stacked) + end + end + end + end + + result.reject { |k, _| k.include?("rotary_emb.inv_freq") } + end + + def layers + model.layers + end + end + + Models.register("dots1", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/ernie4_5.rb b/lib/mlx_lm/models/ernie4_5.rb new file mode 100644 index 0000000..89fb739 --- /dev/null +++ b/lib/mlx_lm/models/ernie4_5.rb @@ -0,0 +1,165 @@ +module MlxLm + module Models + module Ernie45 + class ModelArgs < BaseModelArgs + field :hidden_size + field :intermediate_size + field :model_type, default: "ernie4_5" + field :max_position_embeddings + field :num_attention_heads + field :num_key_value_heads + field :head_dim, default: nil + field :num_hidden_layers + field :rms_norm_eps + field :vocab_size + field :rope_theta + field :use_bias, default: false + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.use_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.use_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.use_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.use_bias) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + true, + nil, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim, use_bias: false) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: use_bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: use_bias) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: use_bias) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size, use_bias: args.use_bias) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class Ernie45Model < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Ernie45Model.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + end + + Models.register("ernie4_5", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/ernie4_5_moe.rb b/lib/mlx_lm/models/ernie4_5_moe.rb new file mode 100644 index 0000000..43f1720 --- /dev/null +++ b/lib/mlx_lm/models/ernie4_5_moe.rb @@ -0,0 +1,97 @@ +require_relative "activations" +require_relative "rope_utils" +require_relative "ernie4_5" + +module MlxLm + module Models + module Ernie45Moe + class ModelArgs < Ernie45::ModelArgs + field :model_type, default: "ernie4_5_moe" + field :moe_num_experts, default: 0 + field :moe_layer_start_index, default: 0 + field :moe_intermediate_size, default: 0 + field :moe_capacity, default: [] + field :moe_k, default: 1 + field :moe_layer_interval, default: 1 + field :moe_use_aux_free, default: false + field :moe_num_shared_experts, default: 0 + field :moe_layer_end_index, default: nil + field :moe_gate_act, default: "softmax" + + def initialize(**kwargs) + super + @moe_capacity = Array(@moe_capacity).dup + end + end + + class Model < Ernie45::Model + REMOVE_PATTERNS = [ + "mtp_block.", + "mtp_linear_proj.", + "mtp_hidden_norm.", + "mtp_emb_norm.", + "e_score_correction_bias", + ].freeze + + EXPERT_PROJ_NAMES = %w[gate_proj down_proj up_proj].freeze + + def sanitize(weights) + result = weights.reject do |key, _| + REMOVE_PATTERNS.any? { |pattern| key.include?(pattern) } + end + + stack_expert_weights!(result) + end + + private + + def stack_expert_weights!(weights) + mx = MLX::Core + num_experts = @args.moe_num_experts.to_i + return weights if num_experts <= 0 + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.mlp" + + EXPERT_PROJ_NAMES.each do |proj_name| + expert_weights = pop_complete_expert_weights(weights, prefix, proj_name, num_experts) + next unless expert_weights + + weights["#{prefix}.switch_mlp.#{proj_name}.weight"] = mx.stack(expert_weights) + end + end + + weights + end + + def pop_complete_expert_weights(weights, prefix, proj_name, num_experts) + first_key = expert_weight_key(prefix, 0, proj_name) + return nil unless weights.key?(first_key) + + popped = [] + num_experts.times do |expert_idx| + key = expert_weight_key(prefix, expert_idx, proj_name) + unless weights.key?(key) + restore_popped_weights!(weights, prefix, proj_name, popped) + return nil + end + popped << weights.delete(key) + end + popped + end + + def restore_popped_weights!(weights, prefix, proj_name, popped) + popped.each_with_index do |tensor, idx| + weights[expert_weight_key(prefix, idx, proj_name)] = tensor + end + end + + def expert_weight_key(prefix, expert_idx, proj_name) + "#{prefix}.experts.#{expert_idx}.#{proj_name}.weight" + end + end + + Models.register("ernie4_5_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/exaone.rb b/lib/mlx_lm/models/exaone.rb new file mode 100644 index 0000000..eaaaf2b --- /dev/null +++ b/lib/mlx_lm/models/exaone.rb @@ -0,0 +1,169 @@ +module MlxLm + module Models + module Exaone + class ModelArgs < BaseModelArgs + field :model_type + field :hidden_size + field :num_layers + field :intermediate_size + field :num_attention_heads + field :vocab_size + field :rope_theta + field :layer_norm_epsilon + field :num_key_value_heads + field :head_dim, default: nil + field :max_position_embeddings, default: nil + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: true + field :attention_bias, default: false + field :mlp_bias, default: false + + def initialize(**kwargs) + super + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class AttentionModule < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + bias = args.attention_bias + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.out_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, d = x.shape + + q = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + k = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + v = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + q = rope.call(q, offset: cache.offset) + k = rope.call(k, offset: cache.offset) + k, v = cache.update_and_fetch(k, v) + else + q = rope.call(q) + k = rope.call(k) + end + + out = mx.scaled_dot_product_attention(q, k, v, @scale, mask) + out = out.transpose([0, 2, 1, 3]).reshape([b, l, d]) + out_proj.call(out) + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + self.attention = AttentionModule.new(args) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + hidden_dim = args.intermediate_size + bias = args.mlp_bias + self.c_fc_0 = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + self.c_fc_1 = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + self.c_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias) + end + + def call(x) + c_proj.call(MlxLm::Models::Activations.swiglu(c_fc_0.call(x), c_fc_1.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.ln_1 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + self.attn = Attention.new(args) + self.ln_2 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + self.mlp = MLP.new(args) + end + + def call(x, mask: nil, cache: nil) + h = x + attn.attention.call(ln_1.call(x), mask: mask, cache: cache) + h + mlp.call(ln_2.call(h)) + end + end + + class ExaoneModel < MLX::NN::Module + def initialize(args) + super() + self.wte = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.h = Array.new(args.num_layers) { TransformerBlock.new(args) } + self.ln_f = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + end + + def call(inputs, cache: nil) + hidden = wte.call(inputs) + layer_cache = cache || [nil] * h.length + + mask = nil + mask = "causal" if hidden.shape[1] > 1 + + h.each_with_index do |layer, i| + hidden = layer.call(hidden, mask: mask, cache: layer_cache[i]) + end + + ln_f.call(hidden) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.transformer = ExaoneModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = transformer.call(inputs, cache: cache) + if @args.tie_word_embeddings + transformer.wte.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("rotary_emb.inv_freq") } + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + transformer.h + end + end + + Models.register("exaone", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/exaone4.rb b/lib/mlx_lm/models/exaone4.rb new file mode 100644 index 0000000..20770f6 --- /dev/null +++ b/lib/mlx_lm/models/exaone4.rb @@ -0,0 +1,233 @@ +module MlxLm + module Models + module Exaone4 + class ModelArgs < BaseModelArgs + field :model_type, default: "exaone4" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :num_key_value_heads + field :max_position_embeddings + field :rope_theta + field :head_dim + field :tie_word_embeddings + field :rope_scaling, default: nil + field :sliding_window, default: nil + field :sliding_window_pattern, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + attr_reader :is_local + + def initialize(args, is_local) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + + @is_local = is_local || false + @use_rope = is_local.nil? || @is_local + if @use_rope + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + if @use_rope + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + end + keys, values = cache.update_and_fetch(keys, values) + elsif @use_rope + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, is_local:) + super() + self.self_attn = Attention.new(args, is_local) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_feedforward_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(x, mask: mask, cache: cache) + h = x + post_attention_layernorm.call(r) + r = mlp.call(h) + h + post_feedforward_layernorm.call(r) + end + end + + class ExaoneModel < MLX::NN::Module + def initialize(args) + super() + @args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + + pattern = args.sliding_window_pattern + self.layers = Array.new(args.num_hidden_layers) do |i| + is_local = pattern ? (pattern[i % pattern.length] == "L") : nil + TransformerBlock.new(args, is_local: is_local) + end + + if pattern + self.swa_idx = pattern.index("L") + self.full_idx = pattern.index("G") + else + self.swa_idx = nil + self.full_idx = 0 + end + + self.window_size = args.sliding_window + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + global_mask = _create_attention_mask(h, layer_cache[full_idx]) + if !swa_idx.nil? + swa_mask = _create_attention_mask( + h, + layer_cache[swa_idx], + window_size: window_size + ) + else + swa_mask = nil + end + + layers.each_with_index do |layer, i| + mask = layer.self_attn.is_local ? swa_mask : global_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + return nil if n == 1 + return _create_causal_mask(n, window_size: window_size) if window_size && n > window_size + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = ExaoneModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def make_cache + layers.map do |layer| + if layer.self_attn.is_local + RotatingKVCache.new(max_size: @args.sliding_window, keep: 0) + else + KVCache.new + end + end + end + + def layers + model.layers + end + end + + Models.register("exaone4", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/exaone_moe.rb b/lib/mlx_lm/models/exaone_moe.rb new file mode 100644 index 0000000..f133836 --- /dev/null +++ b/lib/mlx_lm/models/exaone_moe.rb @@ -0,0 +1,421 @@ +require_relative "activations" +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module ExaoneMoe + class ModelArgs < BaseModelArgs + field :model_type, default: "exaone_moe" + field :vocab_size + field :hidden_size + field :intermediate_size + field :moe_intermediate_size + field :num_hidden_layers + field :num_attention_heads + field :num_key_value_heads, default: nil + field :head_dim, default: nil + field :num_experts + field :num_experts_per_tok + field :num_shared_experts + field :rms_norm_eps + field :max_position_embeddings + field :sliding_window + field :layer_types, default: nil + field :is_moe_layer, default: nil + field :n_group, default: 1 + field :topk_group, default: 1 + field :routed_scaling_factor, default: 2.5 + field :norm_topk_prob, default: true + field :scoring_func, default: "sigmoid" + field :topk_method, default: "noaux_tc" + field :rope_theta, default: 1_000_000.0 + field :rope_scaling, default: nil + field :rope_parameters, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" } + @is_moe_layer ||= Array.new(@num_hidden_layers, false) + + return unless @rope_parameters.respond_to?(:[]) + + rope_theta = @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta] + @rope_theta = rope_theta unless rope_theta.nil? + end + end + + module_function + + def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob + ) + mx = MLX::Core + + scores = mx.sigmoid(gates.astype(mx.float32)) + orig_scores = scores + scores = scores + e_score_correction_bias + + if n_group.to_i > 1 + experts_per_group = scores.shape[-1] / n_group + scores = mx.unflatten(scores, -1, [n_group, experts_per_group]) + group_scores = mx.topk(scores, 2, -1) + group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1) + + drop_count = n_group - topk_group.to_i + if drop_count > 0 + group_idx = mx.argpartition(group_scores, drop_count - 1, -2) + take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32) + group_idx = mx.take(group_idx, take_ids, -2) + scores = mx.put_along_axis( + scores, + mx.stop_gradient(group_idx), + mx.array(0.0), + -2 + ) + end + + scores = mx.flatten(scores, -2, -1) + end + + k = [[top_k.to_i, 1].max, scores.shape[-1]].min + inds = mx.argpartition(scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + selected_scores = mx.take_along_axis(orig_scores, inds, -1) + if k > 1 && norm_topk_prob + denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1) + selected_scores = selected_scores / (denominator + 1e-20) + end + + selected_scores = selected_scores * routed_scaling_factor.to_f + [inds, selected_scores] + end + + class MoEGate < MLX::NN::Module + def initialize(args) + super() + @top_k = args.num_experts_per_tok + @norm_topk_prob = args.norm_topk_prob + @n_routed_experts = args.num_experts + @routed_scaling_factor = args.routed_scaling_factor + @n_group = args.n_group + @topk_group = args.topk_group + + raise ArgumentError, "Unsupported topk method: #{args.topk_method}" unless args.topk_method == "noaux_tc" + + mx = MLX::Core + self.weight = mx.zeros([@n_routed_experts, args.hidden_size]) + self.e_score_correction_bias = mx.zeros([@n_routed_experts]) + end + + def call(x) + mx = MLX::Core + gates = mx.matmul(x, mx.transpose(weight)) + ExaoneMoe.group_expert_select( + gates, + e_score_correction_bias, + @top_k, + @n_group, + @topk_group, + @routed_scaling_factor, + @norm_topk_prob + ) + end + end + + class MLP < MLX::NN::Module + def initialize(args, intermediate_size: nil) + super() + hidden_size = args.hidden_size + intermediate_size ||= args.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(intermediate_size, hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class MoE < MLX::NN::Module + def initialize(args) + super() + @num_shared_experts = args.num_shared_experts + + self.switch_mlp = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.moe_intermediate_size, + args.num_experts + ) + self.gate = MoEGate.new(args) + + if !@num_shared_experts.nil? && @num_shared_experts > 0 + shared_intermediate = args.moe_intermediate_size * @num_shared_experts + self.shared_experts = MLP.new(args, intermediate_size: shared_intermediate) + end + end + + def call(x) + mx = MLX::Core + inds, scores = gate.call(x) + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores, -1), -2).astype(y.dtype) + y = y + shared_experts.call(x) if respond_to?(:shared_experts) + y + end + end + + class Attention < MLX::NN::Module + attr_reader :is_sliding_window + + def initialize(args, layer_idx) + super() + + @hidden_size = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(@hidden_size, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(@hidden_size, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(@hidden_size, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, @hidden_size, bias: false) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + + @is_sliding_window = args.layer_types[layer_idx] == "sliding_attention" + apply_rope_all_layers = !args.layer_types.include?("sliding_attention") + @use_rope = @is_sliding_window || apply_rope_all_layers + + if @use_rope + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + if @use_rope + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + end + keys, values = cache.update_and_fetch(keys, values) + elsif @use_rope + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class DecoderLayer < MLX::NN::Module + attr_reader :is_sliding_window + + def initialize(args, layer_idx) + super() + + self.self_attn = Attention.new(args, layer_idx) + self.mlp = args.is_moe_layer[layer_idx] ? MoE.new(args) : MLP.new(args) + @is_sliding_window = self_attn.is_sliding_window + + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class ExaoneMoeModel < MLX::NN::Module + def initialize(args) + super() + @window_size = args.sliding_window + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |idx| DecoderLayer.new(args, idx) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + + self.swa_idx = nil + self.ga_idx = nil + layers.each_with_index do |layer, idx| + self.swa_idx = idx if swa_idx.nil? && layer.is_sliding_window + self.ga_idx = idx if ga_idx.nil? && !layer.is_sliding_window + break unless swa_idx.nil? || ga_idx.nil? + end + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + global_cache = ga_idx.nil? ? layer_cache[0] : layer_cache[ga_idx] + swa_cache = swa_idx.nil? ? layer_cache[0] : layer_cache[swa_idx] + + global_mask = _create_attention_mask(h, global_cache) + swa_mask = _create_attention_mask(h, swa_cache, window_size: @window_size) + + layers.each_with_index do |layer, idx| + mask = layer.is_sliding_window ? swa_mask : global_mask + h = layer.call(h, mask: mask, cache: layer_cache[idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset if cache.respond_to?(:offset) + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = ExaoneMoeModel.new(args) + + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + mx = MLX::Core + result = weights.reject { |k, _| k.start_with?("mtp.") } + num_experts = @args.num_experts.to_i + + @args.num_hidden_layers.to_i.times do |layer_idx| + next unless @args.is_moe_layer[layer_idx] + + prefix = "model.layers.#{layer_idx}.mlp" + bias_key = "#{prefix}.e_score_correction_bias" + if result.key?(bias_key) + result["#{prefix}.gate.e_score_correction_bias"] = result.delete(bias_key) + end + + %w[gate_proj down_proj up_proj].each do |proj_name| + %w[weight scales biases].each do |param_name| + first_key = "#{prefix}.experts.0.#{proj_name}.#{param_name}" + last_key = "#{prefix}.experts.#{num_experts - 1}.#{proj_name}.#{param_name}" + next unless result.key?(first_key) && result.key?(last_key) + + expert_keys = (0...num_experts).map do |expert_idx| + "#{prefix}.experts.#{expert_idx}.#{proj_name}.#{param_name}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.switch_mlp.#{proj_name}.#{param_name}"] = mx.stack(stacked) + end + end + end + + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + + def cast_predicate + lambda { |key| !key.include?("e_score_correction_bias") } + end + + def make_cache + max_window = @args.sliding_window || @args.max_position_embeddings || 1 + layers.map do |layer| + if layer.is_sliding_window + RotatingKVCache.new(max_size: max_window, keep: 0) + else + KVCache.new + end + end + end + end + + Models.register("exaone_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/falcon_h1.rb b/lib/mlx_lm/models/falcon_h1.rb new file mode 100644 index 0000000..f8ac384 --- /dev/null +++ b/lib/mlx_lm/models/falcon_h1.rb @@ -0,0 +1,102 @@ +require_relative "recurrent_gemma" + +module MlxLm + module Models + module FalconH1 + class ModelArgs < BaseModelArgs + field :model_type, default: "falcon_h1" + field :attention_bias, default: false + field :head_dim, default: 64 + field :hidden_size, default: 1024 + field :intermediate_size, default: 2048 + field :max_position_embeddings, default: 131_072 + field :mamba_d_conv, default: 4 + field :num_attention_heads, default: 8 + field :num_hidden_layers, default: 36 + field :num_key_value_heads, default: 2 + field :rms_norm_eps, default: 1e-5 + field :rope_theta, default: 100_000_000_000.0 + field :vocab_size, default: 32_784 + field :tie_word_embeddings, default: true + field :logits_soft_cap, default: nil + field :attention_window_size, default: nil + field :block_types, default: nil + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = RecurrentGemma::Model.new( + RecurrentGemma::ModelArgs.from_dict(_to_recurrent_gemma_config(args)) + ) + + if args.tie_word_embeddings + language_model.instance_variable_set(:@tie_word_embeddings, true) + language_model.lm_head = nil if language_model.respond_to?(:lm_head=) + end + end + + def call(inputs, cache: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + weights.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + language_model.sanitize(remapped) + end + + def layers + language_model.layers + end + + def make_cache + return language_model.make_cache if language_model.respond_to?(:make_cache) + + nil + end + + private + + def _to_recurrent_gemma_config(args) + { + "model_type" => args.model_type, + "attention_bias" => args.attention_bias, + "conv1d_width" => args.mamba_d_conv || 4, + "hidden_size" => args.hidden_size, + "intermediate_size" => args.intermediate_size, + "logits_soft_cap" => args.logits_soft_cap, + "num_attention_heads" => args.num_attention_heads, + "num_hidden_layers" => args.num_hidden_layers, + "num_key_value_heads" => args.num_key_value_heads || args.num_attention_heads, + "rms_norm_eps" => args.rms_norm_eps, + "rope_theta" => args.rope_theta, + "attention_window_size" => args.attention_window_size || [args.max_position_embeddings.to_i, 128].min, + "vocab_size" => args.vocab_size, + "embeddings_scale_by_sqrt_dim" => false, + "block_types" => args.block_types || ["recurrent", "attention"], + } + end + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub(".mamba.conv1d.", ".temporal_block.conv_1d.") + mapped = mapped.gsub(".mamba.out_proj.", ".temporal_block.linear_out.") + mapped = mapped.gsub(".mamba.in_proj.", ".temporal_block.linear_x.") + mapped = mapped.gsub(".self_attn.", ".temporal_block.") + mapped = mapped.gsub(".feed_forward.", ".mlp_block.") + mapped = mapped.gsub(".input_layernorm.", ".temporal_pre_norm.") + mapped = mapped.gsub(".pre_ff_layernorm.", ".channel_pre_norm.") + mapped = mapped.gsub("model.final_layernorm.", "model.final_norm.") + mapped + end + end + + Models.register("falcon_h1", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/gated_delta.rb b/lib/mlx_lm/models/gated_delta.rb new file mode 100644 index 0000000..425110d --- /dev/null +++ b/lib/mlx_lm/models/gated_delta.rb @@ -0,0 +1,136 @@ +module MlxLm + module Models + module GatedDelta + module_function + + def compute_g(a_log, a, dt_bias) + mx = MLX::Core + decay = mx.exp(a_log.astype(mx.float32)) * MLX::NN.softplus(a + dt_bias) + mx.exp( + mx.multiply(-1.0, decay) + ).astype(a.dtype) + end + + def gated_delta_kernel(q, k, v, g, beta, state, mask = nil) + # TODO: Add a Metal custom-kernel specialization for prefill throughput parity. + gated_delta_ops(q, k, v, g, beta, state, mask) + end + + def gated_delta_ops(q, k, v, g, beta, state = nil, mask = nil) + mx = MLX::Core + bsz, steps, hk, dk = q.shape + v_shape = v.shape + hv = v_shape[-2] + dv = v_shape[-1] + + state ||= mx.zeros([bsz, hv, dv, dk], q.dtype) + + repeat_factor = hv / hk + if repeat_factor > 1 + q = mx.repeat(q, repeat_factor, -2) + k = mx.repeat(k, repeat_factor, -2) + end + + q_steps = mx.split(q, steps, 1).map { |x| mx.squeeze(x, 1) } + k_steps = mx.split(k, steps, 1).map { |x| mx.squeeze(x, 1) } + v_steps = mx.split(v, steps, 1).map { |x| mx.squeeze(x, 1) } + g_steps = mx.split(g, steps, 1).map { |x| mx.squeeze(x, 1) } + beta_steps = mx.split(beta, steps, 1).map { |x| mx.squeeze(x, 1) } + mask_steps = + if mask.nil? + nil + elsif mask.ndim == 1 + [mask] + else + mx.split(mask, steps, 1).map { |x| mx.squeeze(x, 1) } + end + + ys = [] + steps.times do |t| + y, state = _gated_delta_step_ops( + q_steps[t], + k_steps[t], + v_steps[t], + g_steps[t], + beta_steps[t], + state, + mask_steps&.[](t) + ) + ys << y + end + + [mx.stack(ys, 1), state] + end + + def gated_delta_update( + q, + k, + v, + a, + b, + a_log, + dt_bias, + state = nil, + mask = nil, + use_kernel: true + ) + mx = MLX::Core + beta = mx.sigmoid(b) + g = compute_g(a_log, a, dt_bias) + + if state.nil? + bsz, = q.shape + dk = q.shape[-1] + hv = v.shape[-2] + dv = v.shape[-1] + state = mx.zeros([bsz, hv, dv, dk], q.dtype) + end + + if use_kernel && metal_kernel_available? + gated_delta_kernel(q, k, v, g, beta, state, mask) + else + gated_delta_ops(q, k, v, g, beta, state, mask) + end + end + + def _gated_delta_step_ops(q, k, v, g, beta, state, mask = nil) + mx = MLX::Core + old_state = state + + decay = case g.ndim + when 2 + mx.expand_dims(g, [2, 3]) + when 3 + mx.expand_dims(g, 2) + else + raise ArgumentError, "Unsupported gating shape #{g.shape.inspect}" + end + + state = state * decay + k_expanded = mx.expand_dims(k, 2) + kv_mem = (state * k_expanded).sum(-1) + delta = (v - kv_mem) * mx.expand_dims(beta, -1) + state = state + k_expanded * mx.expand_dims(delta, -1) + y = (state * mx.expand_dims(q, 2)).sum(-1) + + unless mask.nil? + mask_shape = [mask.shape[0]] + [1] * (state.ndim - 1) + state = mx.where(mask.reshape(mask_shape), state, old_state) + end + + [y, state] + end + private_class_method :_gated_delta_step_ops + + def metal_kernel_available? + mx = MLX::Core + return false unless mx.respond_to?(:metal_is_available) && mx.metal_is_available + return false unless mx.respond_to?(:default_device) + + device = mx.default_device + device.respond_to?(:type) && device.type == :gpu + end + private_class_method :metal_kernel_available? + end + end +end diff --git a/lib/mlx_lm/models/gemma2.rb b/lib/mlx_lm/models/gemma2.rb index b87d40f..75ee6dc 100644 --- a/lib/mlx_lm/models/gemma2.rb +++ b/lib/mlx_lm/models/gemma2.rb @@ -35,7 +35,7 @@ def call(x) mx = MLX::Core # RMS normalization: x / sqrt(mean(x^2) + eps) * (1 + weight) x_sq = x * x - mean_sq = mx.expand_dims(mx.mean(x_sq, -1), -1) + mean_sq = mx.mean(x_sq, -1, keepdims: true) norm = x * mx.rsqrt(mean_sq + @eps) norm * (weight + 1.0) end diff --git a/lib/mlx_lm/models/gemma3.rb b/lib/mlx_lm/models/gemma3.rb new file mode 100644 index 0000000..431b000 --- /dev/null +++ b/lib/mlx_lm/models/gemma3.rb @@ -0,0 +1,85 @@ +require_relative "gemma3_text" + +module MlxLm + module Models + module Gemma3 + class ModelArgs < BaseModelArgs + field :model_type, default: "gemma3" + field :text_config, default: nil + field :vocab_size, default: 262208 + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + model_type = params["model_type"] || params[:model_type] || "gemma3" + vocab_size = params["vocab_size"] || params[:vocab_size] || 262208 + new(model_type: model_type, text_config: params, vocab_size: vocab_size) + end + + def initialize(**kwargs) + super + @text_config = _stringify_keys(@text_config || {}) + @text_config["vocab_size"] = @vocab_size + @text_config["num_attention_heads"] ||= 8 + @text_config["num_key_value_heads"] ||= 4 + @text_config["model_type"] ||= "gemma3_text" + end + + private + + def _stringify_keys(hash) + hash.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Gemma3Text::Model.new( + Gemma3Text::ModelArgs.from_dict(args.text_config) + ) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call( + inputs, + cache: cache, + input_embeddings: input_embeddings + ) + end + + def sanitize(weights) + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + nested = MLX::Utils.tree_unflatten(flat_weights.to_a) + + if nested.is_a?(Hash) + nested.delete("vision_tower") + nested.delete("multi_modal_projector") + + language_tree = nested["language_model"] || {} + language_weights = MLX::Utils.tree_flatten(language_tree, destination: {}) + sanitized_language = language_model.sanitize(language_weights) + nested["language_model"] = MLX::Utils.tree_unflatten(sanitized_language.to_a) + end + + MLX::Utils.tree_flatten(nested, destination: {}) + end + + def layers + language_model.layers + end + + def make_cache + language_model.make_cache + end + end + + Models.register("gemma3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/gemma3_text.rb b/lib/mlx_lm/models/gemma3_text.rb new file mode 100644 index 0000000..b546817 --- /dev/null +++ b/lib/mlx_lm/models/gemma3_text.rb @@ -0,0 +1,270 @@ +require_relative "cache" +require_relative "rope_utils" + +module MlxLm + module Models + module Gemma3Text + class ModelArgs < BaseModelArgs + field :model_type, default: "gemma3_text" + field :hidden_size, default: 1152 + field :num_hidden_layers, default: 26 + field :intermediate_size, default: 6912 + field :num_attention_heads, default: 4 + field :head_dim, default: 256 + field :rms_norm_eps, default: 1.0e-6 + field :vocab_size, default: 262144 + field :num_key_value_heads, default: 1 + field :rope_theta, default: 1_000_000.0 + field :rope_local_base_freq, default: 10_000.0 + field :query_pre_attn_scalar, default: 256.0 + field :sliding_window, default: 512 + field :sliding_window_pattern, default: 6 + field :max_position_embeddings, default: 32768 + field :rope_scaling, default: nil + end + + class RMSNorm < MLX::NN::Module + def initialize(dims:, eps: 1e-6) + super() + self.weight = MLX::Core.ones([dims]) + @eps = eps + end + + def call(x) + mx = MLX::Core + x_sq = x * x + mean_sq = mx.mean(x_sq, -1, keepdims: true) + x * mx.rsqrt(mean_sq + @eps) * (1.0 + weight) + end + end + + class Attention < MLX::NN::Module + def initialize(args, layer_idx) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = args.query_pre_attn_scalar**(-0.5) + pattern = [args.sliding_window_pattern.to_i, 1].max + @is_sliding = ((layer_idx + 1) % pattern) != 0 + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + self.q_norm = RMSNorm.new(dims: @head_dim, eps: args.rms_norm_eps) + self.k_norm = RMSNorm.new(dims: @head_dim, eps: args.rms_norm_eps) + + if @is_sliding + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_local_base_freq, + false + ) + else + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + queries = q_norm.call(queries) + keys = k_norm.call(keys) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(MLX::NN.gelu_approx(gate_proj.call(x)) * up_proj.call(x)) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.self_attn = Attention.new(args, layer_idx) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size) + self.input_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps) + self.pre_feedforward_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps) + self.post_feedforward_layernorm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = clip_residual(x, post_attention_layernorm.call(r)) + r = mlp.call(pre_feedforward_layernorm.call(h)) + clip_residual(h, post_feedforward_layernorm.call(r)) + end + + private + + def clip_residual(x, y) + mx = MLX::Core + return x + y unless x.dtype == mx.float16 + + bound = mx.finfo(mx.float16).max + mx.clip( + x.astype(mx.float32) + y.astype(mx.float32), + -bound, + bound + ).astype(mx.float16) + end + end + + class Gemma3Model < MLX::NN::Module + attr_reader :sliding_window_pattern + + def initialize(args) + super() + @args = args + @window_size = args.sliding_window + @sliding_window_pattern = [args.sliding_window_pattern.to_i, 1].max + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) do |layer_idx| + TransformerBlock.new(args, layer_idx) + end + self.norm = RMSNorm.new(dims: args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil, input_embeddings: nil) + h = input_embeddings || embed_tokens.call(inputs) + h = h * Math.sqrt(@args.hidden_size) + layer_cache = cache || [nil] * layers.length + + global_idx = sliding_window_pattern - 1 + global_mask = _create_attention_mask(h, layer_cache[global_idx]) + sliding_window_mask = if sliding_window_pattern > 1 + _create_attention_mask(h, layer_cache[0], window_size: @window_size) + else + nil + end + + layers.each_with_index do |layer, i| + is_global = (i % sliding_window_pattern) == (sliding_window_pattern - 1) + mask = is_global ? global_mask : sliding_window_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + + if window_size + offset = cache ? cache.offset : 0 + if cache && cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + attr_reader :args + + def initialize(args) + super() + @args = args + @tie_word_embeddings = false + self.model_type = args.model_type + self.model = Gemma3Model.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = model.call(inputs, cache: cache, input_embeddings: input_embeddings) + if @tie_word_embeddings || lm_head.nil? + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + sanitized = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + unless sanitized.key?("lm_head.weight") + @tie_word_embeddings = true + self.lm_head = nil + end + sanitized + end + + def layers + model.layers + end + + def make_cache + pattern = [@args.sliding_window_pattern.to_i, 1].max + max_size = @args.sliding_window || @args.max_position_embeddings || 1 + Array.new(@args.num_hidden_layers) do |i| + is_global = (i % pattern) == (pattern - 1) + if is_global + MlxLm::KVCache.new + else + MlxLm::RotatingKVCache.new(max_size: max_size, keep: 0) + end + end + end + end + + Models.register("gemma3_text", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/gemma3n.rb b/lib/mlx_lm/models/gemma3n.rb new file mode 100644 index 0000000..34dc592 --- /dev/null +++ b/lib/mlx_lm/models/gemma3n.rb @@ -0,0 +1,79 @@ +require_relative "gemma2" + +module MlxLm + module Models + module Gemma3n + class ModelArgs < BaseModelArgs + field :model_type, default: "gemma3n" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + new(model_type: params["model_type"] || params[:model_type], text_config: params) + end + + def initialize(**kwargs) + super + @text_config = (@text_config || {}).dup + end + end + + class Model < MLX::NN::Module + MULTIMODAL_MODEL_PREFIXES = %w[ + model.vision_tower + model.audio_tower + model.embed_audio + model.embed_vision + ].freeze + + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Gemma2::Model.new(Gemma2::ModelArgs.from_dict(_text_config_for_gemma2(args))) + end + + def call(inputs, cache: nil, input_embeddings: nil) + supports_input_embeddings = language_model.method(:call).parameters.any? do |_, name| + name == :input_embeddings + end + + if supports_input_embeddings + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + else + language_model.call(inputs, cache: cache) + end + end + + def sanitize(weights) + weights.reject do |key, _| + MULTIMODAL_MODEL_PREFIXES.any? { |prefix| key == prefix || key.start_with?("#{prefix}.") } + end + end + + def layers + language_model.layers + end + + def make_cache + return language_model.make_cache if language_model.respond_to?(:make_cache) + + nil + end + + private + + def _text_config_for_gemma2(args) + config = {} + (args.text_config || {}).each { |key, value| config[key.to_s] = value } + config["model_type"] ||= args.model_type + config + end + end + + Models.register("gemma3n", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/glm.rb b/lib/mlx_lm/models/glm.rb new file mode 100644 index 0000000..2b84863 --- /dev/null +++ b/lib/mlx_lm/models/glm.rb @@ -0,0 +1,164 @@ +module MlxLm + module Models + module GLM + class ModelArgs < BaseModelArgs + field :model_type, default: "glm" + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 28 + field :intermediate_size, default: 13696 + field :num_attention_heads, default: 32 + field :rms_norm_eps, default: 1e-5 + field :vocab_size, default: 151552 + field :head_dim, default: nil + field :num_key_value_heads, default: nil + field :max_position_embeddings, default: nil + field :attention_bias, default: false + field :rope_theta, default: 10_000.0 + field :tie_word_embeddings, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + bias = args.attention_bias + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_up_proj = MLX::NN::Linear.new( + args.hidden_size, + 2 * args.intermediate_size, + bias: false + ) + self.down_proj = MLX::NN::Linear.new( + args.intermediate_size, + args.hidden_size, + bias: false + ) + end + + def call(x) + mx = MLX::Core + x = gate_up_proj.call(x) + split_dim = x.shape[-1] / 2 + gate, up = mx.split(x, [split_dim], -1) + down_proj.call(Activations.swiglu(gate, up)) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class GLMModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + @model_type = args.model_type + self.model = GLMModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("glm", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/glm4.rb b/lib/mlx_lm/models/glm4.rb new file mode 100644 index 0000000..1a718b3 --- /dev/null +++ b/lib/mlx_lm/models/glm4.rb @@ -0,0 +1,180 @@ +module MlxLm + module Models + module GLM4 + class ModelArgs < BaseModelArgs + field :model_type, default: "glm4" + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 40 + field :intermediate_size, default: 13696 + field :num_attention_heads, default: 32 + field :attention_bias, default: false + field :head_dim, default: nil + field :rms_norm_eps, default: 1e-5 + field :vocab_size, default: 151552 + field :num_key_value_heads, default: nil + field :partial_rotary_factor, default: 0.5 + field :rope_theta, default: 10_000.0 + field :rope_traditional, default: true + field :max_position_embeddings, default: 32768 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class GLM4MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_up_proj = MLX::NN::Linear.new( + args.hidden_size, + 2 * args.intermediate_size, + bias: false + ) + self.down_proj = MLX::NN::Linear.new( + args.intermediate_size, + args.hidden_size, + bias: false + ) + end + + def call(x) + mx = MLX::Core + x = gate_up_proj.call(x) + split_dim = x.shape[-1] / 2 + gate, up_states = mx.split(x, [split_dim], -1) + down_proj.call(Activations.swiglu(gate, up_states)) + end + end + + class GLM4Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @head_dim = args.head_dim + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new( + dim, + args.num_attention_heads * @head_dim, + bias: args.attention_bias + ) + self.k_proj = MLX::NN::Linear.new( + dim, + args.num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.v_proj = MLX::NN::Linear.new( + dim, + args.num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.o_proj = MLX::NN::Linear.new( + args.num_attention_heads * @head_dim, + args.hidden_size, + bias: false + ) + + self.rope = MLX::NN::RoPE.new( + (args.partial_rotary_factor * @head_dim).to_i, + base: args.rope_theta, + traditional: args.rope_traditional + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class GLM4DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = GLM4Attention.new(args) + self.mlp = GLM4MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_self_attn_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_mlp_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + x = x + post_self_attn_layernorm.call( + self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + ) + residual = x + post_mlp_layernorm.call(mlp.call(post_attention_layernorm.call(x))) + residual + end + end + + class GLM4Model < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { GLM4DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = GLM4Model.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + lm_head.call(out) + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + end + + def layers + model.layers + end + end + + Models.register("glm4", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/glm4_moe.rb b/lib/mlx_lm/models/glm4_moe.rb new file mode 100644 index 0000000..a9043da --- /dev/null +++ b/lib/mlx_lm/models/glm4_moe.rb @@ -0,0 +1,343 @@ +require_relative "activations" +require_relative "pipeline" +require_relative "switch_layers" + +module MlxLm + module Models + module Glm4Moe + class ModelArgs < BaseModelArgs + field :model_type, default: "glm4_moe" + field :vocab_size + field :hidden_size + field :intermediate_size + field :max_position_embeddings + field :moe_intermediate_size + field :norm_topk_prob + field :num_attention_heads + field :n_group + field :head_dim, default: nil + field :topk_group + field :n_shared_experts + field :n_routed_experts + field :routed_scaling_factor + field :num_experts_per_tok + field :first_k_dense_replace + field :num_hidden_layers + field :num_key_value_heads, default: nil + field :rms_norm_eps + field :rope_theta + field :rope_scaling, default: nil + field :use_qk_norm + field :tie_word_embeddings + field :attention_bias + field :partial_rotary_factor + field :scoring_func, default: "sigmoid" + field :topk_method, default: "noaux_tc" + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + @use_qk_norm = args.use_qk_norm + if @use_qk_norm + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + end + + rope_dims = [(@head_dim * args.partial_rotary_factor.to_f).to_i, 1].max + self.rope = MLX::NN::RoPE.new( + rope_dims, + traditional: false, + base: args.rope_theta + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]) + + if @use_qk_norm + queries = q_norm.call(queries) + keys = k_norm.call(keys) + end + + queries = queries.transpose([0, 2, 1, 3]) + keys = keys.transpose([0, 2, 1, 3]) + values = values.transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(config, hidden_size: nil, intermediate_size: nil) + super() + hidden_size ||= config.hidden_size + intermediate_size ||= config.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(hidden_size, intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(intermediate_size, hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + module_function + + def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob + ) + mx = MLX::Core + + scores = mx.sigmoid(gates.astype(mx.float32)) + orig_scores = scores + scores = scores + e_score_correction_bias + + if n_group.to_i > 1 + experts_per_group = scores.shape[-1] / n_group + scores = mx.unflatten(scores, -1, [n_group, experts_per_group]) + group_scores = mx.topk(scores, 2, -1) + group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1) + + drop_count = n_group - topk_group.to_i + if drop_count > 0 + group_idx = mx.argpartition(group_scores, drop_count - 1, -2) + take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32) + group_idx = mx.take(group_idx, take_ids, -2) + scores = mx.put_along_axis( + scores, + mx.stop_gradient(group_idx), + mx.array(0.0), + -2 + ) + end + + scores = mx.flatten(scores, -2, -1) + end + + k = [[top_k.to_i, 1].max, scores.shape[-1]].min + inds = mx.argpartition(scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + selected_scores = mx.take_along_axis(orig_scores, inds, -1) + if k > 1 && norm_topk_prob + denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1) + selected_scores = selected_scores / (denominator + 1e-20) + end + + selected_scores = selected_scores * routed_scaling_factor.to_f + [inds, selected_scores] + end + + class MoEGate < MLX::NN::Module + def initialize(config) + super() + @top_k = config.num_experts_per_tok + @norm_topk_prob = config.norm_topk_prob + @n_routed_experts = config.n_routed_experts + @routed_scaling_factor = config.routed_scaling_factor + @n_group = config.n_group + @topk_group = config.topk_group + + raise ArgumentError, "Unsupported topk method: #{config.topk_method}" unless config.topk_method == "noaux_tc" + + mx = MLX::Core + self.weight = mx.zeros([@n_routed_experts, config.hidden_size]) + self.e_score_correction_bias = mx.zeros([@n_routed_experts]) + end + + def call(x) + mx = MLX::Core + gates = mx.matmul(x, mx.transpose(weight)) + Glm4Moe.group_expert_select( + gates, + e_score_correction_bias, + @top_k, + @n_group, + @topk_group, + @routed_scaling_factor, + @norm_topk_prob + ) + end + end + + class MoE < MLX::NN::Module + def initialize(config) + super() + @config = config + + self.switch_mlp = SwitchLayers::SwitchGLU.new( + config.hidden_size, + config.moe_intermediate_size, + config.n_routed_experts + ) + + self.gate = MoEGate.new(config) + unless config.n_shared_experts.nil? + shared_intermediate = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = MLP.new(config, intermediate_size: shared_intermediate) + end + end + + def call(x) + mx = MLX::Core + inds, scores = gate.call(x) + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores, -1), -2).astype(y.dtype) + y = y + shared_experts.call(x) unless @config.n_shared_experts.nil? + y + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(config, layer_idx) + super() + self.self_attn = Attention.new(config) + self.mlp = if !config.n_routed_experts.nil? && layer_idx >= config.first_k_dense_replace + MoE.new(config) + else + MLP.new(config) + end + + self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class LanguageModel < MLX::NN::Module + include PipelineMixin + + def initialize(config) + super() + self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size) + self.layers = Array.new(config.num_hidden_layers) { |idx| DecoderLayer.new(config, idx) } + self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(x, cache: nil) + h = embed_tokens.call(x) + active_layers = pipeline_layers + layer_cache = cache || [nil] * active_layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + active_layers.each_with_index do |layer, idx| + h = layer.call(h, mask: mask, cache: layer_cache[idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(config) + super() + @args = config + self.model_type = config.model_type + self.model = LanguageModel.new(config) + self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + lm_head.call(out) + end + + def sanitize(weights) + mx = MLX::Core + result = weights.dup + mpt_layer = @args.num_hidden_layers.to_i + + @args.num_hidden_layers.to_i.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.mlp" + %w[gate_proj down_proj up_proj].each do |proj_name| + %w[weight scales biases].each do |param_name| + first_key = "#{prefix}.experts.0.#{proj_name}.#{param_name}" + next unless result.key?(first_key) + + expert_keys = (0...@args.n_routed_experts.to_i).map do |expert_idx| + "#{prefix}.experts.#{expert_idx}.#{proj_name}.#{param_name}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.switch_mlp.#{proj_name}.#{param_name}"] = mx.stack(stacked) + end + end + end + + result.reject { |key, _| key.start_with?("model.layers.#{mpt_layer}") } + end + + def layers + model.pipeline_layers + end + + def cast_predicate + lambda { |key| !key.include?("e_score_correction_bias") } + end + end + + Models.register("glm4_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/glm4_moe_lite.rb b/lib/mlx_lm/models/glm4_moe_lite.rb new file mode 100644 index 0000000..6cba340 --- /dev/null +++ b/lib/mlx_lm/models/glm4_moe_lite.rb @@ -0,0 +1,131 @@ +require_relative "glm4_moe" + +module MlxLm + module Models + module Glm4MoeLite + class ModelArgs < BaseModelArgs + field :model_type, default: "glm4_moe_lite" + field :vocab_size, default: 154_880 + field :hidden_size, default: 2048 + field :intermediate_size, default: 10_240 + field :moe_intermediate_size, default: 1536 + field :num_hidden_layers, default: 47 + field :num_attention_heads, default: 20 + field :num_key_value_heads, default: 20 + field :n_shared_experts, default: 1 + field :n_routed_experts, default: 64 + field :routed_scaling_factor, default: 1.8 + field :kv_lora_rank, default: 512 + field :q_lora_rank, default: 768 + field :qk_rope_head_dim, default: 64 + field :qk_nope_head_dim, default: 192 + field :v_head_dim, default: 256 + field :topk_method, default: "noaux_tc" + field :scoring_func, default: "sigmoid" + field :norm_topk_prob, default: true + field :n_group, default: 1 + field :topk_group, default: 1 + field :num_experts_per_tok, default: 4 + field :moe_layer_freq, default: 1 + field :first_k_dense_replace, default: 1 + field :max_position_embeddings, default: 202_752 + field :rms_norm_eps, default: 1e-5 + field :rope_theta, default: 1_000_000.0 + field :rope_scaling, default: nil + field :attention_bias, default: false + field :attention_dropout, default: 0.0 + field :partial_rotary_factor, default: 1.0 + field :tie_word_embeddings, default: false + field :num_nextn_predict_layers, default: 1 + field :quantization, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Glm4Moe::Model.new( + Glm4Moe::ModelArgs.from_dict(_to_glm4_moe_config(args)) + ) + end + + def call(inputs, cache: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + weights.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + language_model.sanitize(remapped) + end + + def layers + language_model.layers + end + + def make_cache + return language_model.make_cache if language_model.respond_to?(:make_cache) + + nil + end + + private + + def _to_glm4_moe_config(args) + inferred_head_dim = args.qk_nope_head_dim.to_i + args.qk_rope_head_dim.to_i + inferred_head_dim = args.hidden_size / args.num_attention_heads if inferred_head_dim <= 0 + + { + "model_type" => args.model_type, + "vocab_size" => args.vocab_size, + "hidden_size" => args.hidden_size, + "intermediate_size" => args.intermediate_size, + "max_position_embeddings" => args.max_position_embeddings, + "moe_intermediate_size" => args.moe_intermediate_size, + "norm_topk_prob" => args.norm_topk_prob, + "num_attention_heads" => args.num_attention_heads, + "n_group" => args.n_group, + "head_dim" => inferred_head_dim, + "topk_group" => args.topk_group, + "n_shared_experts" => args.n_shared_experts, + "n_routed_experts" => args.n_routed_experts, + "routed_scaling_factor" => args.routed_scaling_factor, + "num_experts_per_tok" => args.num_experts_per_tok, + "first_k_dense_replace" => args.first_k_dense_replace, + "num_hidden_layers" => args.num_hidden_layers, + "num_key_value_heads" => args.num_key_value_heads, + "rms_norm_eps" => args.rms_norm_eps, + "rope_theta" => args.rope_theta, + "rope_scaling" => args.rope_scaling, + "use_qk_norm" => false, + "tie_word_embeddings" => args.tie_word_embeddings, + "attention_bias" => args.attention_bias, + "partial_rotary_factor" => args.partial_rotary_factor, + "scoring_func" => args.scoring_func, + "topk_method" => args.topk_method, + } + end + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub(".self_attn.embed_q.", ".self_attn.q_proj.") + mapped = mapped.gsub(".self_attn.unembed_out.", ".self_attn.v_proj.") + mapped = mapped.gsub(".self_attn.kv_a_proj_with_mqa.", ".self_attn.k_proj.") + mapped = mapped.gsub(".self_attn.q_a_proj.", ".self_attn.q_proj.") + mapped = mapped.gsub(".self_attn.q_b_proj.", ".self_attn.q_proj.") + mapped + end + end + + Models.register("glm4_moe_lite", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/glm_moe_dsa.rb b/lib/mlx_lm/models/glm_moe_dsa.rb new file mode 100644 index 0000000..159c79a --- /dev/null +++ b/lib/mlx_lm/models/glm_moe_dsa.rb @@ -0,0 +1,26 @@ +require_relative "deepseek_v32" + +module MlxLm + module Models + module GlmMoeDsa + class ModelArgs < DeepseekV32::ModelArgs + field :model_type, default: "glm_moe_dsa" + field :rope_parameters, default: nil + + def initialize(**kwargs) + super + return unless @rope_parameters.respond_to?(:[]) + + @rope_scaling = @rope_parameters + rope_theta = @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta] + @rope_theta = rope_theta unless rope_theta.nil? + end + end + + class Model < DeepseekV32::Model + end + + Models.register("glm_moe_dsa", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/gpt2.rb b/lib/mlx_lm/models/gpt2.rb new file mode 100644 index 0000000..49f2a03 --- /dev/null +++ b/lib/mlx_lm/models/gpt2.rb @@ -0,0 +1,166 @@ +module MlxLm + module Models + module GPT2 + class ModelArgs < BaseModelArgs + field :model_type, default: "gpt2" + field :n_ctx + field :n_embd + field :n_head + field :n_layer + field :n_positions + field :layer_norm_epsilon + field :vocab_size + field :num_key_value_heads, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @n_head + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + unless (args.n_embd % args.n_head).zero? + raise ArgumentError, "n_embd must be divisible by n_head" + end + + @n_embd = args.n_embd + @n_head = args.n_head + @head_dim = @n_embd / @n_head + @scale = @head_dim**(-0.5) + + self.c_attn = MLX::NN::Linear.new(@n_embd, 3 * @n_embd, bias: true) + self.c_proj = MLX::NN::Linear.new(@n_embd, @n_embd, bias: true) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = c_attn.call(x) + queries, keys, values = mx.split(qkv, 3, 2) + + queries = queries.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + keys, values = cache.update_and_fetch(keys, values) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_embd]) + c_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + self.c_fc = MLX::NN::Linear.new(args.n_embd, 4 * args.n_embd, bias: true) + self.c_proj = MLX::NN::Linear.new(4 * args.n_embd, args.n_embd, bias: true) + end + + def call(x) + c_proj.call(MLX::NN.gelu_approx(c_fc.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.attn = Attention.new(args) + self.mlp = MLP.new(args) + self.ln_1 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon) + self.ln_2 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon) + end + + def call(x, mask: nil, cache: nil) + r = attn.call(ln_1.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(ln_2.call(h)) + h + r + end + end + + class GPT2Model < MLX::NN::Module + def initialize(args) + super() + self.wte = MLX::NN::Embedding.new(args.vocab_size, args.n_embd) + self.wpe = MLX::NN::Embedding.new(args.n_positions, args.n_embd) + self.h = Array.new(args.n_layer) { TransformerBlock.new(args) } + self.ln_f = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon) + end + + def call(inputs, cache: nil) + mx = MLX::Core + _b, l = inputs.shape + + hidden_states = wte.call(inputs) + layer_cache = cache || [nil] * h.length + offset = layer_cache[0] ? layer_cache[0].offset : 0 + position_ids = mx.add(mx.arange(0, l, 1, mx.int32), offset) + hidden_states = hidden_states + wpe.call(position_ids) + + mask = _create_attention_mask(hidden_states, layer_cache[0]) + h.each_with_index do |layer, i| + hidden_states = layer.call(hidden_states, mask: mask, cache: layer_cache[i]) + end + ln_f.call(hidden_states) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + attr_reader :args + + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = GPT2Model.new(args) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + model.wte.as_linear(out) + end + + def sanitize(weights) + result = {} + weights.each do |k, v| + next if k.match?(/\Ah\.\d+\.attn\.bias\z/) + + value = if k.match?(/\Ah\.\d+\.(attn\.c_attn|attn\.c_proj|mlp\.c_fc|mlp\.c_proj)\.weight\z/) + v.transpose([1, 0]) + else + v + end + + if k.start_with?("model.") + result[k] = value + else + result["model.#{k}"] = value + end + end + result + end + + def layers + model.h + end + end + + Models.register("gpt2", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/gpt_bigcode.rb b/lib/mlx_lm/models/gpt_bigcode.rb new file mode 100644 index 0000000..49d3714 --- /dev/null +++ b/lib/mlx_lm/models/gpt_bigcode.rb @@ -0,0 +1,154 @@ +module MlxLm + module Models + module GPTBigCode + class ModelArgs < BaseModelArgs + field :model_type, default: "gpt_bigcode" + field :n_embd + field :n_layer + field :n_inner + field :n_head + field :n_positions + field :layer_norm_epsilon + field :vocab_size + field :num_key_value_heads, default: nil + field :multi_query, default: true + field :attention_bias, default: true + field :mlp_bias, default: true + field :tie_word_embeddings, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @multi_query ? 1 : @n_head + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + @dim = args.n_embd + @n_heads = args.n_head + @n_kv_heads = args.multi_query ? 1 : args.n_head + @head_dim = @dim / @n_heads + @kv_dim = @n_kv_heads * @head_dim + @scale = @head_dim**(-0.5) + + bias = args.attention_bias + self.c_attn = MLX::NN::Linear.new(@dim, @dim + 2 * @kv_dim, bias: bias) + self.c_proj = MLX::NN::Linear.new(@dim, @dim, bias: bias) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = c_attn.call(x) + queries, keys, values = mx.split(qkv, [@dim, @dim + @kv_dim], -1) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + keys, values = cache.update_and_fetch(keys, values) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @dim]) + c_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + + dim = args.n_embd + hidden_dim = args.n_inner + bias = args.mlp_bias + self.c_fc = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + self.c_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias) + end + + def call(x) + c_proj.call(MLX::NN.gelu(c_fc.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.attn = Attention.new(args) + self.mlp = MLP.new(args) + self.ln_1 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon) + self.ln_2 = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon) + end + + def call(x, mask: nil, cache: nil) + r = attn.call(ln_1.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(ln_2.call(h)) + h + r + end + end + + class GPTBigCodeModel < MLX::NN::Module + def initialize(args) + super() + self.wte = MLX::NN::Embedding.new(args.vocab_size, args.n_embd) + self.wpe = MLX::NN::Embedding.new(args.n_positions, args.n_embd) + self.h = Array.new(args.n_layer) { TransformerBlock.new(args) } + self.ln_f = MLX::NN::LayerNorm.new(args.n_embd, eps: args.layer_norm_epsilon) + end + + def call(inputs, cache: nil) + mx = MLX::Core + _b, l = inputs.shape + + hidden_states = wte.call(inputs) + layer_cache = cache || [nil] * h.length + offset = layer_cache[0] ? layer_cache[0].offset : 0 + position_ids = mx.arange(offset, offset + l, 1, mx.int32) + + mask = nil + mask = "causal" if hidden_states.shape[1] > 1 + + hidden_states = hidden_states + wpe.call(position_ids) + + h.each_with_index do |layer, i| + hidden_states = layer.call(hidden_states, mask: mask, cache: layer_cache[i]) + end + + ln_f.call(hidden_states) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.transformer = GPTBigCodeModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.n_embd, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = transformer.call(inputs, cache: cache) + if @args.tie_word_embeddings + transformer.wte.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + transformer.h + end + end + + Models.register("gpt_bigcode", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/gpt_oss.rb b/lib/mlx_lm/models/gpt_oss.rb new file mode 100644 index 0000000..5c99e73 --- /dev/null +++ b/lib/mlx_lm/models/gpt_oss.rb @@ -0,0 +1,319 @@ +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module GptOss + class ModelArgs < BaseModelArgs + field :model_type, default: "gpt_oss" + field :num_hidden_layers, default: 36 + field :num_local_experts, default: 128 + field :num_experts_per_tok, default: 4 + field :vocab_size, default: 201_088 + field :rms_norm_eps, default: 1e-5 + field :hidden_size, default: 2880 + field :intermediate_size, default: 2880 + field :head_dim, default: 64 + field :num_attention_heads, default: 64 + field :num_key_value_heads, default: 8 + field :sliding_window, default: 128 + field :rope_theta, default: 150_000 + field :rope_scaling, default: nil + field :layer_types, default: nil + + def initialize(**kwargs) + super + @layer_types ||= Array.new(@num_hidden_layers) do |i| + i.even? ? "sliding_attention" : "full_attention" + end + end + end + + class AttentionBlock < MLX::NN::Module + def initialize(config) + super() + @head_dim = config.head_dim + @num_attention_heads = config.num_attention_heads + @num_key_value_heads = config.num_key_value_heads + @sm_scale = 1.0 / Math.sqrt(@head_dim) + + self.q_proj = MLX::NN::Linear.new( + config.hidden_size, + @num_attention_heads * @head_dim, + bias: true + ) + self.k_proj = MLX::NN::Linear.new( + config.hidden_size, + @num_key_value_heads * @head_dim, + bias: true + ) + self.v_proj = MLX::NN::Linear.new( + config.hidden_size, + @num_key_value_heads * @head_dim, + bias: true + ) + self.o_proj = MLX::NN::Linear.new( + @num_attention_heads * @head_dim, + config.hidden_size, + bias: true + ) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + config.rope_theta, + false, + config.rope_scaling + ) + end + + def call(x, mask:, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + q = q_proj.call(x).reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + k = k_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + v = v_proj.call(x).reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + q = rope.call(q, offset: cache.offset) + k = rope.call(k, offset: cache.offset) + k, v = cache.update_and_fetch(k, v) + else + q = rope.call(q) + k = rope.call(k) + end + + out = mx.scaled_dot_product_attention(q, k, v, @sm_scale, mask) + out = out.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + o_proj.call(out) + end + end + + class MLPBlock < MLX::NN::Module + def initialize(config) + super() + @num_local_experts = config.num_local_experts + @num_experts_per_tok = config.num_experts_per_tok + + self.experts = SwitchLayers::SwitchGLU.new( + config.hidden_size, + config.intermediate_size, + @num_local_experts, + bias: true + ) + self.router = MLX::NN::Linear.new( + config.hidden_size, + @num_local_experts, + bias: true + ) + end + + def call(x) + mx = MLX::Core + + gates = router.call(x) + k = [@num_experts_per_tok, @num_local_experts].min + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + expert_weights = mx.take_along_axis(gates, inds, -1) + expert_weights = mx.softmax(expert_weights.astype(mx.float32), -1).astype(expert_weights.dtype) + + x = experts.call(x, inds) + x = x * mx.expand_dims(expert_weights, -1) + mx.sum(x, -2) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(config) + super() + self.self_attn = AttentionBlock.new(config) + self.mlp = MLPBlock.new(config) + self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(x, mask:, cache: nil) + h = x + self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h + mlp.call(post_attention_layernorm.call(h)) + end + end + + class GptOssMoeModel < MLX::NN::Module + attr_reader :layer_types + + def initialize(args) + super() + @window_size = args.sliding_window + @layer_types = args.layer_types + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + + @swa_idx = @layer_types.index("sliding_attention") || 0 + @ga_idx = @layer_types.index("full_attention") || 0 + end + + def call(inputs, cache: nil, input_embeddings: nil) + x = input_embeddings || embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + full_mask = _create_attention_mask(x, layer_cache[@ga_idx]) + swa_mask = _create_attention_mask( + x, + layer_cache[@swa_idx], + window_size: @window_size + ) + + layers.each_with_index do |layer, i| + layer_type = @layer_types[i] + mask = layer_type == "full_attention" ? full_mask : swa_mask + x = layer.call(x, mask: mask, cache: layer_cache[i]) + end + + norm.call(x) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + attr_reader :args + + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = GptOssMoeModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil, input_embeddings: nil) + lm_head.call(model.call(inputs, cache: cache, input_embeddings: input_embeddings)) + end + + def sanitize(weights) + return weights if weights.keys.any? { |key| key.include?("gate_proj.weight") } + + result = {} + weights.each do |key, value| + if key.include?("gate_up_proj") && !key.include?("bias") + normalized_key, normalized_value = _normalize_moe_weight_param(key, value) + split_axis = normalized_value.shape.length - 2 + result[normalized_key.sub("gate_up_proj", "gate_proj")] = _take_every_other( + normalized_value, + start: 0, + axis: split_axis + ) + result[normalized_key.sub("gate_up_proj", "up_proj")] = _take_every_other( + normalized_value, + start: 1, + axis: split_axis + ) + elsif key.include?("down_proj") && !key.include?("bias") + normalized_key, normalized_value = _normalize_moe_weight_param(key, value) + result[normalized_key] = normalized_value + elsif key.include?("gate_up_proj_bias") + split_axis = value.shape.length - 1 + result[key.sub("gate_up_proj_bias", "gate_proj.bias")] = _take_every_other( + value, + start: 0, + axis: split_axis + ) + result[key.sub("gate_up_proj_bias", "up_proj.bias")] = _take_every_other( + value, + start: 1, + axis: split_axis + ) + elsif key.include?("down_proj_bias") + result[key.sub("down_proj_bias", "down_proj.bias")] = value + else + result[key] = value + end + end + + result + end + + def layers + model.layers + end + + def make_cache + model.layer_types.map do |layer_type| + if layer_type == "full_attention" + MlxLm::KVCache.new + else + MlxLm::RotatingKVCache.new(max_size: @args.sliding_window) + end + end + end + + private + + def _normalize_moe_weight_param(key, value) + mx = MLX::Core + normalized_key = key + normalized_value = value + + if key.include?("_blocks") + normalized_value = mx.flatten(value.view(mx.uint32), -2, -1) + normalized_key = normalized_key.sub("_blocks", ".weight") + end + if key.include?("_scales") + normalized_key = normalized_key.sub("_scales", ".scales") + end + + [normalized_key, normalized_value] + end + + def _take_every_other(value, start:, axis:) + mx = MLX::Core + indices = (start...value.shape[axis]).step(2).to_a + take_ids = mx.array(indices, dtype: mx.int32) + mx.take(value, take_ids, axis) + end + end + + Models.register("gpt_oss", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/granite.rb b/lib/mlx_lm/models/granite.rb new file mode 100644 index 0000000..b6aa895 --- /dev/null +++ b/lib/mlx_lm/models/granite.rb @@ -0,0 +1,170 @@ +module MlxLm + module Models + module Granite + class ModelArgs < BaseModelArgs + field :model_type, default: "granite" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :logits_scaling + field :attention_multiplier + field :embedding_multiplier + field :residual_multiplier + field :max_position_embeddings + field :num_key_value_heads + field :attention_bias + field :mlp_bias + field :rope_theta + field :rope_scaling, default: nil + field :tie_word_embeddings, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = dim / @n_heads + @scale = args.attention_multiplier + + bias = args.attention_bias + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + bias = args.mlp_bias + + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + @residual_multiplier = args.residual_multiplier + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r * @residual_multiplier + r = mlp.call(post_attention_layernorm.call(h)) + h + r * @residual_multiplier + end + end + + class GraniteModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + @embedding_multiplier = args.embedding_multiplier + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) * @embedding_multiplier + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = GraniteModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + @logits_scaling = args.logits_scaling + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + out = if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + out / @logits_scaling + end + + def layers + model.layers + end + end + + Models.register("granite", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/granitemoe.rb b/lib/mlx_lm/models/granitemoe.rb new file mode 100644 index 0000000..0c1f9ca --- /dev/null +++ b/lib/mlx_lm/models/granitemoe.rb @@ -0,0 +1,58 @@ +require_relative "granite" + +module MlxLm + module Models + module GraniteMoe + class ModelArgs < Granite::ModelArgs + field :model_type, default: "granitemoe" + field :num_local_experts + field :num_experts_per_tok + end + + class Model < Granite::Model + def sanitize(weights) + result = weights.dup + rewrite_legacy_moe_weights(result) + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + private + + def rewrite_legacy_moe_weights(weights) + mx = MLX::Core + + layers.length.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.block_sparse_moe" + input_key = _first_existing_key( + weights, + ["#{prefix}.input_linear.weight", "#{prefix}.input_linear"] + ) + output_key = _first_existing_key( + weights, + ["#{prefix}.output_linear.weight", "#{prefix}.output_linear"] + ) + next unless input_key && output_key + + input_linear = weights.delete(input_key) + output_linear = weights.delete(output_key) + mid = input_linear.shape[1] / 2 + gate_proj, up_proj = mx.split(input_linear, [mid], 1) + + weights["#{prefix}.switch_mlp.gate_proj.weight"] = gate_proj + weights["#{prefix}.switch_mlp.up_proj.weight"] = up_proj + weights["#{prefix}.switch_mlp.down_proj.weight"] = output_linear + end + + weights + end + + def _first_existing_key(weights, candidates) + candidates.find { |key| weights.key?(key) } + end + end + + Models.register("granitemoe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/granitemoehybrid.rb b/lib/mlx_lm/models/granitemoehybrid.rb new file mode 100644 index 0000000..bc0e287 --- /dev/null +++ b/lib/mlx_lm/models/granitemoehybrid.rb @@ -0,0 +1,178 @@ +require_relative "falcon_h1" + +module MlxLm + module Models + module GraniteMoeHybrid + class ModelArgs < FalconH1::ModelArgs + field :model_type, default: "granitemoehybrid" + field :embedding_multiplier, default: 1.0 + field :attention_multiplier, default: 1.0 + field :logits_scaling, default: 1.0 + field :residual_multiplier, default: 1.0 + field :num_local_experts, default: nil + field :num_experts_per_tok, default: nil + field :shared_intermediate_size, default: nil + field :mamba_n_heads, default: nil + field :mamba_d_head, default: nil + field :mamba_proj_bias, default: false + field :mamba_d_state, default: nil + field :mamba_n_groups, default: nil + field :mamba_conv_bias, default: false + field :layer_types, default: nil + field :position_embedding_type, default: "rope" + field :time_step_limit, default: [0.001, 100.0] + field :mlp_bias, default: false + + def initialize(**kwargs) + super + @num_hidden_layers ||= Array(@layer_types).length + @num_attention_heads ||= @mamba_n_heads + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @mamba_d_head + @mamba_d_conv ||= 4 + @layer_types ||= _default_layer_types + @block_types ||= _to_block_types + end + + def to_falcon_h1_dict + hidden_size = @hidden_size + attention_heads = @num_attention_heads + inferred_head_dim = if !@head_dim.nil? + @head_dim + elsif !@mamba_d_head.nil? + @mamba_d_head + elsif !hidden_size.nil? && attention_heads.to_i > 0 + hidden_size / attention_heads + else + 64 + end + + { + "model_type" => @model_type, + "attention_bias" => @attention_bias, + "head_dim" => inferred_head_dim, + "hidden_size" => hidden_size, + "intermediate_size" => @intermediate_size || @shared_intermediate_size || hidden_size.to_i * 2, + "max_position_embeddings" => @max_position_embeddings, + "mamba_d_conv" => @mamba_d_conv, + "num_attention_heads" => attention_heads, + "num_hidden_layers" => @num_hidden_layers, + "num_key_value_heads" => @num_key_value_heads, + "rms_norm_eps" => @rms_norm_eps, + "rope_theta" => @rope_theta, + "vocab_size" => @vocab_size, + "tie_word_embeddings" => @tie_word_embeddings, + "attention_window_size" => @attention_window_size, + "block_types" => @block_types, + } + end + + private + + def _default_layer_types + count = @num_hidden_layers.to_i + return nil if count <= 0 + + Array.new(count) { |idx| idx.even? ? "mamba" : "attention" } + end + + def _to_block_types + return @block_types if @block_types.is_a?(Array) && !@block_types.empty? + return nil unless @layer_types.is_a?(Array) && !@layer_types.empty? + + @layer_types.map { |layer_type| layer_type.to_s == "mamba" ? "recurrent" : "attention" } + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = FalconH1::Model.new( + FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + normalized = weights.dup + _rewrite_block_sparse_moe!(normalized) + _rewrite_shared_mlp!(normalized) + normalized.delete("lm_head.weight") if @args.tie_word_embeddings + + remapped = {} + normalized.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return nil unless wrapped_model.respond_to?(:make_cache) + + wrapped_model.make_cache + end + + private + + def _rewrite_block_sparse_moe!(weights) + mx = MLX::Core + + @args.num_hidden_layers.to_i.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.block_sparse_moe" + input_key = "#{prefix}.input_linear.weight" + output_key = "#{prefix}.output_linear.weight" + next unless weights.key?(input_key) && weights.key?(output_key) + + input_linear = weights.delete(input_key) + output_linear = weights.delete(output_key) + mid = input_linear.shape[1] / 2 + gate_proj, up_proj = mx.split(input_linear, [mid], 1) + + weights["#{prefix}.switch_mlp.gate_proj.weight"] = gate_proj + weights["#{prefix}.switch_mlp.up_proj.weight"] = up_proj + weights["#{prefix}.switch_mlp.down_proj.weight"] = output_linear + end + end + + def _rewrite_shared_mlp!(weights) + mx = MLX::Core + + @args.num_hidden_layers.to_i.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.shared_mlp" + input_key = "#{prefix}.input_linear.weight" + output_key = "#{prefix}.output_linear.weight" + next unless weights.key?(input_key) && weights.key?(output_key) + + input_linear = weights.delete(input_key) + mid = input_linear.shape[0] / 2 + gate_proj, up_proj = mx.split(input_linear, [mid], 0) + + weights["model.layers.#{layer_idx}.mlp.gate_proj.weight"] = gate_proj + weights["model.layers.#{layer_idx}.mlp.up_proj.weight"] = up_proj + weights["model.layers.#{layer_idx}.mlp.down_proj.weight"] = weights.delete(output_key) + end + end + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub(".block_sparse_moe.", ".feed_forward.") + mapped = mapped.gsub(".shared_mlp.", ".feed_forward.") + mapped = mapped.gsub(".post_attention_layernorm.", ".pre_ff_layernorm.") + mapped = mapped.gsub("model.norm.", "model.final_layernorm.") + mapped + end + end + + Models.register("granitemoehybrid", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/helium.rb b/lib/mlx_lm/models/helium.rb new file mode 100644 index 0000000..f5383db --- /dev/null +++ b/lib/mlx_lm/models/helium.rb @@ -0,0 +1,158 @@ +module MlxLm + module Models + module Helium + class ModelArgs < BaseModelArgs + field :hidden_size, default: 256 + field :num_hidden_layers, default: 24 + field :intermediate_size, default: 1024 + field :num_attention_heads, default: 4 + field :num_key_value_heads, default: nil + field :rms_norm_eps, default: 1e-5 + field :vocab_size, default: 32_000 + field :attention_bias, default: false + field :head_dim, default: nil + field :max_position_embeddings, default: 2048 + field :mlp_bias, default: false + field :model_type, default: "helium" + field :rope_theta, default: 10_000.0 + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.hidden_size / @n_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_proj = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size, + bias: args.mlp_bias + ) + self.up_proj = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size, + bias: args.mlp_bias + ) + self.down_proj = MLX::NN::Linear.new( + args.intermediate_size, + args.hidden_size, + bias: args.mlp_bias + ) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class HeliumModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + @model_type = args.model_type + self.model = HeliumModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + end + + Models.register("helium", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/hunyuan.rb b/lib/mlx_lm/models/hunyuan.rb new file mode 100644 index 0000000..50ee4ca --- /dev/null +++ b/lib/mlx_lm/models/hunyuan.rb @@ -0,0 +1,378 @@ +require_relative "activations" +require_relative "switch_layers" + +module MlxLm + module Models + module Hunyuan + module_function + + def int_or_list(value, idx) + return value[idx] if value.is_a?(Array) + + value + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "hunyuan" + field :vocab_size + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :num_key_value_heads, default: nil + field :attention_bias + field :moe_topk + field :num_experts + field :num_shared_expert + field :use_mixed_mlp_moe + field :use_qk_norm + field :rms_norm_eps + field :rope_theta + field :use_cla + field :cla_share_factor, default: 2 + field :moe_intermediate_size, default: nil + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + _validate_rope_scaling! + end + + private + + def _validate_rope_scaling! + return if @rope_scaling.nil? + + required_keys = %w[factor type] + return if required_keys.all? { |key| _rope_scaling_has_key?(key) } + + raise ArgumentError, "rope_scaling must contain keys #{required_keys}" + end + + def _rope_scaling_has_key?(key) + @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym) + end + end + + class DynamicNTKAlphaRoPE < MLX::NN::Module + def initialize(dims, base: 10_000.0, scaling_alpha: 1.0) + super() + mx = MLX::Core + + @dims = dims + adjusted_base = base * (scaling_alpha**(dims.to_f / (dims - 2))) + self._freqs = mx.power( + adjusted_base, + mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f) + ) + end + + def call(x, offset: 0) + MLX::Core.rope(x, @dims, false, nil, 1.0, offset, _freqs) + end + end + + class Attention < MLX::NN::Module + def initialize(kv_proj, args) + super() + dim = args.hidden_size + + @kv_proj = kv_proj + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = dim / @n_heads + @scale = @head_dim**(-0.5) + @use_qk_norm = args.use_qk_norm + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + if kv_proj + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + end + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias) + + if @use_qk_norm + self.query_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.key_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + end + + scaling_alpha = _config_value(args.rope_scaling, "alpha", 1.0) + self.rope = DynamicNTKAlphaRoPE.new( + @head_dim, + base: args.rope_theta, + scaling_alpha: scaling_alpha + ) + end + + def call(x, mask: nil, cache: nil, kv_states: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + if kv_states + keys, values = kv_states + else + raise ArgumentError, "kv_states required when kv_proj is disabled" unless @kv_proj + + keys = k_proj.call(x) + values = v_proj.call(x) + kv_states = [keys, values] + end + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + offset = cache ? cache.offset : 0 + queries = rope.call(queries, offset: offset) + keys = rope.call(keys, offset: offset) + + if @use_qk_norm + queries = query_layernorm.call(queries) + keys = key_layernorm.call(keys) + end + + keys, values = cache.update_and_fetch(keys, values) if cache + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + [o_proj.call(output), kv_states] + end + + private + + def _config_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class Gate < MLX::NN::Module + def initialize(dim, num_experts) + super() + self.wg = MLX::NN::Linear.new(dim, num_experts, bias: false) + end + + def call(x) + wg.call(x) + end + end + + class MoeBlock < MLX::NN::Module + def initialize(args, layer_idx: 0) + super() + dim = args.hidden_size + intermediate_size = args.intermediate_size + + @use_shared_mlp = args.use_mixed_mlp_moe + if @use_shared_mlp + num_shared = Hunyuan.int_or_list(args.num_shared_expert, layer_idx).to_i + self.shared_mlp = MLP.new(dim, (intermediate_size * num_shared).to_i) + end + + @num_experts = args.num_experts + @top_k = Hunyuan.int_or_list(args.moe_topk, layer_idx).to_i + self.gate = Gate.new(dim, @num_experts) + + expert_intermediate_size = args.moe_intermediate_size.nil? ? + intermediate_size : + Hunyuan.int_or_list(args.moe_intermediate_size, layer_idx) + + self.switch_mlp = SwitchLayers::SwitchGLU.new( + dim, + expert_intermediate_size, + @num_experts + ) + end + + def call(x) + mx = MLX::Core + + gates = gate.call(x) + gates = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype) + + k = [[@top_k, 1].max, @num_experts].min + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + scores = mx.take_along_axis(gates, inds, -1) + + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores.astype(mx.float32), -1), -2).astype(y.dtype) + + y = y + shared_mlp.call(x) if @use_shared_mlp + y + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args, kv_proj:, layer_idx:) + super() + self.self_attn = Attention.new(kv_proj, args) + if args.num_experts.to_i == 1 + self.mlp = MLP.new(args.hidden_size, args.intermediate_size) + else + self.mlp = MoeBlock.new(args, layer_idx: layer_idx) + end + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil, shared_kv_states: nil) + r, shared_kv_states = self_attn.call( + input_layernorm.call(x), + mask: mask, + cache: cache, + kv_states: shared_kv_states + ) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + [h + r, shared_kv_states] + end + end + + class HunYuanModel < MLX::NN::Module + def initialize(args) + super() + @args = args + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) do |i| + kv_proj = (!args.use_cla) || (i % args.cla_share_factor).zero? + DecoderLayer.new(args, kv_proj: kv_proj, layer_idx: i) + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + shared_kv_states = nil + layers.each_with_index do |layer, i| + if (!@args.use_cla) || (i % @args.cla_share_factor).zero? + shared_kv_states = nil + end + h, shared_kv_states = layer.call( + h, + mask: mask, + cache: layer_cache[i], + shared_kv_states: shared_kv_states + ) + end + + norm.call(h) + end + + private + + def _create_attention_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if hidden.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = HunYuanModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + mx = MLX::Core + result = weights.dup + + if result.key?("model.layers.0.mlp.gate_and_up_proj.weight") + new_weights = {} + d = @args.hidden_size + n_kv_heads = @args.num_key_value_heads + n_kv_groups = @args.num_attention_heads / n_kv_heads + head_dim = d / @args.num_attention_heads + + result.each do |key, value| + if key.include?("qkv_proj") + reshaped = value.reshape([n_kv_heads, n_kv_groups + 2, head_dim, -1]) + qkv_splits = mx.split(reshaped, [n_kv_groups, n_kv_groups + 1], 1) + %w[q_proj k_proj v_proj].each_with_index do |proj, idx| + new_weights[key.sub("qkv_proj", proj)] = mx.flatten(qkv_splits[idx], 0, 2) + end + elsif key.include?("gate_and_up_proj") + split_idx = value.shape[0] / 2 + up_proj, gate_proj = mx.split(value, [split_idx], 0) + new_weights[key.sub("gate_and_up_proj", "up_proj")] = up_proj + new_weights[key.sub("gate_and_up_proj", "gate_proj")] = gate_proj + else + new_weights[key] = value + end + end + + result = new_weights + end + + if result.key?("model.layers.0.mlp.experts.0.up_proj.weight") + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}" + %w[up_proj down_proj gate_proj].each do |projection| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |k| result.key?(k) } + + stacked = expert_keys.map { |k| result.delete(k) } + result["#{prefix}.mlp.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked) + end + end + end + end + + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("hunyuan", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/hunyuan_v1_dense.rb b/lib/mlx_lm/models/hunyuan_v1_dense.rb new file mode 100644 index 0000000..ae575b2 --- /dev/null +++ b/lib/mlx_lm/models/hunyuan_v1_dense.rb @@ -0,0 +1,235 @@ +require_relative "activations" + +module MlxLm + module Models + module HunyuanV1Dense + class ModelArgs < BaseModelArgs + field :model_type, default: "hunyuan_v1_dense" + field :vocab_size, default: 151_936 + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 40 + field :intermediate_size, default: 12_288 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: 8 + field :rms_norm_eps, default: 1e-6 + field :rope_theta, default: 10_000.0 + field :max_position_embeddings, default: 32_768 + field :attention_bias, default: false + field :use_qk_norm, default: true + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + field :head_dim, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + _validate_rope_scaling! + end + + private + + def _validate_rope_scaling! + return if @rope_scaling.nil? + + required_keys = %w[alpha factor type] + missing = required_keys.reject { |key| _config_has_key?(key) } + return if missing.empty? + + raise ArgumentError, "rope_scaling must contain keys #{required_keys}" + end + + def _config_has_key?(key) + @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym) + end + end + + class DynamicNTKAlphaRoPE < MLX::NN::Module + def initialize(dims, base: 10_000.0, scaling_alpha: 1.0) + super() + mx = MLX::Core + + @dims = dims + adjusted_base = base * (scaling_alpha**(dims.to_f / (dims - 2))) + self._freqs = mx.power( + adjusted_base, + mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f) + ) + end + + def call(x, offset: 0) + MLX::Core.rope(x, @dims, false, nil, 1.0, offset, _freqs) + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + @use_qk_norm = args.use_qk_norm + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias) + + if @use_qk_norm + self.query_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.key_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + end + + scaling_alpha = _config_value(args.rope_scaling, "alpha", 1.0) + self.rope = DynamicNTKAlphaRoPE.new( + @head_dim, + base: args.rope_theta, + scaling_alpha: scaling_alpha + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + if @use_qk_norm + queries = query_layernorm.call(queries) + keys = key_layernorm.call(keys) + end + + if cache + keys, values = cache.update_and_fetch(keys, values) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + + private + + def _config_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + hidden_dim = args.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class HunyuanV1DenseModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if hidden.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = HunyuanV1DenseModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.dup + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("hunyuan_v1_dense", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/internlm3.rb b/lib/mlx_lm/models/internlm3.rb new file mode 100644 index 0000000..399d10d --- /dev/null +++ b/lib/mlx_lm/models/internlm3.rb @@ -0,0 +1,237 @@ +module MlxLm + module Models + module InternLM3 + class ModelArgs < BaseModelArgs + field :model_type, default: "internlm3" + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 32 + field :intermediate_size, default: 11008 + field :num_attention_heads, default: 32 + field :rms_norm_eps, default: 1e-6 + field :vocab_size, default: 103168 + field :bias, default: false + field :qkv_bias, default: false + field :max_position_embeddings, default: 32768 + field :num_key_value_heads, default: nil + field :rope_theta, default: 10_000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + + return if @rope_scaling.nil? + + required_keys = %w[factor rope_type] + missing = required_keys.reject { |k| _config_has_key?(k) } + unless missing.empty? + raise ArgumentError, "rope_scaling must contain keys #{required_keys}" + end + + rope_type = _config_value("rope_type") + unless %w[linear dynamic].include?(rope_type) + raise ArgumentError, "rope_scaling 'rope_type' only supports 'linear' or 'dynamic'" + end + end + + private + + def _config_has_key?(key) + return false unless @rope_scaling.respond_to?(:key?) + + @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym) + end + + def _config_value(key, default = nil) + return default unless _config_has_key?(key) + + if @rope_scaling.key?(key) + @rope_scaling[key] + else + @rope_scaling[key.to_sym] + end + end + end + + class DynamicNTKScalingRoPE < MLX::NN::Module + def initialize( + dims, + max_position_embeddings: 2048, + traditional: false, + base: 10_000.0, + scale: 1.0 + ) + super() + @max_position_embeddings = max_position_embeddings + @original_base = base + @dims = dims + @traditional = traditional + @scale = scale + end + + def call(x, offset: 0) + seq_len = x.shape[-2] + offset + if seq_len > @max_position_embeddings + scaled_ctx = (@scale * seq_len.to_f / @max_position_embeddings) - (@scale - 1.0) + base = @original_base * (scaled_ctx**(@dims.to_f / (@dims - 2))) + else + base = @original_base + end + + MLX::Core.rope(x, @dims, @traditional, base, @scale, offset) + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + qkv_bias = args.qkv_bias + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.hidden_size / @n_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: qkv_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: qkv_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: qkv_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: qkv_bias) + + rope_scale = if args.rope_scaling && _config_value(args.rope_scaling, "rope_type") == "linear" + 1.0 / _config_value(args.rope_scaling, "factor").to_f + else + 2.0 + end + + self.rope = DynamicNTKScalingRoPE.new( + @head_dim, + max_position_embeddings: args.max_position_embeddings, + traditional: args.rope_traditional, + base: args.rope_theta, + scale: rope_scale + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + + private + + def _config_value(config, key, default = nil) + return default if config.nil? || !config.respond_to?(:key?) + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim, bias) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size, args.bias) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class InternLM3Model < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model = InternLM3Model.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("attention.rope.inv_freq") } + end + + def layers + model.layers + end + end + + Models.register("internlm3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/iquestloopcoder.rb b/lib/mlx_lm/models/iquestloopcoder.rb new file mode 100644 index 0000000..5053ae0 --- /dev/null +++ b/lib/mlx_lm/models/iquestloopcoder.rb @@ -0,0 +1,261 @@ +require_relative "cache" +require_relative "rope_utils" + +module MlxLm + module Models + module Iquestloopcoder + class ModelArgs < BaseModelArgs + field :model_type, default: "iquestloopcoder" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :head_dim + field :num_key_value_heads + field :max_position_embeddings, default: 131_072 + field :attention_bias, default: false + field :mlp_bias, default: false + field :rope_theta, default: 500_000.0 + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + field :loop_num, default: 2 + field :loop_window_size, default: 64 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class LoopGateProjection < MLX::NN::Module + def initialize(num_heads, head_dim) + super() + @num_heads = num_heads + @head_dim = head_dim + + mx = MLX::Core + self.weight = mx.zeros([num_heads, head_dim]) + self.bias = mx.zeros([num_heads]) + end + + def call(query) + mx = MLX::Core + projection = weight.reshape([@num_heads, @head_dim, 1]) + gate_logits = mx.matmul(query, projection) + gate_logits = gate_logits + bias.reshape([1, @num_heads, 1, 1]) + mx.sigmoid(gate_logits) + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def get_qkv(x, offset: 0) + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + queries = rope.call(queries, offset: offset) + keys = rope.call(keys, offset: offset) + + [queries, keys, values] + end + + def attention(queries, keys, values, mask: nil, cache: nil) + _cache = cache + MLX::Core.scaled_dot_product_attention(queries, keys, values, @scale, mask) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + hidden_dim = args.intermediate_size + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: args.mlp_bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: args.mlp_bias) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: args.mlp_bias) + end + + def call(x) + down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x)) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + end + + class IQuestLoopCoderModel < MLX::NN::Module + def initialize(args) + super() + @args = args + unless args.loop_num == 2 + raise ArgumentError, "Only loop_num=2 is supported, got #{args.loop_num}" + end + + self.vocab_size = args.vocab_size + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.gate_projections = Array.new(args.num_hidden_layers) do + LoopGateProjection.new(args.num_attention_heads, args.head_dim) + end + self.loop_num = args.loop_num + self.loop_window_size = args.loop_window_size + end + + def call(inputs, cache: nil) + mx = MLX::Core + b, l = inputs.shape[0], inputs.shape[1] + + h = embed_tokens.call(inputs) + layer_count = layers.length + layer_cache = cache || [nil] * (2 * layer_count) + + mask = _create_attention_mask(h, layer_cache[0]) + window_mask = _create_attention_mask(h, layer_cache[layer_count], window_size: loop_window_size) + + loop1_kv = [] + layers.each_with_index do |layer, idx| + c = layer_cache[idx] + h_norm = layer.input_layernorm.call(h) + offset = c ? c.offset : 0 + q1, k1, v1 = layer.self_attn.get_qkv(h_norm, offset: offset) + + if c + k1, v1 = c.update_and_fetch(k1, v1) + end + loop1_kv << [k1, v1] + + out = layer.self_attn.attention(q1, k1, v1, mask: mask, cache: c) + r = layer.self_attn.o_proj.call(out.transpose([0, 2, 1, 3]).reshape([b, l, @args.hidden_size])) + h = h + r + r = layer.mlp.call(layer.post_attention_layernorm.call(h)) + h = h + r + end + + layers.each_with_index do |layer, idx| + gate_proj = gate_projections[idx] + c = layer_cache[layer_count + idx] + k1, v1 = loop1_kv[idx] + + h_norm = layer.input_layernorm.call(h) + offset = c ? c.offset : 0 + q2, k2, v2 = layer.self_attn.get_qkv(h_norm, offset: offset) + + gate = gate_proj.call(q2) + attn_global = layer.self_attn.attention(q2, k1, v1, mask: mask, cache: c) + + if c + k2, v2 = c.update_and_fetch(k2, v2) + end + + attn_local = layer.self_attn.attention(q2, k2, v2, mask: window_mask, cache: c) + mixed = _mix_attention(gate, attn_global, attn_local) + + r = layer.self_attn.o_proj.call(mixed.transpose([0, 2, 1, 3]).reshape([b, l, @args.hidden_size])) + h = h + r + r = layer.mlp.call(layer.post_attention_layernorm.call(h)) + h = h + r + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + return _create_causal_mask(n, window_size: window_size) if window_size && n > window_size + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + + def _mix_attention(gate, attn_global, attn_local) + gate = gate.astype(attn_global.dtype) + (gate * attn_global) + ((1.0 - gate) * attn_local) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = IQuestLoopCoderModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + + def make_cache + Array.new(layers.length) { MlxLm::KVCache.new } + + Array.new(layers.length) { MlxLm::RotatingKVCache.new(max_size: @args.loop_window_size) } + end + end + + Models.register("iquestloopcoder", Model, ModelArgs) + end + + IQuestLoopCoder = Iquestloopcoder unless const_defined?(:IQuestLoopCoder) + end +end diff --git a/lib/mlx_lm/models/jamba.rb b/lib/mlx_lm/models/jamba.rb new file mode 100644 index 0000000..c67488e --- /dev/null +++ b/lib/mlx_lm/models/jamba.rb @@ -0,0 +1,158 @@ +require_relative "falcon_h1" + +module MlxLm + module Models + module Jamba + class ModelArgs < FalconH1::ModelArgs + field :model_type, default: "jamba" + field :attn_layer_offset, default: 1 + field :attn_layer_period, default: 2 + field :expert_layer_offset, default: 1 + field :expert_layer_period, default: 2 + field :mamba_d_state, default: nil + field :mamba_expand, default: nil + field :num_experts, default: 1 + field :num_experts_per_tok, default: 1 + field :mamba_dt_rank, default: "auto" + field :mamba_proj_bias, default: false + field :mamba_conv_bias, default: true + field :layers_block_type, default: nil + + def initialize(**kwargs) + super + @mamba_d_conv ||= 4 + @num_key_value_heads ||= @num_attention_heads + @layers_block_type ||= _default_layers_block_type + @num_hidden_layers ||= Array(@layers_block_type).length + @block_types ||= _to_block_types + end + + def to_falcon_h1_dict + hidden_size = @hidden_size + attention_heads = @num_attention_heads + inferred_head_dim = if !@head_dim.nil? + @head_dim + elsif !hidden_size.nil? && attention_heads.to_i > 0 + hidden_size / attention_heads + else + 64 + end + + { + "model_type" => @model_type, + "attention_bias" => @attention_bias, + "head_dim" => inferred_head_dim, + "hidden_size" => hidden_size, + "intermediate_size" => @intermediate_size, + "max_position_embeddings" => @max_position_embeddings, + "mamba_d_conv" => @mamba_d_conv, + "num_attention_heads" => attention_heads, + "num_hidden_layers" => @num_hidden_layers, + "num_key_value_heads" => @num_key_value_heads, + "rms_norm_eps" => @rms_norm_eps, + "rope_theta" => @rope_theta, + "vocab_size" => @vocab_size, + "tie_word_embeddings" => @tie_word_embeddings, + "attention_window_size" => @attention_window_size, + "block_types" => @block_types, + } + end + + private + + def _default_layers_block_type + count = @num_hidden_layers.to_i + return nil if count <= 0 + + period = @attn_layer_period.to_i + offset = @attn_layer_offset.to_i + period = 1 if period <= 0 + + Array.new(count) do |idx| + (idx % period == offset) ? "attention" : "mamba" + end + end + + def _to_block_types + return @block_types if @block_types.is_a?(Array) && !@block_types.empty? + return nil unless @layers_block_type.is_a?(Array) && !@layers_block_type.empty? + + @layers_block_type.map { |layer_type| layer_type.to_s == "mamba" ? "recurrent" : "attention" } + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = FalconH1::Model.new( + FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + normalized = weights.dup + _stack_experts!(normalized) + + remapped = {} + normalized.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return nil unless wrapped_model.respond_to?(:make_cache) + + wrapped_model.make_cache + end + + private + + def _stack_experts!(weights) + mx = MLX::Core + + @args.num_hidden_layers.to_i.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.feed_forward" + %w[gate_proj up_proj down_proj].each do |projection| + %w[weight bias scales biases].each do |param| + pattern = /\A#{Regexp.escape(prefix)}\.experts\.(\d+)\.#{projection}\.#{param}\z/ + matches = weights.keys.filter_map do |key| + match = pattern.match(key) + next nil unless match + + [match[1].to_i, key] + end + next if matches.empty? + + stacked = matches.sort_by(&:first).map do |(_, key)| + weights.delete(key) + end + weights["#{prefix}.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked) + end + end + end + end + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub("model.norm.", "model.final_layernorm.") + mapped = mapped.gsub(".mixer.", ".mamba.") + mapped = mapped.gsub(".feed_forward.router.", ".feed_forward.gate.") + mapped + end + end + + Models.register("jamba", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/kimi_k25.rb b/lib/mlx_lm/models/kimi_k25.rb new file mode 100644 index 0000000..e3ed75f --- /dev/null +++ b/lib/mlx_lm/models/kimi_k25.rb @@ -0,0 +1,98 @@ +require_relative "deepseek" + +module MlxLm + module Models + module KimiK25 + class ModelArgs < BaseModelArgs + field :model_type, default: "kimi_k25" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + model_type = params["model_type"] || params[:model_type] || "kimi_k25" + new(model_type: model_type, text_config: params) + end + + def initialize(**kwargs) + super + @text_config = _stringify_keys(@text_config || {}) + @text_config["model_type"] ||= "deepseek" + end + + private + + def _stringify_keys(hash) + hash.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + end + end + + class Model < MLX::NN::Module + MULTIMODAL_PREFIXES = %w[ + vision_tower + vision_model + multi_modal_projector + mm_projector + ].freeze + + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = DeepSeek::Model.new( + DeepSeek::ModelArgs.from_dict(args.text_config) + ) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + language_weights = {} + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + + flat_weights.each do |key, value| + next if _multimodal_key?(key) + + normalized_key = key.start_with?("language_model.") ? key.delete_prefix("language_model.") : key + language_weights[normalized_key] = value + end + + sanitized_language = if language_model.respond_to?(:sanitize) + language_model.sanitize(language_weights) + else + language_weights + end + + sanitized_language.each_with_object({}) do |(key, value), out| + out["language_model.#{key}"] = value + end + end + + def model + language_model.model + end + + def layers + model.layers + end + + def cast_predicate + lambda { |key| !key.include?("e_score_correction_bias") } + end + + private + + def _multimodal_key?(key) + MULTIMODAL_PREFIXES.any? { |prefix| key == prefix || key.start_with?("#{prefix}.") } + end + end + + Models.register("kimi_k25", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/kimi_linear.rb b/lib/mlx_lm/models/kimi_linear.rb new file mode 100644 index 0000000..c01d748 --- /dev/null +++ b/lib/mlx_lm/models/kimi_linear.rb @@ -0,0 +1,124 @@ +require_relative "bailing_moe_linear" + +module MlxLm + module Models + module KimiLinear + class ModelArgs < BailingMoeLinear::ModelArgs + field :model_type, default: "kimi_linear" + field :hidden_dim, default: nil + field :ffn_hidden_size, default: nil + field :num_layers, default: nil + field :num_heads, default: nil + field :num_kv_heads, default: nil + field :num_local_experts, default: nil + field :n_routed_experts, default: nil + field :n_shared_experts, default: nil + field :top_k, default: nil + field :score_func, default: nil + + def self.from_dict(params) + normalized = params.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + + { + "hidden_dim" => "hidden_size", + "ffn_hidden_size" => "intermediate_size", + "num_layers" => "num_hidden_layers", + "num_heads" => "num_attention_heads", + "num_kv_heads" => "num_key_value_heads", + "num_local_experts" => "num_experts", + "n_routed_experts" => "num_experts", + "n_shared_experts" => "num_shared_experts", + "top_k" => "num_experts_per_tok", + "score_func" => "score_function", + }.each do |source_key, target_key| + next unless normalized.key?(source_key) + + normalized[target_key] = normalized[source_key] unless normalized.key?(target_key) + end + + normalized["model_type"] ||= "kimi_linear" + super(normalized) + end + + def initialize(**kwargs) + super + @hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !kwargs.key?(:hidden_size) && !@hidden_dim.nil? + @intermediate_size = @ffn_hidden_size if kwargs.key?(:ffn_hidden_size) && !kwargs.key?(:intermediate_size) && !@ffn_hidden_size.nil? + @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !kwargs.key?(:num_hidden_layers) && !@num_layers.nil? + @num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !kwargs.key?(:num_attention_heads) && !@num_heads.nil? + @num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !kwargs.key?(:num_key_value_heads) && !@num_kv_heads.nil? + @num_experts = @num_local_experts if kwargs.key?(:num_local_experts) && !kwargs.key?(:num_experts) && !@num_local_experts.nil? + @num_experts = @n_routed_experts if kwargs.key?(:n_routed_experts) && !kwargs.key?(:num_experts) && !kwargs.key?(:num_local_experts) && !@n_routed_experts.nil? + @num_shared_experts = @n_shared_experts if kwargs.key?(:n_shared_experts) && !kwargs.key?(:num_shared_experts) && !@n_shared_experts.nil? + @num_experts_per_tok = @top_k if kwargs.key?(:top_k) && !kwargs.key?(:num_experts_per_tok) && !@top_k.nil? + @score_function = @score_func if kwargs.key?(:score_func) && !kwargs.key?(:score_function) && !@score_func.nil? + @num_key_value_heads ||= @num_attention_heads + end + + def to_bailing_moe_linear_dict + to_bailing_moe_dict + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = BailingMoeLinear::Model.new( + BailingMoeLinear::ModelArgs.from_dict(args.to_bailing_moe_linear_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + flat_weights.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache) + + nil + end + + def cast_predicate + return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate) + + lambda { |_key| true } + end + + def quant_predicate + return wrapped_model.quant_predicate if wrapped_model.respond_to?(:quant_predicate) + + lambda { |_key, _value| true } + end + + private + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub(".mlp.router.", ".mlp.gate.") + mapped = mapped.gsub("model.embed_tokens.", "model.word_embeddings.") + mapped = mapped.gsub("model.tok_embeddings.", "model.word_embeddings.") + mapped + end + end + + Models.register("kimi_linear", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/kimi_vl.rb b/lib/mlx_lm/models/kimi_vl.rb new file mode 100644 index 0000000..aa8ee60 --- /dev/null +++ b/lib/mlx_lm/models/kimi_vl.rb @@ -0,0 +1,93 @@ +require_relative "deepseek" + +module MlxLm + module Models + module KimiVL + class ModelArgs < BaseModelArgs + field :model_type, default: "kimi_vl" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + model_type = params["model_type"] || params[:model_type] || "kimi_vl" + new(model_type: model_type, text_config: params) + end + + def initialize(**kwargs) + super + @text_config = _stringify_keys(@text_config || {}) + @text_config["model_type"] ||= "deepseek" + end + + private + + def _stringify_keys(hash) + hash.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = DeepSeek::Model.new( + DeepSeek::ModelArgs.from_dict(args.text_config) + ) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + language_weights = {} + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + + flat_weights.each do |key, value| + next if _drop_key?(key) + + normalized_key = key.start_with?("language_model.") ? key.delete_prefix("language_model.") : key + language_weights[normalized_key] = value + end + + sanitized_language = if language_model.respond_to?(:sanitize) + language_model.sanitize(language_weights) + else + language_weights + end + + sanitized_language.each_with_object({}) do |(key, value), out| + out["language_model.#{key}"] = value + end + end + + def model + language_model.model + end + + def layers + model.layers + end + + def cast_predicate + lambda { |key| !key.include?("e_score_correction_bias") } + end + + private + + def _drop_key?(key) + key.include?("vision_tower") || + key.include?("multi_modal_projector") || + key.include?("rotary_emb") + end + end + + Models.register("kimi_vl", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/klear.rb b/lib/mlx_lm/models/klear.rb new file mode 100644 index 0000000..cb9604a --- /dev/null +++ b/lib/mlx_lm/models/klear.rb @@ -0,0 +1,283 @@ +require_relative "activations" +require_relative "switch_layers" + +module MlxLm + module Models + module Klear + class ModelArgs < BaseModelArgs + field :model_type, default: "Klear" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :attention_bias + field :mlp_only_layers + field :num_experts + field :num_experts_per_tok + field :decoder_sparse_step + field :n_shared_experts + field :moe_intermediate_size + field :rms_norm_eps + field :vocab_size + field :num_key_value_heads + field :rope_theta + field :max_position_embeddings + field :norm_topk_prob + + def initialize(**kwargs) + super + @mlp_only_layers ||= [] + @num_key_value_heads ||= @num_attention_heads + end + end + + class KlearAttention < MLX::NN::Module + def initialize(args) + super() + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = args.hidden_size / args.num_attention_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new( + args.hidden_size, + @num_attention_heads * @head_dim, + bias: args.attention_bias + ) + self.k_proj = MLX::NN::Linear.new( + args.hidden_size, + @num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.v_proj = MLX::NN::Linear.new( + args.hidden_size, + @num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.o_proj = MLX::NN::Linear.new( + @num_attention_heads * @head_dim, + args.hidden_size, + bias: args.attention_bias + ) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: args.rope_theta) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @num_attention_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @num_key_value_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + o_proj.call(output) + end + end + + class KlearMLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class KlearSparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + @norm_topk_prob = args.norm_topk_prob + @num_experts = args.num_experts + @top_k = [args.num_experts_per_tok.to_i, 1].max + + self.gate = MLX::NN::Linear.new(args.hidden_size, @num_experts, bias: false) + self.experts = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.moe_intermediate_size, + @num_experts + ) + self.shared_experts = KlearMLP.new( + args.hidden_size, + args.moe_intermediate_size * args.n_shared_experts + ) + self.coefficient = MLX::NN::Linear.new(args.hidden_size, 2) + + mx = MLX::Core + self.expert_bias = mx.zeros([@num_experts]).astype(mx.float32) + end + + def call(x) + mx = MLX::Core + + routing_weights = mx.sigmoid(gate.call(x).astype(mx.float32)) + biased_weights = routing_weights + expert_bias.reshape([1, 1, @num_experts]) + + k = [@top_k, @num_experts].min + inds = mx.argpartition(biased_weights * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + scores = mx.take_along_axis(routing_weights, inds, -1) + if @norm_topk_prob + denom = mx.expand_dims(mx.sum(scores, -1), -1) + scores = scores / denom + end + + scores = scores.astype(x.dtype) + expert_out = experts.call(x, inds) + y_experts = mx.sum(expert_out * mx.expand_dims(scores, -1), -2) + + coef = mx.softmax(coefficient.call(x).astype(mx.float32), -1).astype(x.dtype) + coef_expert, coef_shared = mx.split(coef, [1], -1) + shared = shared_experts.call(x) + + y_experts * coef_expert + shared * coef_shared + end + end + + class KlearDecoderLayer < MLX::NN::Module + def initialize(args, layer_idx:) + super() + self.self_attn = KlearAttention.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + + if _use_sparse_moe_layer?(args, layer_idx) + self.mlp = KlearSparseMoeBlock.new(args) + else + self.mlp = KlearMLP.new(args.hidden_size, args.intermediate_size) + end + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + + private + + def _use_sparse_moe_layer?(args, layer_idx) + sparse_step = [args.decoder_sparse_step.to_i, 1].max + mlp_only_layers = args.mlp_only_layers || [] + + !mlp_only_layers.include?(layer_idx) && + args.num_experts.to_i > 0 && + ((layer_idx + 1) % sparse_step).zero? + end + end + + class KlearModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) do |layer_idx| + KlearDecoderLayer.new(args, layer_idx: layer_idx) + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = KlearModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + lm_head.call(model.call(inputs, cache: cache)) + end + + def sanitize(weights) + return weights unless weights.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + + mx = MLX::Core + result = weights.dup + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.mlp.experts" + %w[gate_proj up_proj down_proj].each do |name| + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.#{expert_idx}.#{name}.weight" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.#{name}.weight"] = mx.stack(stacked) + end + end + + result + end + + def layers + model.layers + end + + def quant_predicate + lambda do |path, _module| + if path.to_s.end_with?("mlp.gate") + { "group_size" => 64, "bits" => 8 } + else + true + end + end + end + + def cast_predicate + lambda { |key| !key.to_s.include?("expert_bias") } + end + end + + Models.register("Klear", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/lfm2.rb b/lib/mlx_lm/models/lfm2.rb new file mode 100644 index 0000000..1440b93 --- /dev/null +++ b/lib/mlx_lm/models/lfm2.rb @@ -0,0 +1,120 @@ +require_relative "qwen3" + +module MlxLm + module Models + module Lfm2 + class ModelArgs < BaseModelArgs + field :model_type, default: "lfm2" + field :vocab_size, default: 32000 + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 32 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: nil + field :max_position_embeddings, default: 2048 + field :norm_eps, default: 1e-6 + field :conv_bias, default: false + field :conv_L_cache, default: 4 + field :block_dim, default: nil + field :block_ff_dim, default: nil + field :block_multiple_of, default: 256 + field :block_ffn_dim_multiplier, default: nil + field :block_auto_adjust_ff_dim, default: false + field :rope_theta, default: 1_000_000.0 + field :rope_parameters, default: nil + field :full_attn_idxs, default: nil + field :layer_types, default: nil + field :tie_word_embeddings, default: true + + def initialize(**kwargs) + super + rope_theta_from_params = _rope_theta_from_parameters + @rope_theta = rope_theta_from_params unless rope_theta_from_params.nil? + @num_key_value_heads ||= @num_attention_heads + @block_dim ||= @hidden_size + @block_ff_dim ||= @block_dim * 4 + @full_attn_idxs ||= _full_attn_idxs_from_layer_types + end + + private + + def _rope_theta_from_parameters + return nil unless @rope_parameters.is_a?(Hash) + + @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta] + end + + def _full_attn_idxs_from_layer_types + return [] unless @layer_types.is_a?(Array) + + @layer_types.each_with_index.filter_map do |layer_type, i| + i if layer_type.to_s == "full_attention" + end + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Qwen3::Model.new(Qwen3::ModelArgs.from_dict(_qwen3_config(args))) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + end + + def sanitize(weights) + sanitized = {} + weights.each do |name, param| + current = param + if name.include?("conv.weight") && _transpose_conv_weight?(param) + current = MLX::Core.swapaxes(param, 1, 2) + end + sanitized[name] = current + end + sanitized + end + + def layers + language_model.layers + end + + def make_cache + return language_model.make_cache if language_model.respond_to?(:make_cache) + return nil unless defined?(MlxLm::KVCache) + + Array.new(layers.length) { MlxLm::KVCache.new } + end + + private + + def _transpose_conv_weight?(param) + return false unless param.respond_to?(:shape) + return false unless param.shape.is_a?(Array) + return false unless param.shape.length >= 3 + + param.shape[-1] > param.shape[1] + end + + def _qwen3_config(args) + { + "model_type" => "qwen3", + "hidden_size" => args.hidden_size, + "num_hidden_layers" => args.num_hidden_layers, + "intermediate_size" => args.block_ff_dim, + "num_attention_heads" => args.num_attention_heads, + "num_key_value_heads" => args.num_key_value_heads, + "rms_norm_eps" => args.norm_eps, + "vocab_size" => args.vocab_size, + "rope_theta" => args.rope_theta, + "max_position_embeddings" => args.max_position_embeddings, + "tie_word_embeddings" => args.tie_word_embeddings, + } + end + end + + Models.register("lfm2", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/lfm2_moe.rb b/lib/mlx_lm/models/lfm2_moe.rb new file mode 100644 index 0000000..816d7f0 --- /dev/null +++ b/lib/mlx_lm/models/lfm2_moe.rb @@ -0,0 +1,421 @@ +require_relative "activations" +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module Lfm2Moe + class ModelArgs < BaseModelArgs + field :model_type, default: "lfm2_moe" + field :vocab_size + field :hidden_size + field :intermediate_size + field :moe_intermediate_size + field :num_hidden_layers + field :num_experts + field :num_experts_per_tok + field :norm_topk_prob + field :num_attention_heads + field :num_key_value_heads, default: nil + field :max_position_embeddings + field :use_expert_bias + field :num_dense_layers + field :norm_eps + field :conv_bias + field :conv_L_cache + field :rope_theta, default: 1_000_000.0 + field :rope_parameters, default: nil + field :full_attn_idxs, default: nil + field :layer_types, default: nil + + def initialize(**kwargs) + super + rope_theta_from_params = _rope_theta_from_parameters + @rope_theta = rope_theta_from_params unless rope_theta_from_params.nil? + @num_key_value_heads ||= @num_attention_heads + @full_attn_idxs ||= _full_attn_idxs_from_layer_types + end + + private + + def _rope_theta_from_parameters + return nil unless @rope_parameters.is_a?(Hash) + + @rope_parameters["rope_theta"] || @rope_parameters[:rope_theta] + end + + def _full_attn_idxs_from_layer_types + return [] unless @layer_types.is_a?(Array) + + @layer_types.each_with_index.filter_map do |layer_type, i| + i if layer_type.to_s == "full_attention" + end + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.hidden_size / @n_heads + @scale = @head_dim**(-0.5) + + self.q_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.norm_eps) + self.k_layernorm = MLX::NN::RMSNorm.new(@head_dim, eps: args.norm_eps) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.out_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + nil, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_layernorm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_layernorm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + out_proj.call(output) + end + end + + class ShortConv < MLX::NN::Module + def initialize(args, layer_idx) + super() + _ = layer_idx + @args = args + @l_cache = args.conv_L_cache + @hidden_size = args.hidden_size + + self.conv = MLX::NN::Conv1d.new( + args.hidden_size, + args.hidden_size, + @l_cache, + padding: 0, + groups: args.hidden_size, + bias: args.conv_bias + ) + self.in_proj = MLX::NN::Linear.new(args.hidden_size, 3 * args.hidden_size, bias: args.conv_bias) + self.out_proj = MLX::NN::Linear.new(args.hidden_size, args.hidden_size, bias: args.conv_bias) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + + projected = in_proj.call(x) + b_gate, c_gate, x_gate = mx.split(projected, [@hidden_size, 2 * @hidden_size], -1) + bx = b_gate * x_gate + bx = mx.where(mask.reshape([mask.shape[0], mask.shape[1], 1]), bx, 0) unless mask.nil? + + if cache + state = if cache[0].nil? + mx.zeros([bx.shape[0], @l_cache - 1, @hidden_size], dtype: bx.dtype) + else + cache[0] + end + + bx = mx.concatenate([state, bx], 1) + n_keep = @l_cache - 1 + t = x_gate.shape[1] + + if cache.lengths + ends = mx.clip(cache.lengths, 0, t) + positions = mx.expand_dims( + mx.expand_dims(ends, 1) + mx.arange(n_keep), + -1 + ) + cache[0] = mx.take_along_axis(bx, positions, 1) + else + if n_keep > 0 + split_at = bx.shape[1] - n_keep + cache[0] = mx.split(bx, [split_at], 1)[1] + else + cache[0] = mx.zeros([bx.shape[0], 0, bx.shape[2]], dtype: bx.dtype) + end + end + + cache.advance(t) + else + bx = mx.pad( + bx, + [ + [0, 0], + [@l_cache - 1, 0], + [0, 0], + ] + ) + end + + conv_out = conv.call(bx) + out_proj.call(c_gate * conv_out) + end + end + + class MLP < MLX::NN::Module + def initialize(config, intermediate_size: nil) + super() + @hidden_size = config.hidden_size + @intermediate_size = intermediate_size || config.intermediate_size + self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class SparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + + @num_experts = args.num_experts + @top_k = args.num_experts_per_tok + @norm_topk_prob = args.norm_topk_prob + @use_expert_bias = args.use_expert_bias + + self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false) + self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, intermediate_size, @num_experts) + self.expert_bias = MLX::Core.zeros([@num_experts]) if @use_expert_bias + end + + def call(x) + mx = MLX::Core + + gates = gate.call(x).astype(mx.float32) + gates = mx.softmax(gates, -1) + gates = gates + expert_bias if @use_expert_bias + + k = [[@top_k.to_i, 1].max, @num_experts].min + inds = mx.argpartition(gates, -k, -1) + take_ids = mx.array((@num_experts - k...@num_experts).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + scores = mx.take_along_axis(gates, inds, -1) + if @norm_topk_prob + scores = scores / (mx.expand_dims(mx.sum(scores, -1), -1) + 1e-20) + end + scores = scores.astype(x.dtype) + + y = switch_mlp.call(x, inds) + mx.sum(y * mx.expand_dims(scores, -1), -2) + end + end + + class DecoderLayer < MLX::NN::Module + attr_reader :is_attention_layer + + def initialize(args, layer_idx) + super() + @is_attention_layer = args.full_attn_idxs.include?(layer_idx) + + if @is_attention_layer + self.self_attn = Attention.new(args) + else + self.conv = ShortConv.new(args, layer_idx) + end + + self.feed_forward = if layer_idx < args.num_dense_layers + MLP.new(args, intermediate_size: args.intermediate_size) + else + SparseMoeBlock.new(args) + end + + self.operator_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.norm_eps) + self.ffn_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = if @is_attention_layer + self_attn.call(operator_norm.call(x), mask: mask, cache: cache) + else + conv.call(operator_norm.call(x), mask: mask, cache: cache) + end + + h = x + r + h + feed_forward.call(ffn_norm.call(h)) + end + end + + class Lfm2MoeModel < MLX::NN::Module + def initialize(args) + super() + @args = args + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |i| DecoderLayer.new(args, i) } + self.embedding_norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.norm_eps) + + self.fa_idx = args.full_attn_idxs[0] || 0 + self.conv_idx = 0 + args.num_hidden_layers.times do |i| + if args.full_attn_idxs.include?(i) + self.conv_idx += 1 + else + break + end + end + self.conv_idx = [conv_idx, args.num_hidden_layers - 1].min + end + + def call(inputs, cache: nil, input_embeddings: nil) + h = input_embeddings || embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + attn_mask = _create_attention_mask(h, layer_cache[fa_idx]) + conv_mask = _create_ssm_mask(h, layer_cache[conv_idx]) + + layers.each_with_index do |layer, i| + mask = layer.is_attention_layer ? attn_mask : conv_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + embedding_norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + + def _create_ssm_mask(h, cache = nil) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + + nil + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Lfm2MoeModel.new(args) + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = model.call(inputs, cache: cache, input_embeddings: input_embeddings) + model.embed_tokens.as_linear(out) + end + + def sanitize(weights) + mx = MLX::Core + sanitized = {} + + weights.each do |name, param| + current = param + if name.include?("conv.weight") && _transpose_conv_weight?(param) + current = mx.swapaxes(param, 1, 2) + end + + key = name + { + "w1.weight" => "gate_proj.weight", + "w2.weight" => "down_proj.weight", + "w3.weight" => "up_proj.weight", + }.each do |old_name, new_name| + key = key.gsub(old_name, new_name) if key.include?(old_name) + end + + sanitized[key] = current + end + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}" + %w[gate_proj down_proj up_proj].each do |projection| + first_key = "#{prefix}.feed_forward.experts.0.#{projection}.weight" + next unless sanitized.key?(first_key) + + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.feed_forward.experts.#{expert_idx}.#{projection}.weight" + end + next unless expert_keys.all? { |k| sanitized.key?(k) } + + stacked = expert_keys.map { |k| sanitized.delete(k) } + sanitized["#{prefix}.feed_forward.switch_mlp.#{projection}.weight"] = mx.stack(stacked) + end + end + + sanitized + end + + def layers + model.layers + end + + def make_cache + layers.map do |layer| + if layer.is_attention_layer + MlxLm::KVCache.new + else + MlxLm::ArraysCache.new(1) + end + end + end + + def quant_predicate + lambda do |path, _| + if path.end_with?("feed_forward.gate") + { group_size: 64, bits: 8 } + else + true + end + end + end + + def cast_predicate + lambda { |k| !k.include?("expert_bias") } + end + + private + + def _transpose_conv_weight?(param) + return false unless param.respond_to?(:shape) + return false unless param.shape.is_a?(Array) + return false unless param.shape.length >= 3 + + param.shape[-1] > param.shape[1] + end + end + + Models.register("lfm2_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/lfm2_vl.rb b/lib/mlx_lm/models/lfm2_vl.rb new file mode 100644 index 0000000..da62224 --- /dev/null +++ b/lib/mlx_lm/models/lfm2_vl.rb @@ -0,0 +1,67 @@ +require_relative "lfm2" + +module MlxLm + module Models + module Lfm2VL + class ModelArgs < BaseModelArgs + field :model_type, default: "lfm2-vl" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + new(model_type: params["model_type"] || params[:model_type], text_config: params) + end + + def initialize(**kwargs) + super + @text_config = _stringify_keys(@text_config || {}) + @text_config["tie_word_embeddings"] = false + end + + private + + def _stringify_keys(hash) + hash.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Lfm2::Model.new(Lfm2::ModelArgs.from_dict(args.text_config)) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + end + + def sanitize(weights) + nested = MLX::Utils.tree_unflatten(weights.to_a) + if nested.is_a?(Hash) + nested.delete("vision_tower") + nested.delete("multi_modal_projector") + end + MLX::Utils.tree_flatten(nested, destination: {}) + end + + def layers + language_model.layers + end + + def make_cache + return language_model.make_cache if language_model.respond_to?(:make_cache) + + nil + end + end + + Models.register("lfm2-vl", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/lille_130m.rb b/lib/mlx_lm/models/lille_130m.rb new file mode 100644 index 0000000..c38bf08 --- /dev/null +++ b/lib/mlx_lm/models/lille_130m.rb @@ -0,0 +1,148 @@ +module MlxLm + module Models + module Lille130m + class ModelArgs < BaseModelArgs + field :model_type, default: "lille-130m" + field :block_size + field :layer_norm_eps + field :n_embd + field :n_head + field :n_kv_heads + field :n_layer + field :rope_theta + field :vocab_size + field :tie_word_embeddings, default: true + end + + class Lille130mAttention < MLX::NN::Module + def initialize(args) + super() + @n_head = args.n_head + @n_kv_heads = args.n_kv_heads + @head_dim = args.n_embd / @n_head + @scale = @head_dim**(-0.5) + + self.qkv_proj = MLX::NN::Linear.new( + args.n_embd, + (@n_head + (2 * @n_kv_heads)) * @head_dim, + bias: false + ) + self.out_proj = MLX::NN::Linear.new(@n_head * @head_dim, args.n_embd, bias: false) + self.norm = MLX::NN::RMSNorm.new(args.n_embd, eps: args.layer_norm_eps) + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = qkv_proj.call(norm.call(x)) + q_size = @n_head * @head_dim + kv_size = @n_kv_heads * @head_dim + queries, keys, values = mx.split(qkv, [q_size, q_size + kv_size], 2) + + queries = queries.reshape([b, l, @n_head, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_head * @head_dim]) + out_proj.call(output) + end + end + + class Lille130mMLP < MLX::NN::Module + def initialize(args) + super() + hidden_dim = 256 * ((8 * args.n_embd / 3) / 256.0).round + hidden_dim = 256 if hidden_dim.zero? + + self.norm = MLX::NN::RMSNorm.new(args.n_embd, eps: args.layer_norm_eps) + self.gate_proj = MLX::NN::Linear.new(args.n_embd, hidden_dim, bias: false) + self.up_proj = MLX::NN::Linear.new(args.n_embd, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, args.n_embd, bias: false) + end + + def call(x) + h = norm.call(x) + down_proj.call(Activations.swiglu(gate_proj.call(h), up_proj.call(h))) + end + end + + class Lille130Block < MLX::NN::Module + def initialize(args) + super() + self.attention = Lille130mAttention.new(args) + self.feed_forward = Lille130mMLP.new(args) + end + + def call(x, mask: nil, cache: nil) + h = x + attention.call(x, mask: mask, cache: cache) + h + feed_forward.call(h) + end + end + + class Lille130 < MLX::NN::Module + def initialize(args) + super() + self.tok_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.n_embd) + self.layers = Array.new(args.n_layer) { Lille130Block.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.n_embd, eps: args.layer_norm_eps) + end + + def call(inputs, cache: nil) + h = tok_embeddings.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + tok_embeddings.as_linear(norm.call(h)) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.args = args + self.model_type = args.model_type + self.transformer = Lille130.new(args) + end + + def call(inputs, cache: nil) + transformer.call(inputs, cache: cache) + end + + def layers + transformer.layers + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("rotary_emb") } + end + end + + Models.register("lille-130m", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/llama4.rb b/lib/mlx_lm/models/llama4.rb new file mode 100644 index 0000000..a513593 --- /dev/null +++ b/lib/mlx_lm/models/llama4.rb @@ -0,0 +1,357 @@ +require_relative "activations" +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module Llama4 + class TextArgs < BaseModelArgs + field :model_type, default: "llama4_text" + field :attention_bias, default: false + field :attention_chunk_size, default: 1024 + field :head_dim, default: nil + field :hidden_size + field :interleave_moe_layer_step, default: 1 + field :intermediate_size + field :intermediate_size_mlp, default: nil + field :max_position_embeddings, default: 4096 + field :num_attention_heads + field :num_experts_per_tok, default: 1 + field :num_hidden_layers + field :num_key_value_heads, default: nil + field :num_local_experts, default: 1 + field :rms_norm_eps, default: 1e-5 + field :rope_scaling, default: nil + field :rope_theta, default: 10_000.0 + field :use_qk_norm, default: false + field :vocab_size + field :attn_temperature_tuning, default: 4 + field :floor_scale, default: 8192 + field :attn_scale, default: 0.1 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @intermediate_size_mlp ||= @intermediate_size + @attention_chunk_size = [@attention_chunk_size.to_i, 1].max + @interleave_moe_layer_step = [@interleave_moe_layer_step.to_i, 1].max + end + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "llama4" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + new(model_type: params["model_type"] || params[:model_type], text_config: params) + end + + def initialize(**kwargs) + super + @text_config = _to_text_args(@text_config || {}) + end + + private + + def _to_text_args(config) + return config if config.is_a?(TextArgs) + + normalized = {} + config.each { |key, value| normalized[key.to_s] = value } + TextArgs.from_dict(normalized) + end + end + + class Attention < MLX::NN::Module + def initialize(args, layer_idx) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + @use_rope = ((layer_idx + 1) % 4) != 0 + @attn_temperature_tuning = args.attn_temperature_tuning + @floor_scale = args.floor_scale + @attn_scale = args.attn_scale + @use_qk_norm = args.use_qk_norm && @use_rope + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias) + + if @use_rope + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + true, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + offset = cache ? cache.offset : 0 + if @use_rope + queries = rope.call(queries, offset: offset) + keys = rope.call(keys, offset: offset) + end + + if @use_qk_norm + queries = mx.rms_norm(queries, nil, 1e-6) + keys = mx.rms_norm(keys, nil, 1e-6) + end + + if @attn_temperature_tuning && !@use_rope + attn_scales = (mx.log(mx.floor(mx.arange(offset + 1, offset + l + 1) / @floor_scale) + 1.0) * @attn_scale) + 1.0 + queries = (queries * attn_scales.reshape([l, 1])).astype(queries.dtype) + end + + keys, values = cache.update_and_fetch(keys, values) if cache + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args, intermediate_size = nil) + super() + dim = args.hidden_size + hidden_dim = intermediate_size || args.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class MoE < MLX::NN::Module + def initialize(args) + super() + @top_k = args.num_experts_per_tok + raise ArgumentError, "Only 1 expert per token supported" unless @top_k == 1 + + @num_experts = args.num_local_experts + self.experts = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.intermediate_size, + @num_experts + ) + self.router = MLX::NN::Linear.new(args.hidden_size, @num_experts, bias: false) + self.shared_expert = MLP.new(args) + end + + def call(x) + mx = MLX::Core + logits = router.call(x) + + indices = mx.argpartition(logits * -1.0, @top_k - 1, -1) + take_ids = mx.array((0...@top_k).to_a, dtype: mx.int32) + indices = mx.take(indices, take_ids, -1) + scores = mx.take_along_axis(logits, indices, -1) + scores = mx.sigmoid(scores.astype(mx.float32)).astype(x.dtype) + + out = mx.squeeze(experts.call(x * scores, indices), 2) + out + shared_expert.call(x) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.self_attn = Attention.new(args, layer_idx) + is_moe_layer = (layer_idx % args.interleave_moe_layer_step) == (args.interleave_moe_layer_step - 1) + if is_moe_layer + self.feed_forward = MoE.new(args) + else + self.feed_forward = MLP.new(args, args.intermediate_size_mlp) + end + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = feed_forward.call(post_attention_layernorm.call(h)) + h + r + end + end + + class LlamaModel < MLX::NN::Module + def initialize(args) + super() + @attention_chunk_size = args.attention_chunk_size + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |i| TransformerBlock.new(args, i) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + mx = MLX::Core + h = embed_tokens.call(inputs) + layer_cache = cache || Array.new(layers.length) + + if cache + cache.each_with_index do |c, idx| + next unless ((idx + 1) % 4) != 0 + next unless c && c.respond_to?(:maybe_trim_front) + + c.maybe_trim_front + end + first_cache = cache[0] + start = first_cache&.respond_to?(:start_position) ? first_cache.start_position : 0 + offset = first_cache&.respond_to?(:offset) ? first_cache.offset : 0 + else + start = 0 + offset = 0 + end + + finish = offset + h.shape[1] + linds = mx.arange(start, finish) + rinds = mx.arange(offset, finish).reshape([h.shape[1], 1]) + + block_pos = mx.abs( + mx.floor_divide(linds, @attention_chunk_size) - + mx.floor_divide(rinds, @attention_chunk_size) + ) + token_pos = mx.less_equal(linds, rinds) + chunk_mask = mx.logical_and(mx.equal(block_pos, 0), token_pos) + global_mask = _create_attention_mask(h, layer_cache[3]) + + layers.each_with_index do |layer, idx| + use_chunked_attention = ((idx + 1) % 4) != 0 + mask = use_chunked_attention ? chunk_mask : global_mask + h = layer.call(h, mask: mask, cache: layer_cache[idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class LanguageModel < MLX::NN::Module + def initialize(args) + super() + self.model = LlamaModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + lm_head.call(model.call(inputs, cache: cache)) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = LanguageModel.new(args.text_config) + end + + def call(inputs, cache: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + mx = MLX::Core + + sanitized = {} + weights.each do |key, value| + next if _multimodal_key?(key) + + sanitized[key] = value + end + + @args.text_config.num_hidden_layers.to_i.times do |layer_idx| + prefix = "language_model.model.layers.#{layer_idx}.feed_forward.experts" + + gate_up = _pop_first( + sanitized, + ["#{prefix}.gate_up_proj", "#{prefix}.gate_up_proj.weight"] + ) + if gate_up + split = gate_up.shape[-1] / 2 + gate_proj, up_proj = mx.split(gate_up, [split], -1) + sanitized["#{prefix}.gate_proj.weight"] = mx.swapaxes(gate_proj, 1, 2) + sanitized["#{prefix}.up_proj.weight"] = mx.swapaxes(up_proj, 1, 2) + end + + down_proj = _pop_first( + sanitized, + ["#{prefix}.down_proj", "#{prefix}.down_proj.weight"] + ) + if down_proj + sanitized["#{prefix}.down_proj.weight"] = mx.swapaxes(down_proj, 1, 2) + end + end + + sanitized + end + + def layers + language_model.model.layers + end + + def make_cache + chunk_size = [@args.text_config.attention_chunk_size.to_i, 1].max + Array.new(layers.length) do |i| + if ((i + 1) % 4) != 0 + MlxLm::ChunkedKVCache.new(chunk_size) + else + MlxLm::KVCache.new + end + end + end + + private + + def _pop_first(weights, keys) + keys.each do |key| + return weights.delete(key) if weights.key?(key) + end + nil + end + + def _multimodal_key?(key) + key_name = key.to_s + key_name.include?("vision_model") || + key_name.include?("vision_tower") || + key_name.include?("multi_modal_projector") + end + end + + Models.register("llama4", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/llama4_text.rb b/lib/mlx_lm/models/llama4_text.rb new file mode 100644 index 0000000..9acb368 --- /dev/null +++ b/lib/mlx_lm/models/llama4_text.rb @@ -0,0 +1,195 @@ +module MlxLm + module Models + module Llama4Text + class ModelArgs < BaseModelArgs + field :model_type, default: "llama4_text" + field :hidden_size + field :num_attention_heads + field :num_hidden_layers + field :vocab_size + field :intermediate_size, default: nil + field :intermediate_size_mlp, default: nil + field :num_key_value_heads, default: nil + field :rms_norm_eps, default: 1e-5 + field :rope_theta, default: 10_000.0 + field :head_dim, default: nil + field :tie_word_embeddings, default: true + field :no_rope_layers, default: nil + field :use_qk_norm, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @intermediate_size_mlp ||= @intermediate_size + + if @no_rope_layers.nil? + @no_rope_layers = Array.new(@num_hidden_layers, 1) + elsif @no_rope_layers.length != @num_hidden_layers + raise ArgumentError, "`no_rope_layers` length mismatch" + end + end + end + + class Attention < MLX::NN::Module + def initialize(args, use_rope) + super() + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + @use_rope = !!use_rope + @use_qk_norm = !!args.use_qk_norm + @rms_norm_eps = args.rms_norm_eps + + self.q_proj = MLX::NN::Linear.new(args.hidden_size, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(args.hidden_size, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(args.hidden_size, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, args.hidden_size, bias: false) + + if @use_rope + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: true, base: args.rope_theta) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]) + + if @use_qk_norm + queries = mx.rms_norm(queries, nil, @rms_norm_eps) + keys = mx.rms_norm(keys, nil, @rms_norm_eps) + end + + queries = queries.transpose([0, 2, 1, 3]) + keys = keys.transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if @use_rope + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + end + + keys, values = cache.update_and_fetch(keys, values) if cache + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, intermediate_size) + super() + self.gate_proj = MLX::NN::Linear.new(dim, intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(intermediate_size, dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, use_rope) + super() + self.self_attn = Attention.new(args, use_rope) + self.feed_forward = MLP.new(args.hidden_size, args.intermediate_size_mlp) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = feed_forward.call(post_attention_layernorm.call(h)) + h + r + end + end + + class LanguageModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) do |i| + TransformerBlock.new(args, args.no_rope_layers[i]) + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = LanguageModel.new(args) + self.output = nil + unless args.tie_word_embeddings + self.output = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + h = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(h) + else + output.call(h) + end + end + + def sanitize(weights) + sanitized = weights.reject do |k, _| + k.include?("self_attn.rotary_emb.inv_freq") || k.include?("self_attn.rope.inv_freq") + end + if @args.tie_word_embeddings + sanitized.delete("output.weight") + sanitized.delete("lm_head.weight") + end + sanitized + end + + def layers + model.layers + end + end + + Models.register("llama4_text", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/longcat_flash.rb b/lib/mlx_lm/models/longcat_flash.rb new file mode 100644 index 0000000..fcdf8a9 --- /dev/null +++ b/lib/mlx_lm/models/longcat_flash.rb @@ -0,0 +1,153 @@ +require_relative "glm4_moe_lite" + +module MlxLm + module Models + module LongcatFlash + class ModelArgs < Glm4MoeLite::ModelArgs + field :model_type, default: "longcat_flash" + field :hidden_dim, default: nil + field :ffn_hidden_size, default: nil + field :num_layers, default: nil + field :num_heads, default: nil + field :num_kv_heads, default: nil + field :num_experts, default: nil + field :num_local_experts, default: nil + field :num_shared_experts, default: nil + field :top_k, default: nil + field :score_function, default: nil + + def self.from_dict(params) + normalized = params.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + + { + "hidden_dim" => "hidden_size", + "ffn_hidden_size" => "intermediate_size", + "num_layers" => "num_hidden_layers", + "num_heads" => "num_attention_heads", + "num_kv_heads" => "num_key_value_heads", + "num_local_experts" => "n_routed_experts", + "num_experts" => "n_routed_experts", + "num_shared_experts" => "n_shared_experts", + "top_k" => "num_experts_per_tok", + "score_function" => "scoring_func", + }.each do |source_key, target_key| + next unless normalized.key?(source_key) + + normalized[target_key] = normalized[source_key] unless normalized.key?(target_key) + end + + normalized["model_type"] ||= "longcat_flash" + super(normalized) + end + + def initialize(**kwargs) + super + @hidden_size = @hidden_dim if kwargs.key?(:hidden_dim) && !kwargs.key?(:hidden_size) && !@hidden_dim.nil? + @intermediate_size = @ffn_hidden_size if kwargs.key?(:ffn_hidden_size) && !kwargs.key?(:intermediate_size) && !@ffn_hidden_size.nil? + @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !kwargs.key?(:num_hidden_layers) && !@num_layers.nil? + @num_attention_heads = @num_heads if kwargs.key?(:num_heads) && !kwargs.key?(:num_attention_heads) && !@num_heads.nil? + @num_key_value_heads = @num_kv_heads if kwargs.key?(:num_kv_heads) && !kwargs.key?(:num_key_value_heads) && !@num_kv_heads.nil? + @n_routed_experts = @num_local_experts if kwargs.key?(:num_local_experts) && !kwargs.key?(:n_routed_experts) && !@num_local_experts.nil? + @n_routed_experts = @num_experts if kwargs.key?(:num_experts) && !kwargs.key?(:n_routed_experts) && !kwargs.key?(:num_local_experts) && !@num_experts.nil? + @n_shared_experts = @num_shared_experts if kwargs.key?(:num_shared_experts) && !kwargs.key?(:n_shared_experts) && !@num_shared_experts.nil? + @num_experts_per_tok = @top_k if kwargs.key?(:top_k) && !kwargs.key?(:num_experts_per_tok) && !@top_k.nil? + @scoring_func = @score_function if kwargs.key?(:score_function) && !kwargs.key?(:scoring_func) && !@score_function.nil? + @num_key_value_heads ||= @num_attention_heads + end + + def to_glm4_moe_lite_dict + { + "model_type" => @model_type, + "vocab_size" => @vocab_size, + "hidden_size" => @hidden_size, + "intermediate_size" => @intermediate_size, + "moe_intermediate_size" => @moe_intermediate_size, + "num_hidden_layers" => @num_hidden_layers, + "num_attention_heads" => @num_attention_heads, + "num_key_value_heads" => @num_key_value_heads, + "n_shared_experts" => @n_shared_experts, + "n_routed_experts" => @n_routed_experts, + "routed_scaling_factor" => @routed_scaling_factor, + "kv_lora_rank" => @kv_lora_rank, + "q_lora_rank" => @q_lora_rank, + "qk_rope_head_dim" => @qk_rope_head_dim, + "qk_nope_head_dim" => @qk_nope_head_dim, + "v_head_dim" => @v_head_dim, + "topk_method" => @topk_method, + "scoring_func" => @scoring_func, + "norm_topk_prob" => @norm_topk_prob, + "n_group" => @n_group, + "topk_group" => @topk_group, + "num_experts_per_tok" => @num_experts_per_tok, + "moe_layer_freq" => @moe_layer_freq, + "first_k_dense_replace" => @first_k_dense_replace, + "max_position_embeddings" => @max_position_embeddings, + "rms_norm_eps" => @rms_norm_eps, + "rope_theta" => @rope_theta, + "rope_scaling" => @rope_scaling, + "attention_bias" => @attention_bias, + "attention_dropout" => @attention_dropout, + "partial_rotary_factor" => @partial_rotary_factor, + "tie_word_embeddings" => @tie_word_embeddings, + "num_nextn_predict_layers" => @num_nextn_predict_layers, + "quantization" => @quantization, + } + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = Glm4MoeLite::Model.new( + Glm4MoeLite::ModelArgs.from_dict(args.to_glm4_moe_lite_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + flat_weights.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache) + + nil + end + + def cast_predicate + return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate) + + lambda { |_key| true } + end + + private + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub(".attention.", ".self_attn.") + mapped = mapped.gsub(".block_sparse_moe.", ".mlp.") + mapped = mapped.gsub(".mlp.router.", ".mlp.gate.") + mapped + end + end + + Models.register("longcat_flash", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/longcat_flash_ngram.rb b/lib/mlx_lm/models/longcat_flash_ngram.rb new file mode 100644 index 0000000..9e8a6f7 --- /dev/null +++ b/lib/mlx_lm/models/longcat_flash_ngram.rb @@ -0,0 +1,137 @@ +require_relative "longcat_flash" + +module MlxLm + module Models + module LongcatFlashNgram + class ModelArgs < LongcatFlash::ModelArgs + field :model_type, default: "longcat_flash_ngram" + field :attention_method, default: nil + field :zero_expert_type, default: "identity" + field :moe_topk, default: nil + field :expert_ffn_hidden_size, default: nil + field :zero_expert_num, default: nil + field :num_layers, default: nil + field :ngram_vocab_size_ratio, default: 78 + field :emb_neighbor_num, default: 4 + field :emb_split_num, default: 4 + field :mla_scale_q_lora, default: nil + field :mla_scale_kv_lora, default: nil + field :router_bias, default: false + + def self.from_dict(params) + normalized = params.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + + { + "num_layers" => "num_hidden_layers", + "moe_topk" => "num_experts_per_tok", + "expert_ffn_hidden_size" => "moe_intermediate_size", + }.each do |source_key, target_key| + next unless normalized.key?(source_key) + + normalized[target_key] = normalized[source_key] unless normalized.key?(target_key) + end + + if normalized.key?("n_routed_experts") && normalized.key?("zero_expert_num") && !normalized.key?("num_local_experts") + normalized["num_local_experts"] = normalized["n_routed_experts"].to_i + normalized["zero_expert_num"].to_i + end + + if normalized.key?("num_attention_heads") && !normalized.key?("num_key_value_heads") && !normalized.key?("num_kv_heads") + normalized["num_key_value_heads"] = normalized["num_attention_heads"] + end + + normalized["model_type"] ||= "longcat_flash_ngram" + super(normalized) + end + + def initialize(**kwargs) + super + @num_hidden_layers = @num_layers if kwargs.key?(:num_layers) && !kwargs.key?(:num_hidden_layers) && !@num_layers.nil? + @num_experts_per_tok = @moe_topk if kwargs.key?(:moe_topk) && !kwargs.key?(:num_experts_per_tok) && !@moe_topk.nil? + @moe_intermediate_size = @expert_ffn_hidden_size if kwargs.key?(:expert_ffn_hidden_size) && !kwargs.key?(:moe_intermediate_size) && !@expert_ffn_hidden_size.nil? + + if kwargs.key?(:zero_expert_num) && !@zero_expert_num.nil? && !kwargs.key?(:num_local_experts) && !kwargs.key?(:n_routed_experts) && !@n_routed_experts.nil? + @n_routed_experts = @n_routed_experts.to_i + @zero_expert_num.to_i + end + + if kwargs.key?(:num_attention_heads) && !kwargs.key?(:num_key_value_heads) && !kwargs.key?(:num_kv_heads) + @num_key_value_heads = @num_attention_heads + end + + @num_key_value_heads ||= @num_attention_heads + end + + def to_longcat_flash_dict + routed_experts = @n_routed_experts + if !@zero_expert_num.nil? && !routed_experts.nil? + routed_experts = routed_experts.to_i + @zero_expert_num.to_i + end + + dict = to_glm4_moe_lite_dict + dict["model_type"] = @model_type + dict["n_routed_experts"] = routed_experts unless routed_experts.nil? + dict + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = LongcatFlash::Model.new( + LongcatFlash::ModelArgs.from_dict(args.to_longcat_flash_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + flat_weights.each do |key, value| + remapped[_to_longcat_flash_key(key)] = value + end + + sanitized = wrapped_model.sanitize(remapped) + restored = {} + sanitized.each do |key, value| + restored[_from_longcat_flash_key(key)] = value + end + restored + end + + def layers + wrapped_model.layers + end + + def make_cache + return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache) + + nil + end + + def cast_predicate + return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate) + + lambda { |_key| true } + end + + private + + def _to_longcat_flash_key(key) + key.to_s.gsub("model.ngram_embeddings.word_embeddings.", "model.embed_tokens.") + end + + def _from_longcat_flash_key(key) + key.to_s.gsub("model.embed_tokens.", "model.ngram_embeddings.word_embeddings.") + end + end + + Models.register("longcat_flash_ngram", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/mamba.rb b/lib/mlx_lm/models/mamba.rb new file mode 100644 index 0000000..3981db1 --- /dev/null +++ b/lib/mlx_lm/models/mamba.rb @@ -0,0 +1,301 @@ +require_relative "activations" +require_relative "cache" + +module MlxLm + module Models + module Mamba + class ModelArgs < BaseModelArgs + field :model_type, default: "mamba" + field :vocab_size + field :hidden_size, default: nil + field :intermediate_size, default: nil + field :state_size, default: nil + field :num_hidden_layers, default: nil + field :conv_kernel, default: nil + field :use_bias, default: nil + field :use_conv_bias, default: nil + field :time_step_rank, default: "auto" + field :tie_word_embeddings, default: true + field :use_bcdt_rms, default: false + field :mixer_rms_eps, default: 1e-6 + + field :d_model, default: nil + field :d_inner, default: nil + field :d_state, default: nil + field :n_layer, default: nil + field :n_layers, default: nil + field :d_conv, default: nil + field :bias, default: nil + field :conv_bias, default: nil + + def initialize(**kwargs) + super + + @hidden_size ||= @d_model + @intermediate_size ||= @d_inner + @state_size ||= @d_state + @num_hidden_layers ||= @n_layer + @num_hidden_layers ||= @n_layers + @conv_kernel ||= @d_conv + @use_bias = @bias if @use_bias.nil? + @use_conv_bias = @conv_bias if @use_conv_bias.nil? + + @time_step_rank = (@hidden_size.to_f / 16.0).ceil if @time_step_rank == "auto" + @use_bcdt_rms = true if @model_type == "falcon_mamba" + + @hidden_size ||= 768 + @intermediate_size ||= 1536 + @state_size ||= 16 + @num_hidden_layers ||= 24 + @conv_kernel ||= 4 + @use_bias = true if @use_bias.nil? + @use_conv_bias = true if @use_conv_bias.nil? + end + end + + class MambaBlock < MLX::NN::Module + def initialize(args) + super() + + @hidden_size = args.hidden_size + @ssm_state_size = args.state_size + @conv_kernel_size = args.conv_kernel + @intermediate_size = args.intermediate_size + @time_step_rank = args.time_step_rank.to_i + @use_conv_bias = args.use_conv_bias + @use_bcdt_rms = args.use_bcdt_rms + @mixer_rms_eps = args.mixer_rms_eps + + self.in_proj = MLX::NN::Linear.new( + @hidden_size, + @intermediate_size * 2, + bias: args.use_bias + ) + + self.conv1d = MLX::NN::Conv1d.new( + @intermediate_size, + @intermediate_size, + @conv_kernel_size, + groups: @intermediate_size, + bias: @use_conv_bias, + padding: 0 + ) + + self.x_proj = MLX::NN::Linear.new( + @intermediate_size, + @time_step_rank + 2 * @ssm_state_size, + bias: false + ) + self.dt_proj = MLX::NN::Linear.new(@time_step_rank, @intermediate_size, bias: true) + + mx = MLX::Core + a = mx.repeat( + mx.arange(1.0, @ssm_state_size + 1.0, 1.0).reshape([1, @ssm_state_size]), + @intermediate_size, + 0 + ) + self.a_log = mx.log(a) + self.d = mx.ones([@intermediate_size]) + + self.out_proj = MLX::NN::Linear.new( + @intermediate_size, + @hidden_size, + bias: args.use_bias + ) + end + + def call(x, cache) + if cache.nil? + conv_cache = nil + state_cache = nil + else + conv_cache = cache[0] + state_cache = cache[1] + end + + output, new_conv_cache, new_state_cache = _process_sequence(x, conv_cache, state_cache) + + if cache.is_a?(MlxLm::ArraysCache) + cache[0] = new_conv_cache + cache[1] = new_state_cache + end + + output + end + + def ssm_step(x, a, state = nil) + mx = MLX::Core + + delta_bc = x_proj.call(x) + delta, b, c = mx.split( + delta_bc, + [@time_step_rank, @time_step_rank + @ssm_state_size], + -1 + ) + + if @use_bcdt_rms + delta = _rms_norm(delta, eps: @mixer_rms_eps) + b = _rms_norm(b, eps: @mixer_rms_eps) + c = _rms_norm(c, eps: @mixer_rms_eps) + end + + delta = MLX::NN.softplus(dt_proj.call(delta)) + new_state = mx.expand_dims(delta * x, -1) * mx.expand_dims(b, 1) + + unless state.nil? + new_state = new_state + state * mx.exp(mx.expand_dims(delta, -1) * a) + end + + y = mx.squeeze(mx.matmul(new_state, mx.expand_dims(c, -1)), 2) + y = y + d * x + + [y, new_state] + end + + private + + def _process_sequence(x, conv_cache, state_cache) + mx = MLX::Core + + xz = in_proj.call(x) + x_part, z = mx.split(xz, 2, -1) + + if conv_cache.nil? + x_full = mx.pad( + x_part, + [ + [0, 0], + [@conv_kernel_size - 1, 0], + [0, 0], + ] + ) + else + x_full = mx.concatenate([conv_cache, x_part], 1) + end + + conv_out = conv1d.call(x_full) + + n_keep = @conv_kernel_size - 1 + new_conv_cache = if n_keep > 0 + split_at = x_full.shape[1] - n_keep + mx.split(x_full, [split_at], 1)[1] + else + mx.zeros([x_full.shape[0], 0, x_full.shape[2]], x_full.dtype) + end + + x_part = MLX::NN.silu(conv_out) + a = mx.multiply(-1.0, mx.exp(a_log)) + + current_state = state_cache + ys = [] + x_part.shape[1].times do |t| + x_t = _slice_step(x_part, t) + y_t, current_state = ssm_step(x_t, a, current_state) + ys << y_t + end + + y = mx.stack(ys, 1) + out = out_proj.call(Activations.swiglu(z, y)) + + [out, new_conv_cache, current_state] + end + + def _slice_step(array, idx) + mx = MLX::Core + tail = idx.zero? ? array : mx.split(array, [idx], 1)[1] + mx.squeeze(mx.split(tail, [1], 1)[0], 1) + end + + def _rms_norm(x, eps:) + mx = MLX::Core + variance = mx.mean(mx.square(x), -1, true) + x * mx.rsqrt(variance + eps) + end + end + + class ResidualBlock < MLX::NN::Module + def initialize(args) + super() + self.mixer = MambaBlock.new(args) + self.norm = MLX::NN::RMSNorm.new(args.hidden_size) + end + + def call(x, cache) + mixer.call(norm.call(x), cache) + x + end + end + + class MambaModel < MLX::NN::Module + def initialize(args) + super() + self.embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { ResidualBlock.new(args) } + self.norm_f = MLX::NN::RMSNorm.new(args.hidden_size) + end + + def call(x, cache) + hidden = embeddings.call(x) + layer_cache = cache || [nil] * layers.length + + layers.each_with_index do |layer, i| + hidden = layer.call(hidden, layer_cache[i]) + end + + norm_f.call(hidden) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.args = args + self.model_type = args.model_type + self.backbone = MambaModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + end + + def call(inputs, cache: nil) + hidden = backbone.call(inputs, cache) + + if args.tie_word_embeddings + backbone.embeddings.as_linear(hidden) + else + lm_head.call(hidden) + end + end + + def make_cache + Array.new(layers.length) { MlxLm::ArraysCache.new(2) } + end + + def layers + backbone.layers + end + + def sanitize(weights) + sanitized = {} + weights.each do |name, param| + current = param + if name.include?("conv1d.weight") && _transpose_conv_weight?(param) + current = MLX::Core.swapaxes(param, 1, 2) + end + sanitized[name] = current + end + sanitized + end + + private + + def _transpose_conv_weight?(param) + return false unless param.respond_to?(:shape) + return false unless param.shape.is_a?(Array) + return false unless param.shape.length >= 3 + + param.shape[-1] != 1 + end + end + + Models.register("mamba", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/mamba2.rb b/lib/mlx_lm/models/mamba2.rb new file mode 100644 index 0000000..fd2b190 --- /dev/null +++ b/lib/mlx_lm/models/mamba2.rb @@ -0,0 +1,292 @@ +require_relative "activations" +require_relative "cache" +require_relative "ssm" + +module MlxLm + module Models + module Mamba2 + class ModelArgs < BaseModelArgs + field :model_type, default: "mamba2" + field :num_heads + field :head_dim + field :vocab_size + field :hidden_size + field :intermediate_size, default: nil + field :state_size + field :num_hidden_layers + field :layer_norm_epsilon, default: 1e-6 + field :conv_kernel + field :n_groups + field :use_bias, default: true + field :use_conv_bias, default: true + field :tie_word_embeddings, default: true + field :time_step_limit, default: [0.001, 100.0] + field :time_step_rank, default: "auto" + field :ssm_state_size, default: nil + field :max_position_embeddings, default: 2056 + + def initialize(**kwargs) + super + + @time_step_rank = (@hidden_size.to_f / 16.0).ceil if @time_step_rank == "auto" + @ssm_state_size ||= @state_size + @intermediate_size ||= @num_heads * @head_dim + end + end + + class MambaRMSNormGated < MLX::NN::Module + def initialize(hidden_size, eps: 1e-6) + super() + @eps = eps + self.weight = MLX::Core.ones([hidden_size]) + end + + def call(hidden_states, gate = nil) + hidden_states = Activations.swiglu(gate, hidden_states) unless gate.nil? + MLX::Core.rms_norm(hidden_states, weight, @eps) + end + end + + class Mamba2Block < MLX::NN::Module + def initialize(args, layer_idx) + super() + + _ = layer_idx + @num_heads = args.num_heads + @hidden_size = args.hidden_size + @ssm_state_size = args.ssm_state_size + @conv_kernel_size = args.conv_kernel + @intermediate_size = args.num_heads * args.head_dim + @n_groups = args.n_groups + @head_dim = args.head_dim + @time_step_limit = args.time_step_limit + @heads_per_group = @num_heads / @n_groups + + @conv_dim = @intermediate_size + 2 * @n_groups * @ssm_state_size + + self.conv1d = MLX::NN::Conv1d.new( + @conv_dim, + @conv_dim, + args.conv_kernel, + padding: 0, + groups: @conv_dim, + bias: args.use_conv_bias + ) + + projection_size = @intermediate_size + @conv_dim + @num_heads + self.in_proj = MLX::NN::Linear.new(@hidden_size, projection_size, bias: args.use_bias) + + mx = MLX::Core + self.dt_bias = mx.ones([@num_heads]) + self.a_log = mx.log(mx.arange(1, @num_heads + 1, 1, mx.float32)) + self.d = mx.ones([@num_heads]) + + self.norm = MambaRMSNormGated.new(@intermediate_size, eps: args.layer_norm_epsilon) + self.out_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: args.use_bias) + end + + def call(hidden_states, mask, cache = nil) + mx = MLX::Core + + projected = in_proj.call(hidden_states) + gate, conv_input, dt = mx.split( + projected, + [@intermediate_size, @intermediate_size + @conv_dim], + -1 + ) + + conv_output = _conv(conv_input, cache, mask) + ssm_hidden, b, c = mx.split( + conv_output, + [@intermediate_size, @intermediate_size + @n_groups * @ssm_state_size], + -1 + ) + + y = _ssm(ssm_hidden, b, c, dt, cache, mask: mask) + cache.advance(y.shape[1]) if cache + + out_proj.call(norm.call(y, gate)) + end + + private + + def _conv(conv_input, cache, mask) + mx = MLX::Core + + conv_input = mx.where(mx.expand_dims(mask, -1), conv_input, 0) unless mask.nil? + + if cache + conv_state = if cache[0].nil? + mx.zeros( + [conv_input.shape[0], @conv_kernel_size - 1, @conv_dim], + conv_input.dtype + ) + else + cache[0] + end + + padded_input = mx.concatenate([conv_state, conv_input], 1) + n_keep = @conv_kernel_size - 1 + + if cache.lengths + t = padded_input.shape[1] + ends = mx.clip(cache.lengths, 0, t - n_keep) + positions = mx.expand_dims( + mx.expand_dims(ends, 1) + mx.arange(n_keep), + -1 + ) + cache[0] = mx.take_along_axis(padded_input, positions, 1) + else + if n_keep > 0 + split_at = padded_input.shape[1] - n_keep + cache[0] = mx.split(padded_input, [split_at], 1)[1] + else + cache[0] = mx.zeros([padded_input.shape[0], 0, padded_input.shape[2]], padded_input.dtype) + end + end + else + padded_input = mx.pad( + conv_input, + [ + [0, 0], + [@conv_kernel_size - 1, 0], + [0, 0], + ] + ) + end + + MLX::NN.silu(conv1d.call(padded_input)) + end + + def _ssm(hidden_states, b, c, dt, cache, mask:) + batch_size, seq_len, = hidden_states.shape + hidden_states = hidden_states.reshape( + [batch_size, seq_len, @num_heads, @head_dim] + ) + b = b.reshape([batch_size, seq_len, @n_groups, @ssm_state_size]) + c = c.reshape([batch_size, seq_len, @n_groups, @ssm_state_size]) + + if cache + state = cache[1] + lengths = cache.lengths + else + state = nil + lengths = nil + end + + y, state = SSM.ssm_update( + hidden_states, + a_log, + b, + c, + d, + dt, + dt_bias, + state: state, + time_step_limit: @time_step_limit, + mask: mask, + lengths: lengths + ) + + cache[1] = state if cache + y.reshape([batch_size, seq_len, @intermediate_size]) + end + end + + class ResidualBlock < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.mixer = Mamba2Block.new(args, layer_idx) + self.norm = MLX::NN::RMSNorm.new(args.hidden_size) + end + + def call(x, mask, cache = nil) + mixer.call(norm.call(x), mask, cache) + x + end + end + + class Mamba2Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embeddings = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |i| ResidualBlock.new(args, i) } + self.norm_f = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + end + + def call(x, cache = nil) + hidden = embeddings.call(x) + layer_cache = cache || [nil] * layers.length + + mask = _create_ssm_mask(hidden, layer_cache[0]) + layers.each_with_index do |layer, i| + hidden = layer.call(hidden, mask, layer_cache[i]) + end + + norm_f.call(hidden) + end + + private + + def _create_ssm_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + + nil + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.args = args + self.model_type = args.model_type + self.backbone = Mamba2Model.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + end + + def call(inputs, cache: nil) + hidden = backbone.call(inputs, cache) + + if args.tie_word_embeddings + backbone.embeddings.as_linear(hidden) + else + lm_head.call(hidden) + end + end + + def make_cache(batch_size: 1) + _ = batch_size + Array.new(args.num_hidden_layers) { MlxLm::ArraysCache.new(2) } + end + + def layers + backbone.layers + end + + def sanitize(weights) + sanitized = {} + weights.each do |name, param| + current = param + if name.include?("conv1d.weight") && _transpose_conv_weight?(param) + current = MLX::Core.swapaxes(param, 1, 2) + end + sanitized[name] = current + end + sanitized + end + + private + + def _transpose_conv_weight?(param) + return false unless param.respond_to?(:shape) + return false unless param.shape.is_a?(Array) + return false unless param.shape.length >= 3 + + param.shape[-1] != 1 + end + end + + Models.register("mamba2", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/mimo.rb b/lib/mlx_lm/models/mimo.rb new file mode 100644 index 0000000..7883cae --- /dev/null +++ b/lib/mlx_lm/models/mimo.rb @@ -0,0 +1,174 @@ +module MlxLm + module Models + module Mimo + class ModelArgs < BaseModelArgs + field :model_type, default: "mimo" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :num_key_value_heads, default: nil + field :max_position_embeddings, default: 32_768 + field :rope_theta, default: 10_000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + field :num_nextn_predict_layers, default: 2 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = dim / @n_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: true) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class MimoModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = MimoModel.new(args) + self.lm_head = nil + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.dup + result.delete("lm_head.weight") if @args.tie_word_embeddings + result.reject do |k, _| + k.include?("self_attn.rotary_emb.inv_freq") || + k.start_with?("model.mtp_layers.") + end + end + + def layers + model.layers + end + end + + Models.register("mimo", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/mimo_v2_flash.rb b/lib/mlx_lm/models/mimo_v2_flash.rb new file mode 100644 index 0000000..af52737 --- /dev/null +++ b/lib/mlx_lm/models/mimo_v2_flash.rb @@ -0,0 +1,491 @@ +require_relative "activations" +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module MimoV2Flash + class ModelArgs < BaseModelArgs + field :model_type, default: "mimo_v2_flash" + field :num_experts_per_tok, default: 1 + field :hybrid_layer_pattern, default: nil + field :moe_layer_freq, default: nil + field :add_swa_attention_sink_bias, default: false + field :add_full_attention_sink_bias, default: false + field :sliding_window_size, default: 4096 + field :vocab_size + field :hidden_size + field :intermediate_size + field :moe_intermediate_size + field :num_hidden_layers + field :num_attention_heads + field :num_key_value_heads, default: nil + field :n_shared_experts, default: nil + field :n_routed_experts, default: nil + field :routed_scaling_factor, default: nil + field :topk_method, default: "noaux_tc" + field :scoring_func, default: "sigmoid" + field :norm_topk_prob, default: false + field :n_group, default: 1 + field :topk_group, default: 1 + field :max_position_embeddings, default: 32768 + field :layernorm_epsilon, default: 1e-6 + field :rope_theta, default: 10_000.0 + field :swa_rope_theta, default: nil + field :swa_num_attention_heads, default: nil + field :swa_num_key_value_heads, default: nil + field :head_dim, default: nil + field :v_head_dim, default: nil + field :swa_head_dim, default: nil + field :swa_v_head_dim, default: nil + field :partial_rotary_factor, default: 1.0 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @swa_num_attention_heads ||= @num_attention_heads + @swa_num_key_value_heads ||= @num_key_value_heads + + @head_dim ||= @hidden_size / @num_attention_heads + @v_head_dim ||= @head_dim + @swa_head_dim ||= @head_dim + @swa_v_head_dim ||= @swa_head_dim + @swa_rope_theta ||= @rope_theta + + @n_routed_experts ||= 1 + @routed_scaling_factor = 1.0 if @routed_scaling_factor.nil? + @hybrid_layer_pattern ||= Array.new(@num_hidden_layers, 0) + @moe_layer_freq ||= Array.new(@num_hidden_layers, 0) + @topk_group ||= @n_group + end + end + + class Attention < MLX::NN::Module + def initialize(args, is_sliding_window) + super() + + dim = args.hidden_size + @is_sliding_window = is_sliding_window + if @is_sliding_window + @n_heads = args.swa_num_attention_heads + @n_kv_heads = args.swa_num_key_value_heads + @has_sinks = args.add_swa_attention_sink_bias + @head_dim = args.swa_head_dim + @v_head_dim = args.swa_v_head_dim + rope_theta = args.swa_rope_theta + else + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @has_sinks = args.add_full_attention_sink_bias + @head_dim = args.head_dim + @v_head_dim = args.v_head_dim + rope_theta = args.rope_theta + end + + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @v_head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @v_head_dim, dim, bias: false) + self.attention_sink_bias = if @has_sinks + MLX::Core.ones([@n_heads]) + else + nil + end + + rotary_dim = [(@head_dim * args.partial_rotary_factor.to_f).to_i, 1].max + self.rope = MLX::NN::RoPE.new( + rotary_dim, + traditional: false, + base: rope_theta + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @v_head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = _scaled_dot_product_attention(queries, keys, values, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @v_head_dim]) + o_proj.call(output) + end + + private + + def _scaled_dot_product_attention(queries, keys, values, mask) + mx = MLX::Core + + if attention_sink_bias + begin + return mx.scaled_dot_product_attention( + queries, + keys, + values, + @scale, + mask, + sinks: attention_sink_bias + ) + rescue StandardError + # Fallback when sinks are unsupported by the local MLX runtime. + end + end + + mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + end + end + + class MLP < MLX::NN::Module + def initialize(config, hidden_size: nil, intermediate_size: nil) + super() + @hidden_size = hidden_size || config.hidden_size + @intermediate_size = intermediate_size || config.intermediate_size + + self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + module_function + + def group_expert_select( + gates, + e_score_correction_bias, + top_k, + n_group, + topk_group, + routed_scaling_factor, + norm_topk_prob + ) + mx = MLX::Core + + scores = mx.sigmoid(gates.astype(mx.float32)) + orig_scores = scores + scores = scores + e_score_correction_bias + + if n_group.to_i > 1 + experts_per_group = scores.shape[-1] / n_group + scores = mx.unflatten(scores, -1, [n_group, experts_per_group]) + group_scores = mx.topk(scores, 2, -1) + group_scores = mx.expand_dims(mx.sum(group_scores, -1), -1) + + drop_count = n_group - topk_group.to_i + if drop_count > 0 + group_idx = mx.argpartition(group_scores, drop_count - 1, -2) + take_ids = mx.array((0...drop_count).to_a, dtype: mx.int32) + group_idx = mx.take(group_idx, take_ids, -2) + scores = mx.put_along_axis( + scores, + mx.stop_gradient(group_idx), + mx.array(0.0), + -2 + ) + end + + scores = mx.flatten(scores, -2, -1) + end + + k = [top_k.to_i, scores.shape[-1]].min + inds = mx.argpartition(scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + selected_scores = mx.take_along_axis(orig_scores, inds, -1) + if k > 1 && norm_topk_prob + denominator = mx.expand_dims(mx.sum(selected_scores, -1), -1) + selected_scores = selected_scores / (denominator + 1e-20) + end + + selected_scores = selected_scores * routed_scaling_factor.to_f + [inds, selected_scores] + end + + class MoEGate < MLX::NN::Module + def initialize(config) + super() + @top_k = config.num_experts_per_tok + @norm_topk_prob = config.norm_topk_prob + @n_routed_experts = config.n_routed_experts + @routed_scaling_factor = config.routed_scaling_factor || 1.0 + @n_group = config.n_group + @topk_group = config.topk_group + + raise ArgumentError, "Unsupported topk method: #{config.topk_method}" unless config.topk_method == "noaux_tc" + + mx = MLX::Core + self.weight = mx.zeros([@n_routed_experts, config.hidden_size]) + self.e_score_correction_bias = mx.zeros([@n_routed_experts]) + end + + def call(x) + mx = MLX::Core + gates = mx.matmul(x, mx.transpose(weight)) + MimoV2Flash.group_expert_select( + gates, + e_score_correction_bias, + @top_k, + @n_group, + @topk_group, + @routed_scaling_factor, + @norm_topk_prob + ) + end + end + + class MoE < MLX::NN::Module + def initialize(config) + super() + @config = config + + self.switch_mlp = SwitchLayers::SwitchGLU.new( + config.hidden_size, + config.moe_intermediate_size, + config.n_routed_experts + ) + + self.gate = MoEGate.new(config) + if config.n_shared_experts + shared_intermediate = config.moe_intermediate_size * config.n_shared_experts + self.shared_experts = MLP.new(config, intermediate_size: shared_intermediate) + end + end + + def call(x) + mx = MLX::Core + inds, scores = gate.call(x) + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores, -1), -2).astype(y.dtype) + y = y + shared_experts.call(x) if @config.n_shared_experts + y + end + end + + class DecoderLayer < MLX::NN::Module + attr_reader :is_sliding_window + + def initialize(config, is_moe, is_sliding_window) + super() + @is_sliding_window = is_sliding_window + + self.self_attn = Attention.new(config, is_sliding_window) + self.mlp = is_moe ? MoE.new(config) : MLP.new(config) + self.input_layernorm = MLX::NN::RMSNorm.new( + config.hidden_size, + eps: config.layernorm_epsilon + ) + self.post_attention_layernorm = MLX::NN::RMSNorm.new( + config.hidden_size, + eps: config.layernorm_epsilon + ) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class LanguageModel < MLX::NN::Module + def initialize(config) + super() + @hybrid_layer_pattern = config.hybrid_layer_pattern + @sliding_window_size = config.sliding_window_size + + self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size) + self.layers = Array.new(config.num_hidden_layers) do |idx| + DecoderLayer.new( + config, + config.moe_layer_freq[idx] == 1, + config.hybrid_layer_pattern[idx] == 1 + ) + end + self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.layernorm_epsilon) + self.swa_idx = @hybrid_layer_pattern.index(1) || 0 + self.ga_idx = @hybrid_layer_pattern.index(0) || 0 + end + + def call(x, cache: nil) + h = embed_tokens.call(x) + layer_cache = cache || [nil] * layers.length + + full_mask = _create_attention_mask(h, layer_cache[ga_idx]) + swa_mask = _create_attention_mask( + h, + layer_cache[swa_idx], + window_size: @sliding_window_size + ) + + layers.each_with_index do |layer, i| + mask = layer.is_sliding_window ? swa_mask : full_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset if cache.respond_to?(:offset) + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + def initialize(config) + super() + @args = config + self.model_type = config.model_type + self.model = LanguageModel.new(config) + self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + lm_head.call(out) + end + + def sanitize(weights) + mx = MLX::Core + new_weights = {} + + weights.each do |k, v| + if k.include?("weight_scale_inv") + wk = k.sub("_scale_inv", "") + if weights.key?(wk) + new_weights[wk] = _dequant(weights[wk], v) + end + elsif !new_weights.key?(k) + new_weights[k] = v + end + end + + result = new_weights + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}" + %w[gate_proj down_proj up_proj].each do |proj| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.mlp.experts.0.#{proj}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...@args.n_routed_experts).map do |expert_idx| + "#{prefix}.mlp.experts.#{expert_idx}.#{proj}.#{param}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.mlp.switch_mlp.#{proj}.#{param}"] = mx.stack(stacked) + end + end + end + + result.reject { |k, _| k.start_with?("model.mtp") } + end + + def layers + model.layers + end + + def cast_predicate + lambda { |k| !k.include?("e_score_correction_bias") } + end + + def make_cache + layers.map do |layer| + if layer.is_sliding_window + MlxLm::RotatingKVCache.new(max_size: @args.sliding_window_size) + else + MlxLm::KVCache.new + end + end + end + + private + + def _dequant(weight, scale_inv) + mx = MLX::Core + dtype = mx.bfloat16 + block_size = 128 + + dequantized = mx.from_fp8(weight, dtype: dtype) + m, n = dequantized.shape + pad_bottom = block_size * scale_inv.shape[0] - m + pad_side = block_size * scale_inv.shape[1] - n + + dequantized = mx.pad(dequantized, [[0, pad_bottom], [0, pad_side]]) + dequantized = dequantized.reshape([ + (m + pad_bottom) / block_size, + block_size, + (n + pad_side) / block_size, + block_size, + ]) + + scaled = dequantized * scale_inv.reshape([scale_inv.shape[0], 1, scale_inv.shape[1], 1]) + scaled = scaled.reshape([m + pad_bottom, n + pad_side]) + scaled = mx.split(scaled, [m], 0)[0] + scaled = mx.split(scaled, [n], 1)[0] + scaled.astype(dtype) + rescue StandardError + weight + end + end + + Models.register("mimo_v2_flash", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/minicpm.rb b/lib/mlx_lm/models/minicpm.rb new file mode 100644 index 0000000..db13654 --- /dev/null +++ b/lib/mlx_lm/models/minicpm.rb @@ -0,0 +1,169 @@ +module MlxLm + module Models + module MiniCPM + class ModelArgs < BaseModelArgs + field :model_type, default: "minicpm" + field :hidden_size + field :dim_model_base + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :num_key_value_heads + field :scale_depth + field :scale_emb + field :max_position_embeddings, default: nil + field :rope_theta, default: 1_000_000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = dim / @n_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + @residual_scale = args.scale_depth / Math.sqrt(args.num_hidden_layers) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r * @residual_scale + r = mlp.call(post_attention_layernorm.call(h)) + h + r * @residual_scale + end + end + + class MiniCPMModel < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) * @args.scale_emb + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = MiniCPMModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + end + + def call(inputs, cache: nil) + mx = MLX::Core + out = model.call(inputs, cache: cache) + + if @args.tie_word_embeddings + mx.matmul(out, model.embed_tokens.weight.T) + else + lm_head.call(out / (@args.hidden_size.to_f / @args.dim_model_base)) + end + end + + def sanitize(weights) + unless weights.key?("lm_head.weight") + weights["lm_head.weight"] = weights["model.embed_tokens.weight"] + end + weights + end + + def layers + model.layers + end + end + + Models.register("minicpm", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/minicpm3.rb b/lib/mlx_lm/models/minicpm3.rb new file mode 100644 index 0000000..7096b62 --- /dev/null +++ b/lib/mlx_lm/models/minicpm3.rb @@ -0,0 +1,237 @@ +require_relative "activations" +require_relative "rope_utils" + +module MlxLm + module Models + module MiniCPM3 + class ModelArgs < BaseModelArgs + field :model_type, default: "minicpm3" + field :hidden_size + field :dim_model_base + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :num_key_value_heads + field :q_lora_rank + field :qk_nope_head_dim + field :qk_rope_head_dim + field :kv_lora_rank + field :scale_depth + field :scale_emb + field :max_position_embeddings + field :attention_bias, default: false + field :rope_theta, default: 1_000_000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @rope_scaling ||= {} + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + @qk_rope_head_dim = args.qk_rope_head_dim + @qk_nope_head_dim = args.qk_nope_head_dim + @kv_lora_rank = args.kv_lora_rank + @num_heads = args.num_attention_heads + @hidden_size = args.hidden_size + @v_head_dim = @hidden_size / @num_heads + @q_head_dim = @qk_nope_head_dim + @qk_rope_head_dim + @kv_head_dim = @qk_nope_head_dim + @v_head_dim + @softmax_scale = @q_head_dim**(-0.5) + + self.q_a_proj = MLX::NN::Linear.new( + @hidden_size, + args.q_lora_rank, + bias: args.attention_bias + ) + self.q_a_layernorm = MLX::NN::RMSNorm.new(args.q_lora_rank, eps: args.rms_norm_eps) + self.q_b_proj = MLX::NN::Linear.new( + args.q_lora_rank, + @num_heads * @q_head_dim, + bias: false + ) + + self.kv_a_proj_with_mqa = MLX::NN::Linear.new( + @hidden_size, + @kv_lora_rank + @qk_rope_head_dim, + bias: args.attention_bias + ) + self.kv_a_layernorm = MLX::NN::RMSNorm.new(@kv_lora_rank, eps: args.rms_norm_eps) + self.kv_b_proj = MLX::NN::Linear.new( + @kv_lora_rank, + @num_heads * @kv_head_dim, + bias: false + ) + + self.o_proj = MLX::NN::Linear.new( + @num_heads * @v_head_dim, + @hidden_size, + bias: args.attention_bias + ) + + self.rope = SuScaledRoPE.new( + @qk_rope_head_dim, + base: args.rope_theta, + max_position_embeddings: args.max_position_embeddings, + original_max_position_embeddings: scaling_value(args.rope_scaling, "original_max_position_embeddings", 4096), + short_factor: scaling_value(args.rope_scaling, "short_factor", 1.0), + long_factor: scaling_value(args.rope_scaling, "long_factor", 1.0) + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _ = x.shape + + q = q_b_proj.call(q_a_layernorm.call(q_a_proj.call(x))) + q = q.reshape([b, l, @num_heads, @q_head_dim]).transpose([0, 2, 1, 3]) + q_nope, q_pe = mx.split(q, [@qk_nope_head_dim], -1) + + compressed_kv = kv_a_proj_with_mqa.call(x) + compressed_kv, k_pe = mx.split(compressed_kv, [@kv_lora_rank], -1) + k_pe = k_pe.reshape([b, l, 1, @qk_rope_head_dim]).transpose([0, 2, 1, 3]) + + kv = kv_b_proj.call(kv_a_layernorm.call(compressed_kv)) + kv = kv.reshape([b, l, @num_heads, @kv_head_dim]).transpose([0, 2, 1, 3]) + k_nope, values = mx.split(kv, [@qk_nope_head_dim], -1) + + if cache + q_pe = rope.call(q_pe, offset: cache.offset) + k_pe = rope.call(k_pe, offset: cache.offset) + else + q_pe = rope.call(q_pe) + k_pe = rope.call(k_pe) + end + + k_pe_broadcasted = mx.broadcast_to(k_pe, [b, @num_heads, l, @qk_rope_head_dim]) + queries = mx.concatenate([q_nope, q_pe], -1) + keys = mx.concatenate([k_nope, k_pe_broadcasted], -1) + + if cache + keys, values = cache.update_and_fetch(keys, values) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @softmax_scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @v_head_dim]) + o_proj.call(output) + end + + private + + def scaling_value(config, key, default) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + @residual_scale = args.scale_depth / Math.sqrt(args.num_hidden_layers) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r * @residual_scale + r = mlp.call(post_attention_layernorm.call(h)) + h + r * @residual_scale + end + end + + class MiniCPM3Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, mask: nil, cache: nil) + h = embed_tokens.call(inputs) * @args.scale_emb + layer_cache = cache || [nil] * layers.length + local_mask = mask || _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: local_mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if hidden.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = MiniCPM3Model.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + end + + def call(inputs, mask: nil, cache: nil) + out = model.call(inputs, mask: mask, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out / (@args.hidden_size.to_f / @args.dim_model_base)) + end + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + + if @args.tie_word_embeddings + result.delete("lm_head.weight") + elsif !result.key?("lm_head.weight") && result.key?("model.embed_tokens.weight") + result["lm_head.weight"] = result["model.embed_tokens.weight"] + end + + result + end + + def layers + model.layers + end + end + + Models.register("minicpm3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/minimax.rb b/lib/mlx_lm/models/minimax.rb new file mode 100644 index 0000000..860bbed --- /dev/null +++ b/lib/mlx_lm/models/minimax.rb @@ -0,0 +1,282 @@ +require_relative "switch_layers" + +module MlxLm + module Models + module Minimax + class ModelArgs < BaseModelArgs + field :model_type, default: "minimax" + field :hidden_size + field :intermediate_size + field :num_attention_heads + field :num_key_value_heads + field :max_position_embeddings + field :num_experts_per_tok + field :num_local_experts + field :shared_intermediate_size + field :num_hidden_layers + field :rms_norm_eps + field :rope_theta + field :rotary_dim + field :vocab_size + field :tie_word_embeddings, default: false + field :scoring_func, default: "sigmoid" + field :head_dim, default: nil + field :use_qk_norm, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @rotary_dim ||= @head_dim + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + @use_qk_norm = args.use_qk_norm + + self.q_proj = MLX::NN::Linear.new(dim, @num_attention_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @num_key_value_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@num_attention_heads * @head_dim, dim, bias: false) + + if @use_qk_norm + self.q_norm = MLX::NN::RMSNorm.new(@head_dim * @num_attention_heads, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim * @num_key_value_heads, eps: args.rms_norm_eps) + end + + self.rope = MLX::NN::RoPE.new(args.rotary_dim, traditional: false, base: args.rope_theta) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + if @use_qk_norm + queries = q_norm.call(queries) + keys = k_norm.call(keys) + end + + queries = queries.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + o_proj.call(output) + end + end + + class SparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + mx = MLX::Core + @num_experts_per_tok = args.num_experts_per_tok + @num_local_experts = args.num_local_experts + + self.gate = MLX::NN::Linear.new(args.hidden_size, @num_local_experts, bias: false) + self.switch_mlp = SwitchLayers::SwitchGLU.new(args.hidden_size, args.intermediate_size, @num_local_experts) + self.e_score_correction_bias = mx.zeros([@num_local_experts]) + end + + def call(x) + mx = MLX::Core + + gates = gate.call(x.astype(mx.float32)) + orig_scores = mx.sigmoid(gates) + scores = orig_scores + e_score_correction_bias + + k = [[@num_experts_per_tok.to_i, 1].max, @num_local_experts.to_i].min + inds = mx.argpartition(scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + scores = mx.take_along_axis(orig_scores, inds, -1) + scores = scores / (mx.expand_dims(mx.sum(scores, -1), -1) + 1e-20) + scores = scores.astype(x.dtype) + + y = switch_mlp.call(x, inds) + mx.sum(y * mx.expand_dims(scores, -1), -2) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.block_sparse_moe = SparseMoeBlock.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + h = x + self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h + block_sparse_moe.call(post_attention_layernorm.call(h)) + end + end + + class MiniMaxModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, mask: nil, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + local_mask = mask || _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: local_mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = MiniMaxModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + end + + def call(inputs, mask: nil, cache: nil) + out = model.call(inputs, mask: mask, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + mx = MLX::Core + dequantized = {} + + weights.each do |key, value| + if key.include?("weight_scale_inv") + weight_key = key.sub("_scale_inv", "") + next unless weights.key?(weight_key) + + dequantized[weight_key] = _dequant(weights[weight_key], value) + elsif !dequantized.key?(key) + dequantized[key] = value + end + end + + result = dequantized + return result unless result.key?("model.layers.0.block_sparse_moe.experts.0.w1.weight") + + mapping = { + "w1" => "gate_proj", + "w2" => "down_proj", + "w3" => "up_proj", + } + experts_count = @args.num_local_experts.to_i + return result if experts_count <= 0 + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}" + mapping.each do |old_name, new_name| + first_key = "#{prefix}.block_sparse_moe.experts.0.#{old_name}.weight" + next unless result.key?(first_key) + + expert_keys = (0...experts_count).map do |expert_idx| + "#{prefix}.block_sparse_moe.experts.#{expert_idx}.#{old_name}.weight" + end + next unless expert_keys.all? { |k| result.key?(k) } + + stacked = expert_keys.map { |k| result.delete(k) } + result["#{prefix}.block_sparse_moe.switch_mlp.#{new_name}.weight"] = mx.stack(stacked) + end + end + + result + end + + def layers + model.layers + end + + def cast_predicate + lambda { |key| !key.include?("e_score_correction_bias") } + end + + def quant_predicate + lambda do |path, _| + if path.end_with?("block_sparse_moe.gate") + { group_size: 64, bits: 8 } + else + true + end + end + end + + private + + def _dequant(weight, scale_inv) + mx = MLX::Core + dtype = mx.bfloat16 + block_size = 128 + + dequantized = mx.from_fp8(weight, dtype: dtype) + m, n = dequantized.shape + pad_bottom = block_size * scale_inv.shape[0] - m + pad_side = block_size * scale_inv.shape[1] - n + + dequantized = mx.pad(dequantized, [[0, pad_bottom], [0, pad_side]]) + dequantized = dequantized.reshape([ + (m + pad_bottom) / block_size, + block_size, + (n + pad_side) / block_size, + block_size, + ]) + + scaled = dequantized * scale_inv.reshape([scale_inv.shape[0], 1, scale_inv.shape[1], 1]) + scaled = scaled.reshape([m + pad_bottom, n + pad_side]) + scaled = mx.split(scaled, [m], 0)[0] + scaled = mx.split(scaled, [n], 1)[0] + scaled.astype(dtype) + rescue StandardError + weight + end + end + + Models.register("minimax", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/ministral3.rb b/lib/mlx_lm/models/ministral3.rb new file mode 100644 index 0000000..547eb2f --- /dev/null +++ b/lib/mlx_lm/models/ministral3.rb @@ -0,0 +1,304 @@ +require_relative "activations" +require_relative "cache" +require_relative "pipeline" +require_relative "rope_utils" + +module MlxLm + module Models + module Ministral3 + def self.llama4_attn_scale(size, offset, beta, max_position_embeddings) + mx = MLX::Core + positions = mx.arange(size) + offset + scale = 1.0 + beta.to_f * mx.log(1.0 + mx.floor(positions / max_position_embeddings.to_f)) + scale.reshape([size, 1]) + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "ministral3" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :head_dim, default: nil + field :max_position_embeddings, default: nil + field :num_key_value_heads, default: nil + field :rope_parameters, default: nil + field :tie_word_embeddings, default: true + field :layer_types, default: nil + field :sliding_window, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @rope_parameters = _stringify_keys(@rope_parameters || {}) + @rope_parameters["rope_theta"] = 10_000.0 unless @rope_parameters.key?("rope_theta") + @layer_types ||= Array.new(@num_hidden_layers) { "full_attention" } + end + + def rope_parameter(key, default = nil) + return default unless @rope_parameters.is_a?(Hash) + return @rope_parameters[key.to_s] if @rope_parameters.key?(key.to_s) + return @rope_parameters[key.to_sym] if @rope_parameters.key?(key.to_sym) + + default + end + + private + + def _stringify_keys(hash) + hash.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_parameter("rope_theta", 10_000.0), + false, + args.rope_parameters, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, attn_scale:, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + queries = queries * attn_scale + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + attr_reader :use_sliding + + def initialize(args, use_sliding: false) + super() + @use_sliding = use_sliding + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, attn_scale:, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), attn_scale: attn_scale, mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class LanguageModel < MLX::NN::Module + include PipelineMixin + attr_reader :sliding_window + + def initialize(args) + super() + @args = args + @sliding_window = args.sliding_window + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = args.layer_types.map do |layer_type| + TransformerBlock.new(args, use_sliding: layer_type == "sliding_attention") + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil, input_embeddings: nil) + h = input_embeddings || embed_tokens.call(inputs) + active_layers = pipeline_layers + layer_cache = cache || Array.new(active_layers.length) + + first_cache = layer_cache.find { |entry| !entry.nil? } + offset = first_cache ? first_cache.offset : 0 + + fa_idx = nil + swa_idx = nil + active_layers.each_with_index do |layer, i| + if layer.use_sliding + swa_idx ||= i + else + fa_idx ||= i + end + break if fa_idx && swa_idx + end + + fa_mask = fa_idx.nil? ? nil : _create_attention_mask(h, layer_cache[fa_idx]) + swa_mask = if swa_idx.nil? + nil + else + _create_attention_mask(h, layer_cache[swa_idx], window_size: @sliding_window) + end + + beta = @args.rope_parameter("llama_4_scaling_beta", 0.0).to_f + max_pos = @args.rope_parameter( + "original_max_position_embeddings", + @args.max_position_embeddings || h.shape[1] + ).to_i + max_pos = 1 if max_pos <= 0 + + attn_scale = MlxLm::Models::Ministral3.llama4_attn_scale( + inputs.shape[1], + offset, + beta, + max_pos + ).astype(h.dtype) + + active_layers.each_with_index do |layer, idx| + mask = layer.use_sliding ? swa_mask : fa_mask + h = layer.call(h, attn_scale: attn_scale, mask: mask, cache: layer_cache[idx]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = cache ? cache.offset : 0 + if cache && cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = LanguageModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = model.call(inputs, cache: cache, input_embeddings: input_embeddings) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + sanitized = weights.reject do |key, _| + key_name = key.to_s + key_name.include?("self_attn.rotary_emb.inv_freq") || key_name.include?("self_attn.rope.inv_freq") + end + sanitized.delete("lm_head.weight") if @args.tie_word_embeddings + + new_weights = {} + sanitized.each do |key, value| + key_name = key.to_s + if key_name.include?("weight_scale_inv") + wk = key_name.sub("_scale_inv", "") + next unless sanitized.key?(wk) + + new_weights[wk] = sanitized[wk] * value + elsif key_name.include?("activation_scale") + next + elsif !new_weights.key?(key) + new_weights[key] = value + end + end + new_weights + end + + def layers + model.pipeline_layers + end + + def make_cache + max_size = @args.sliding_window || @args.max_position_embeddings || 1 + layers.map do |layer| + if layer.use_sliding + MlxLm::RotatingKVCache.new(max_size: max_size) + else + MlxLm::KVCache.new + end + end + end + end + + Models.register("ministral3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/mistral3.rb b/lib/mlx_lm/models/mistral3.rb new file mode 100644 index 0000000..60e37fb --- /dev/null +++ b/lib/mlx_lm/models/mistral3.rb @@ -0,0 +1,84 @@ +require_relative "llama" + +module MlxLm + module Models + module Mistral3 + class ModelArgs < BaseModelArgs + field :model_type, default: "mistral3" + field :text_config, default: nil + + def initialize(**kwargs) + super + @text_config = (@text_config || {}).dup + @text_config["tie_word_embeddings"] = false unless @text_config.key?("tie_word_embeddings") + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + + text_config = args.text_config || {} + text_model_type = text_config["model_type"] + + if text_model_type == "ministral3" && Models::REGISTRY.key?("ministral3") + model_class, args_class = Models.get_classes(text_config) + self.language_model = model_class.new(args_class.from_dict(text_config)) + else + self.language_model = Llama::Model.new(Llama::ModelArgs.from_dict(text_config)) + end + end + + def call(inputs, cache: nil, input_embeddings: nil) + supports_input_embeddings = language_model.method(:call).parameters.any? do |_, name| + name == :input_embeddings + end + + if supports_input_embeddings + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + else + language_model.call(inputs, cache: cache) + end + end + + def sanitize(weights) + result = {} + language_weights = {} + + weights.each do |k, v| + next if k == "vision_tower" || k.start_with?("vision_tower.") + next if k == "multi_modal_projector" || k.start_with?("multi_modal_projector.") + + if k.start_with?("language_model.") + language_weights[k.delete_prefix("language_model.")] = v + else + result[k] = v + end + end + + sanitized_language = if language_model.respond_to?(:sanitize) + language_model.sanitize(language_weights) + else + language_weights + end + + sanitized_language.each do |k, v| + result["language_model.#{k}"] = v + end + + result + end + + def layers + return language_model.model.layers if language_model.respond_to?(:model) && language_model.model.respond_to?(:layers) + + language_model.layers + end + end + + Models.register("mistral3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/mixtral.rb b/lib/mlx_lm/models/mixtral.rb index 8369292..982c6df 100644 --- a/lib/mlx_lm/models/mixtral.rb +++ b/lib/mlx_lm/models/mixtral.rb @@ -67,19 +67,6 @@ def call(x, mask: nil, cache: nil) end end - class Expert < MLX::NN::Module - def initialize(dim, hidden_dim) - super() - self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) - self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) - self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) - end - - def call(x) - down_proj.call(MLX::NN.silu(gate_proj.call(x)) * up_proj.call(x)) - end - end - class SparseMoeBlock < MLX::NN::Module def initialize(args) super() @@ -89,46 +76,24 @@ def initialize(args) hidden_dim = args.intermediate_size self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false) - self.experts = Array.new(@num_experts) { Expert.new(dim, hidden_dim) } + self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, hidden_dim, @num_experts) end def call(x) mx = MLX::Core - ne = @num_experts_per_tok - orig_shape = x.shape - dims = x.shape[-1] - tokens = x.size / dims - x_flat = x.reshape([tokens, dims]) - - # Route tokens to experts - gates = gate.call(x_flat) - inds = mx.argpartition(gates * -1.0, ne - 1, -1) - take_ids = mx.array((0...ne).to_a, mx.int32) - inds = mx.take(inds, take_ids, 1) + k = @num_experts_per_tok + + gates = gate.call(x) + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) scores = mx.take_along_axis(gates, inds, -1) scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype) - # Evaluate experts per token - inds_list = inds.tolist - y_rows = [] - (0...x_flat.shape[0]).each do |i| - xt = x_flat[i] - selected = inds_list[i] - selected = [selected].flatten - expert_outs = selected.map { |eidx| - mx.expand_dims(experts[eidx].call(xt), 0) - } - yt = mx.concatenate(expert_outs, 0) - # Weighted sum: yt shape [ne, dim], scores[i] shape [ne] - st = scores[i] - weighted = yt * mx.expand_dims(st, -1) - summed = mx.sum(weighted, 0) - y_rows << mx.expand_dims(summed, 0) - end - - y = mx.concatenate(y_rows, 0) - y.reshape(orig_shape) + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores, -1), -2) + y end end @@ -193,8 +158,26 @@ def call(inputs, cache: nil) end def sanitize(weights) + mx = MLX::Core result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } result.delete("lm_head.weight") if @args.tie_word_embeddings + + # Convert per-expert weights to stacked SwitchGLU format + @args.num_hidden_layers.times do |l| + prefix = "model.layers.#{l}" + [["w1", "gate_proj"], ["w2", "down_proj"], ["w3", "up_proj"]].each do |n, m| + ["weight", "scales", "biases"].each do |k| + key0 = "#{prefix}.block_sparse_moe.experts.0.#{n}.#{k}" + if result.key?(key0) + to_join = (0...@args.num_local_experts).map { |e| + result.delete("#{prefix}.block_sparse_moe.experts.#{e}.#{n}.#{k}") + } + result["#{prefix}.block_sparse_moe.switch_mlp.#{m}.#{k}"] = mx.stack(to_join) + end + end + end + end + result end diff --git a/lib/mlx_lm/models/mla.rb b/lib/mlx_lm/models/mla.rb new file mode 100644 index 0000000..27e3b7e --- /dev/null +++ b/lib/mlx_lm/models/mla.rb @@ -0,0 +1,75 @@ +module MlxLm + module Models + module MLA + class MultiLinear < MLX::NN::Module + def initialize(input_dims, output_dims, num_heads) + super() + scale = Math.sqrt(1.0 / input_dims) + self.weight = MLX::Core.uniform([num_heads, output_dims, input_dims], -scale, scale) + end + + def call(x, transpose: true) + if transpose + MLX::Core.matmul(x, MLX::Core.swapaxes(weight, -1, -2)) + else + MLX::Core.matmul(x, weight) + end + end + + def to_quantized(group_size: nil, bits: nil, mode: "affine", quantize_input: false) + raise ArgumentError, "Quantized input is not supported." if quantize_input + + QuantizedMultiLinear.from_multi_linear(self, group_size, bits, mode: mode) + end + end + + class QuantizedMultiLinear < MLX::NN::Module + attr_reader :group_size, :bits, :mode + + def initialize(input_dims, output_dims, num_heads, group_size = nil, bits = nil, mode: "affine") + super() + + @group_size, @bits = MLX::NN.__send__(:defaults_for_mode, mode, group_size, bits) + @mode = mode + + scale = Math.sqrt(1.0 / input_dims) + weight = MLX::Core.uniform([num_heads, output_dims, input_dims], -scale, scale) + q_weight, q_scales, *q_biases = MLX::Core.quantize(weight, @group_size, @bits, @mode) + self.weight = q_weight + self.scales = q_scales + self.biases = q_biases.empty? ? nil : q_biases[0] + + freeze + end + + def call(x, transpose: true) + MLX::Core.quantized_matmul( + x, + weight, + scales, + biases, + transpose, + @group_size, + @bits, + @mode + ) + end + + def self.from_multi_linear(multi_linear_layer, group_size = nil, bits = nil, mode: "affine") + num_heads, output_dims, input_dims = multi_linear_layer.weight.shape + out = new(input_dims, output_dims, num_heads, group_size, bits, mode: mode) + q_weight, q_scales, *q_biases = MLX::Core.quantize( + multi_linear_layer.weight, + out.group_size, + out.bits, + out.mode + ) + out.weight = q_weight + out.scales = q_scales + out.biases = q_biases.empty? ? nil : q_biases[0] + out + end + end + end + end +end diff --git a/lib/mlx_lm/models/nanochat.rb b/lib/mlx_lm/models/nanochat.rb new file mode 100644 index 0000000..1ff416b --- /dev/null +++ b/lib/mlx_lm/models/nanochat.rb @@ -0,0 +1,167 @@ +module MlxLm + module Models + module Nanochat + module_function + + def rms_norm(x, eps: 1e-5) + mx = MLX::Core + variance = mx.mean(mx.square(x), -1, true) + mx.multiply(x, mx.rsqrt(mx.add(variance, eps))) + end + + def softcap(logits, cap: 15.0) + mx = MLX::Core + mx.multiply(cap, mx.tanh(mx.divide(logits, cap))) + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "nanochat" + field :hidden_size, default: 1280 + field :num_hidden_layers, default: 20 + field :num_attention_heads, default: 10 + field :num_key_value_heads, default: 10 + field :vocab_size, default: 65_536 + field :max_position_embeddings, default: 2048 + field :intermediate_size, default: 5120 + field :rope_theta, default: 10_000.0 + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + @hidden_size = args.hidden_size + @num_heads = args.num_attention_heads + @num_kv_heads = args.num_key_value_heads + @head_dim = @hidden_size / @num_heads + @scale = @head_dim**(-0.5) + @rope_theta = args.rope_theta + + self.c_q = MLX::NN::Linear.new(@hidden_size, @num_heads * @head_dim, bias: false) + self.c_k = MLX::NN::Linear.new(@hidden_size, @num_kv_heads * @head_dim, bias: false) + self.c_v = MLX::NN::Linear.new(@hidden_size, @num_kv_heads * @head_dim, bias: false) + self.c_proj = MLX::NN::Linear.new(@hidden_size, @hidden_size, bias: false) + + mx = MLX::Core + exponent = mx.multiply( + mx.arange(0, @head_dim, 2, mx.float32), + Math.log(@rope_theta) / @head_dim.to_f + ) + self._rope_freqs = mx.multiply(-1.0, mx.exp(exponent)) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = c_q.call(x) + keys = c_k.call(x) + values = c_v.call(x) + + queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + offset = cache ? cache.offset : 0 + queries = _apply_rotary_emb(queries, offset: offset) + keys = _apply_rotary_emb(keys, offset: offset) + + queries = Nanochat.rms_norm(queries) + keys = Nanochat.rms_norm(keys) + + if cache + keys, values = cache.update_and_fetch(keys, values) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @hidden_size]) + c_proj.call(output) + end + + private + + def _apply_rotary_emb(x, offset:) + MLX::Core.rope(x, @head_dim, false, nil, 1.0, offset, _rope_freqs) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + self.c_fc = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + self.c_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false) + end + + def call(x) + c_proj.call(MLX::NN.relu2(c_fc.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.attn = Attention.new(args) + self.mlp = MLP.new(args) + end + + def call(x, mask: nil, cache: nil) + h = x + attn.call(Nanochat.rms_norm(x), mask: mask, cache: cache) + h + mlp.call(Nanochat.rms_norm(h)) + end + end + + class NanoChatModel < MLX::NN::Module + def initialize(args) + super() + self.wte = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.h = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + end + + def call(inputs, cache: nil) + hidden = wte.call(inputs) + hidden = Nanochat.rms_norm(hidden) + + layer_cache = cache || [nil] * h.length + mask = _create_attention_mask(hidden, layer_cache[0]) + + h.each_with_index do |layer, i| + hidden = layer.call(hidden, mask: mask, cache: layer_cache[i]) + end + + Nanochat.rms_norm(hidden) + end + + private + + def _create_attention_mask(hidden, cache) + return cache.make_mask(hidden.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if hidden.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.args = args + self.model_type = args.model_type + self.transformer = NanoChatModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = transformer.call(inputs, cache: cache) + logits = lm_head.call(out) + Nanochat.softcap(logits) + end + + def layers + transformer.h + end + end + + Models.register("nanochat", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/nemotron.rb b/lib/mlx_lm/models/nemotron.rb new file mode 100644 index 0000000..b371070 --- /dev/null +++ b/lib/mlx_lm/models/nemotron.rb @@ -0,0 +1,202 @@ +module MlxLm + module Models + module Nemotron + class ModelArgs < BaseModelArgs + field :model_type, default: "nemotron" + field :hidden_size + field :hidden_act + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :norm_eps + field :vocab_size + field :num_key_value_heads + field :head_dim, default: nil + field :max_position_embeddings, default: nil + field :attention_bias, default: false + field :mlp_bias, default: false + field :partial_rotary_factor, default: 0.5 + field :rope_theta, default: 10_000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @head_dim ||= @hidden_size / @num_attention_heads + validate_rope_scaling! + end + + private + + def rope_scaling_value(key) + return nil unless @rope_scaling + + @rope_scaling[key] || @rope_scaling[key.to_s] + end + + def validate_rope_scaling! + return unless @rope_scaling + + raise ArgumentError, "rope_scaling must contain 'factor'" if rope_scaling_value(:factor).nil? + + rope_type = rope_scaling_value(:type) || rope_scaling_value(:rope_type) + if rope_type.nil? + raise ArgumentError, "rope_scaling must contain either 'type' or 'rope_type'" + end + return if rope_type == "linear" + + raise ArgumentError, "rope_scaling 'type' currently only supports 'linear'" + end + end + + class NemotronLayerNorm1P < MLX::NN::LayerNorm + def call(x) + w = state.key?("weight") ? weight + 1.0 : nil + b = state.key?("bias") ? bias : nil + MLX::Core.layer_norm(x, w, b, @eps) + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @partial_rotary_factor = args.partial_rotary_factor + @scale = @head_dim**(-0.5) + + bias = args.attention_bias + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: bias) + + rope_scale = 1.0 + if args.rope_scaling + rope_type = args.rope_scaling[:type] || args.rope_scaling["type"] || + args.rope_scaling[:rope_type] || args.rope_scaling["rope_type"] + if rope_type == "linear" + factor = args.rope_scaling[:factor] || args.rope_scaling["factor"] + rope_scale = 1.0 / factor.to_f + end + end + + self.rope = MLX::NN::RoPE.new( + (@partial_rotary_factor * @head_dim).to_i, + base: args.rope_theta, + scale: rope_scale + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + bias = args.mlp_bias + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + end + + def call(x) + down_proj.call(MLX::NN.relu2(up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args) + self.input_layernorm = NemotronLayerNorm1P.new(args.hidden_size, eps: args.norm_eps) + self.post_attention_layernorm = NemotronLayerNorm1P.new(args.hidden_size, eps: args.norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class NemotronModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = NemotronLayerNorm1P.new(args.hidden_size, eps: args.norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = NemotronModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + end + + Models.register("nemotron", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/nemotron_h.rb b/lib/mlx_lm/models/nemotron_h.rb new file mode 100644 index 0000000..235ec1e --- /dev/null +++ b/lib/mlx_lm/models/nemotron_h.rb @@ -0,0 +1,212 @@ +require_relative "falcon_h1" + +module MlxLm + module Models + module NemotronH + class ModelArgs < FalconH1::ModelArgs + field :model_type, default: "nemotron_h" + field :tie_word_embeddings, default: false + field :mamba_num_heads, default: nil + field :mamba_head_dim, default: nil + field :mamba_proj_bias, default: nil + field :ssm_state_size, default: nil + field :conv_kernel, default: nil + field :n_groups, default: nil + field :mlp_bias, default: nil + field :layer_norm_epsilon, default: nil + field :use_bias, default: nil + field :use_conv_bias, default: nil + field :hybrid_override_pattern, default: nil + field :moe_intermediate_size, default: nil + field :moe_shared_expert_intermediate_size, default: nil + field :n_group, default: nil + field :n_routed_experts, default: nil + field :n_shared_experts, default: nil + field :topk_group, default: nil + field :num_experts_per_tok, default: nil + field :norm_topk_prob, default: nil + field :routed_scaling_factor, default: nil + field :time_step_limit, default: nil + field :time_step_min, default: nil + field :time_step_max, default: nil + + def initialize(**kwargs) + super + + @mamba_d_conv = @conv_kernel if kwargs.key?(:conv_kernel) && !kwargs.key?(:mamba_d_conv) && !@conv_kernel.nil? + @rms_norm_eps = @layer_norm_epsilon if kwargs.key?(:layer_norm_epsilon) && !kwargs.key?(:rms_norm_eps) && !@layer_norm_epsilon.nil? + @num_attention_heads ||= @mamba_num_heads + @head_dim ||= @mamba_head_dim + + pattern = _hybrid_pattern_array + @hybrid_override_pattern = pattern unless pattern.nil? + @hybrid_override_pattern ||= _default_hybrid_pattern + + if @num_hidden_layers.nil? && @hybrid_override_pattern.is_a?(Array) && !@hybrid_override_pattern.empty? + @num_hidden_layers = @hybrid_override_pattern.length + end + + @num_key_value_heads ||= @num_attention_heads + @mamba_d_conv ||= 4 + @block_types ||= _to_block_types(@hybrid_override_pattern) + end + + def to_falcon_h1_dict + hidden_size = @hidden_size + attention_heads = @num_attention_heads + inferred_head_dim = if !@head_dim.nil? + @head_dim + elsif !@mamba_head_dim.nil? + @mamba_head_dim + elsif !hidden_size.nil? && attention_heads.to_i > 0 + hidden_size / attention_heads + else + 64 + end + + { + "model_type" => @model_type, + "attention_bias" => @attention_bias, + "head_dim" => inferred_head_dim, + "hidden_size" => hidden_size, + "intermediate_size" => @intermediate_size || @moe_shared_expert_intermediate_size || hidden_size.to_i * 2, + "max_position_embeddings" => @max_position_embeddings, + "mamba_d_conv" => @mamba_d_conv, + "num_attention_heads" => attention_heads, + "num_hidden_layers" => @num_hidden_layers, + "num_key_value_heads" => @num_key_value_heads, + "rms_norm_eps" => @rms_norm_eps || @layer_norm_epsilon || 1e-5, + "rope_theta" => @rope_theta, + "vocab_size" => @vocab_size, + "tie_word_embeddings" => @tie_word_embeddings, + "attention_window_size" => @attention_window_size, + "block_types" => @block_types, + } + end + + private + + def _hybrid_pattern_array + return nil if @hybrid_override_pattern.nil? + return @hybrid_override_pattern if @hybrid_override_pattern.is_a?(Array) + return @hybrid_override_pattern.chars if @hybrid_override_pattern.is_a?(String) + + nil + end + + def _default_hybrid_pattern + count = @num_hidden_layers.to_i + return nil if count <= 0 + + Array.new(count) { |idx| idx.even? ? "*" : "M" } + end + + def _to_block_types(pattern) + return @block_types if @block_types.is_a?(Array) && !@block_types.empty? + return nil unless pattern.is_a?(Array) && !pattern.empty? + + pattern.map do |block_type| + case block_type.to_s + when "*" + "attention" + else + "recurrent" + end + end + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = FalconH1::Model.new( + FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + normalized = weights.is_a?(Hash) ? weights.dup : weights.to_h + _stack_experts!(normalized) + + remapped = {} + normalized.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return nil unless wrapped_model.respond_to?(:make_cache) + + wrapped_model.make_cache + end + + private + + def _stack_experts!(weights) + mx = MLX::Core + grouped = Hash.new { |h, k| h[k] = [] } + pattern = /\A(backbone\.layers\.\d+\.mixer|model\.layers(?:\.layers)?\.\d+\.mixer)\.experts\.(\d+)\.(up_proj|down_proj)\.(weight|bias|scales|biases)\z/ + + weights.keys.each do |key| + match = pattern.match(key) + next unless match + + prefix = match[1] + expert_idx = match[2].to_i + projection = match[3] + param = match[4] + grouped[[prefix, projection, param]] << [expert_idx, key] + end + + grouped.each do |(prefix, projection, param), entries| + next if entries.empty? + + stacked = entries.sort_by(&:first).map { |_, key| weights.delete(key) } + target = projection == "up_proj" ? "fc1" : "fc2" + weights["#{prefix}.switch_mlp.#{target}.#{param}"] = mx.stack(stacked) + end + end + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub("backbone.embeddings.", "model.embed_tokens.") + mapped = mapped.gsub("backbone.norm_f.", "model.final_layernorm.") + mapped = mapped.gsub("backbone.layers.", "model.layers.") + mapped = mapped.gsub("model.layers.layers.", "model.layers.") + + mapped = mapped.gsub(/\.layers\.(\d+)\.norm\./) { ".layers.#{$1}.input_layernorm." } + + mapped = mapped.gsub(".mixer.conv1d.", ".mamba.conv1d.") + mapped = mapped.gsub(".mixer.in_proj.", ".mamba.in_proj.") + mapped = mapped.gsub(".mixer.out_proj.", ".mamba.out_proj.") + mapped = mapped.gsub(".mixer.q_proj.", ".self_attn.q_proj.") + mapped = mapped.gsub(".mixer.k_proj.", ".self_attn.k_proj.") + mapped = mapped.gsub(".mixer.v_proj.", ".self_attn.v_proj.") + mapped = mapped.gsub(".mixer.o_proj.", ".self_attn.o_proj.") + mapped = mapped.gsub(".mixer.gate.", ".feed_forward.router.") + mapped = mapped.gsub(".mixer.switch_mlp.fc1.", ".feed_forward.switch_mlp.up_proj.") + mapped = mapped.gsub(".mixer.switch_mlp.fc2.", ".feed_forward.switch_mlp.down_proj.") + mapped = mapped.gsub(".mixer.shared_experts.up_proj.", ".feed_forward.up_proj.") + mapped = mapped.gsub(".mixer.shared_experts.down_proj.", ".feed_forward.down_proj.") + mapped = mapped.gsub(".mixer.up_proj.", ".feed_forward.up_proj.") + mapped = mapped.gsub(".mixer.down_proj.", ".feed_forward.down_proj.") + mapped + end + end + + Models.register("nemotron_h", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/nemotron_nas.rb b/lib/mlx_lm/models/nemotron_nas.rb new file mode 100644 index 0000000..b6c7587 --- /dev/null +++ b/lib/mlx_lm/models/nemotron_nas.rb @@ -0,0 +1,404 @@ +require_relative "cache" +require_relative "rope_utils" + +module MlxLm + module Models + module NemotronNas + module_function + + def find_multiple(n, k) + remainder = n % k + remainder.zero? ? n : (n + k - remainder) + end + + def ffn_mult_to_intermediate_size(ffn_mult, hidden_size) + intermediate_size = (2 * ffn_mult.to_f * hidden_size / 3).to_i + find_multiple(intermediate_size, 256) + end + + class AttentionConfig + attr_reader :no_op, :replace_with_linear, :sparsify, :n_heads_in_group, :window_length, + :num_sink_tokens, :use_prefill_window_in_sink_attention, :unshifted_sink + + def initialize( + no_op: false, + replace_with_linear: false, + sparsify: nil, + n_heads_in_group: nil, + window_length: nil, + num_sink_tokens: nil, + use_prefill_window_in_sink_attention: false, + unshifted_sink: false + ) + @no_op = no_op + @replace_with_linear = replace_with_linear + @sparsify = sparsify + @n_heads_in_group = n_heads_in_group + @window_length = window_length + @num_sink_tokens = num_sink_tokens + @use_prefill_window_in_sink_attention = use_prefill_window_in_sink_attention + @unshifted_sink = unshifted_sink + + if @no_op || @replace_with_linear + @n_heads_in_group = nil + @window_length = nil + @num_sink_tokens = nil + else + raise ArgumentError, "n_heads_in_group must be specified for active attention blocks" if @n_heads_in_group.nil? + raise ArgumentError, "n_heads_in_group must be positive, got #{@n_heads_in_group}" if @n_heads_in_group.to_i <= 0 + end + end + + def self.from_dict(data) + hash = _symbolize_keys(data || {}) + new(**hash) + end + + def self._symbolize_keys(hash) + hash.each_with_object({}) { |(k, v), out| out[k.to_sym] = v } + end + private_class_method :_symbolize_keys + end + + class FFNConfig + attr_reader :no_op, :replace_with_linear, :sparsify, :ffn_mult + + def initialize( + no_op: false, + replace_with_linear: false, + sparsify: nil, + ffn_mult: nil + ) + @no_op = no_op + @replace_with_linear = replace_with_linear + @sparsify = sparsify + @ffn_mult = ffn_mult + + if @no_op || @replace_with_linear + @ffn_mult = nil + else + raise ArgumentError, "ffn_mult must be specified for active FFN blocks" if @ffn_mult.nil? + @ffn_mult = @ffn_mult.to_f.round(6) + end + end + + def self.from_dict(data) + hash = _symbolize_keys(data || {}) + new(**hash) + end + + def self._symbolize_keys(hash) + hash.each_with_object({}) { |(k, v), out| out[k.to_sym] = v } + end + private_class_method :_symbolize_keys + end + + class BlockConfig + attr_reader :attention, :ffn + + def initialize(attention:, ffn:) + @attention = attention + @ffn = ffn + end + + def self.from_dict(data) + hash = data || {} + attention_data = hash["attention"] || hash[:attention] || {} + ffn_data = hash["ffn"] || hash[:ffn] || {} + new( + attention: AttentionConfig.from_dict(attention_data), + ffn: FFNConfig.from_dict(ffn_data) + ) + end + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "nemotron-nas" + field :hidden_size, default: 8192 + field :num_hidden_layers, default: 80 + field :num_attention_heads, default: 64 + field :rms_norm_eps, default: 1e-5 + field :vocab_size, default: 128_256 + field :block_configs, default: [] + field :hidden_act, default: "silu" + field :attention_bias, default: false + field :mlp_bias, default: false + field :rope_theta, default: 500_000.0 + field :rope_scaling, default: nil + field :max_position_embeddings, default: 131_072 + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @block_configs = Array(@block_configs).map do |config| + config.is_a?(BlockConfig) ? config : BlockConfig.from_dict(config) + end + + if @block_configs.length != @num_hidden_layers + raise ArgumentError, + "Number of block_configs (#{@block_configs.length}) must match num_hidden_layers (#{@num_hidden_layers})" + end + + validate_rope_scaling! + validate_block_configs! + end + + private + + def validate_rope_scaling! + return unless @rope_scaling + + factor = rope_scaling_value(:factor) + raise ArgumentError, "rope_scaling must contain 'factor'" if factor.nil? + + rope_type = rope_scaling_value(:rope_type) || rope_scaling_value(:type) + raise ArgumentError, "rope_scaling must contain 'rope_type'" if rope_type.nil? + + normalized = @rope_scaling.dup + normalized["rope_type"] = rope_type + normalized[:rope_type] = rope_type + @rope_scaling = normalized + end + + def rope_scaling_value(key) + return nil unless @rope_scaling + return @rope_scaling[key] if @rope_scaling.key?(key) + + @rope_scaling[key.to_s] + end + + def validate_block_configs! + @block_configs.each_with_index do |block_config, i| + attention = block_config.attention + next if attention.no_op || attention.replace_with_linear + + heads_in_group = attention.n_heads_in_group.to_i + if (@num_attention_heads % heads_in_group) != 0 + raise ArgumentError, + "Layer #{i}: num_attention_heads (#{@num_attention_heads}) must be divisible by n_heads_in_group (#{attention.n_heads_in_group})" + end + end + end + end + + class Attention < MLX::NN::Module + def initialize(args, attention_config) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = @n_heads / attention_config.n_heads_in_group + @head_dim = args.hidden_size / @n_heads + raise ArgumentError, "hidden_size (#{dim}) must be divisible by num_attention_heads (#{@n_heads})" if (@head_dim * @n_heads) != dim + + @scale = @head_dim**(-0.5) + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: args.attention_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: args.attention_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: args.attention_bias) + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args, ffn_config) + super() + hidden_dim = NemotronNas.ffn_mult_to_intermediate_size(ffn_config.ffn_mult, args.hidden_size) + @act_fn = args.hidden_act + + supported = %w[silu relu gelu gelu_new gelu_fast] + unless supported.include?(@act_fn) + raise ArgumentError, "Unknown activation function: #{@act_fn}" + end + + self.gate_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.mlp_bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, args.hidden_size, bias: args.mlp_bias) + self.up_proj = MLX::NN::Linear.new(args.hidden_size, hidden_dim, bias: args.mlp_bias) + end + + def call(x) + gate = _activate(gate_proj.call(x)) + down_proj.call(gate * up_proj.call(x)) + end + + private + + def _activate(x) + case @act_fn + when "silu" + MLX::NN.silu(x) + when "relu" + MLX::NN.relu(x) + when "gelu" + MLX::NN.gelu(x) + when "gelu_new", "gelu_fast" + MLX::NN.gelu_approx(x) + else + x + end + end + end + + class LinearSubblockReplacement < MLX::NN::Module + def initialize(hidden_size, bias) + super() + self.linear = MLX::NN::Linear.new(hidden_size, hidden_size, bias: bias) + end + + def call(x, mask: nil, cache: nil) + _ = mask + _ = cache + linear.call(x) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, layer_idx) + super() + block_config = args.block_configs[layer_idx] + @attention_config = block_config.attention + @ffn_config = block_config.ffn + + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) unless @attention_config.no_op + self.self_attn = if @attention_config.no_op + nil + elsif @attention_config.replace_with_linear + LinearSubblockReplacement.new(args.hidden_size, args.attention_bias) + else + Attention.new(args, @attention_config) + end + + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) unless @ffn_config.no_op + self.mlp = if @ffn_config.no_op + nil + elsif @ffn_config.replace_with_linear + LinearSubblockReplacement.new(args.hidden_size, args.mlp_bias) + else + MLP.new(args, @ffn_config) + end + end + + def call(x, mask: nil, cache: nil) + if self_attn + residual = x + h = input_layernorm.call(x) + x = residual + self_attn.call(h, mask: mask, cache: cache) + end + + if mlp + residual = x + h = post_attention_layernorm.call(x) + x = residual + mlp.call(h) + end + + x + end + end + + class NemotronNASModel < MLX::NN::Module + attr_reader :num_attn_layers + + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |layer_idx| TransformerBlock.new(args, layer_idx) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + @num_attn_layers = layers.count { |layer| !layer.self_attn.nil? } + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * @num_attn_layers + mask = _create_attention_mask(h, layer_cache[0]) + + cache_idx = 0 + layers.each do |layer| + layer_state = if layer.self_attn + state = layer_cache[cache_idx] + cache_idx += 1 + state + end + h = layer.call(h, mask: mask, cache: layer_state) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = NemotronNASModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) unless args.tie_word_embeddings + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + + def make_cache + layers.filter_map do |layer| + MlxLm::KVCache.new if layer.self_attn + end + end + end + + Models.register("nemotron-nas", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/olmo.rb b/lib/mlx_lm/models/olmo.rb new file mode 100644 index 0000000..7a259fd --- /dev/null +++ b/lib/mlx_lm/models/olmo.rb @@ -0,0 +1,165 @@ +module MlxLm + module Models + module OLMo + class ModelArgs < BaseModelArgs + field :model_type, default: "olmo" + field :d_model, default: nil + field :n_layers, default: nil + field :mlp_hidden_size, default: nil + field :n_heads, default: nil + field :vocab_size, default: 50304 + field :embedding_size, default: nil + field :rope_theta, default: 10000.0 + field :rope_traditional, default: false + field :mlp_ratio, default: 4 + field :weight_tying, default: false + + # Compatibility aliases used in some generic tests/config builders. + field :hidden_size, default: nil + field :num_hidden_layers, default: nil + field :intermediate_size, default: nil + field :num_attention_heads, default: nil + field :tie_word_embeddings, default: nil + + def initialize(**kwargs) + super + @d_model = @hidden_size if @hidden_size + @n_layers = @num_hidden_layers if @num_hidden_layers + @n_heads = @num_attention_heads if @num_attention_heads + @mlp_hidden_size = @intermediate_size if @intermediate_size + @weight_tying = @tie_word_embeddings unless @tie_word_embeddings.nil? + + @d_model ||= 4096 + @n_layers ||= 32 + @n_heads ||= 32 + @embedding_size ||= @vocab_size + @mlp_hidden_size ||= @mlp_ratio * @d_model + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + dim = args.d_model + @n_heads = args.n_heads + @head_dim = dim / @n_heads + @scale = @head_dim**(-0.5) + @ff_hidden_size = args.mlp_hidden_size + + self.ff_proj = MLX::NN::Linear.new(dim, @ff_hidden_size, bias: false) + self.ff_out = MLX::NN::Linear.new(@ff_hidden_size / 2, dim, bias: false) + + self.att_norm = MLX::NN::LayerNorm.new(dim, affine: false) + self.ff_norm = MLX::NN::LayerNorm.new(dim, affine: false) + + self.att_proj = MLX::NN::Linear.new(dim, 3 * dim, bias: false) + self.attn_out = MLX::NN::Linear.new(dim, dim, bias: false) + + self.rope = MLX::NN::RoPE.new( + @head_dim, + traditional: args.rope_traditional, + base: args.rope_theta + ) + end + + def attend(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, d = x.shape + + qkv = att_proj.call(x) + queries, keys, values = mx.split(qkv, [d, 2 * d], 2) + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, d]) + attn_out.call(output) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + + r = attend(att_norm.call(x), mask: mask, cache: cache) + h = x + r + + ff_hidden = ff_proj.call(ff_norm.call(h)) + x1, x2 = mx.split(ff_hidden, [@ff_hidden_size / 2], 2) + h + ff_out.call(Activations.swiglu(x2, x1)) + end + end + + class Transformer < MLX::NN::Module + def initialize(args) + super() + @weight_tying = args.weight_tying + + self.wte = MLX::NN::Embedding.new(args.embedding_size, args.d_model) + self.blocks = Array.new(args.n_layers) { TransformerBlock.new(args) } + self.ff_out = MLX::NN::Linear.new(args.d_model, args.embedding_size, bias: false) unless @weight_tying + self.norm = MLX::NN::LayerNorm.new(args.d_model, affine: false) + end + + def call(inputs, cache: nil) + h = wte.call(inputs) + layer_cache = cache || [nil] * blocks.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + blocks.each_with_index do |block, i| + h = block.call(h, mask: mask, cache: layer_cache[i]) + end + + h = norm.call(h) + + if @weight_tying + wte.as_linear(h) + else + ff_out.call(h) + end + end + end + + class OlmoModel < MLX::NN::Module + def initialize(args) + super() + self.transformer = Transformer.new(args) + end + + def call(inputs, cache: nil) + transformer.call(inputs, cache: cache) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.model_type = args.model_type + self.model = OlmoModel.new(args) + self.args = args + end + + def call(inputs, cache: nil) + model.call(inputs, cache: cache) + end + + def layers + model.transformer.blocks + end + end + + Models.register("olmo", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/olmo3.rb b/lib/mlx_lm/models/olmo3.rb new file mode 100644 index 0000000..df39749 --- /dev/null +++ b/lib/mlx_lm/models/olmo3.rb @@ -0,0 +1,254 @@ +module MlxLm + module Models + module OLMo3 + class ModelArgs < BaseModelArgs + field :model_type, default: "olmo3" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :max_position_embeddings + field :sliding_window + field :rope_theta + field :attention_bias, default: false + field :layer_types, default: nil + field :num_key_value_heads, default: nil + field :head_dim, default: nil + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + @layer_types ||= Array.new(@num_hidden_layers) do |i| + ((i + 1) % 4).zero? ? "full_attention" : "sliding_attention" + end + end + end + + class Olmo3Attention < MLX::NN::Module + def initialize(args, layer_idx:) + super() + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new( + args.hidden_size, + @num_attention_heads * @head_dim, + bias: args.attention_bias + ) + self.k_proj = MLX::NN::Linear.new( + args.hidden_size, + @num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.v_proj = MLX::NN::Linear.new( + args.hidden_size, + @num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.o_proj = MLX::NN::Linear.new( + @num_attention_heads * @head_dim, + args.hidden_size, + bias: args.attention_bias + ) + + self.q_norm = MLX::NN::RMSNorm.new( + @num_attention_heads * @head_dim, + eps: args.rms_norm_eps + ) + self.k_norm = MLX::NN::RMSNorm.new( + @num_key_value_heads * @head_dim, + eps: args.rms_norm_eps + ) + + if args.layer_types[layer_idx] != "full_attention" + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: args.rope_theta) + else + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_norm.call(q_proj.call(x)) + keys = k_norm.call(k_proj.call(x)) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + o_proj.call(output) + end + end + + class Olmo3MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: false) + self.up_proj = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class Olmo3DecoderLayer < MLX::NN::Module + def initialize(args, layer_idx:) + super() + self.self_attn = Olmo3Attention.new(args, layer_idx: layer_idx) + self.mlp = Olmo3MLP.new(args) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_feedforward_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = post_attention_layernorm.call(self_attn.call(x, mask: mask, cache: cache)) + h = x + r + r = post_feedforward_layernorm.call(mlp.call(h)) + h + r + end + end + + class Olmo3Model < MLX::NN::Module + attr_reader :layer_types + + def initialize(args) + super() + @sliding_window = args.sliding_window + @layer_types = args.layer_types + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) do |i| + Olmo3DecoderLayer.new(args, layer_idx: i) + end + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + + self.swa_idx = @layer_types.index("sliding_attention") || 0 + self.ga_idx = @layer_types.index("full_attention") || 0 + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + full_mask = _create_attention_mask(h, layer_cache[ga_idx]) + sliding_window_mask = _create_attention_mask( + h, + layer_cache[swa_idx], + window_size: @sliding_window + ) + + layers.each_with_index do |layer, i| + mask = @layer_types[i] == "full_attention" ? full_mask : sliding_window_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + attr_reader :args + + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Olmo3Model.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + + def make_cache + model.layer_types.map do |layer_type| + if layer_type == "full_attention" + KVCache.new + else + RotatingKVCache.new(max_size: args.sliding_window) + end + end + end + end + + Models.register("olmo3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/olmoe.rb b/lib/mlx_lm/models/olmoe.rb new file mode 100644 index 0000000..4fa7e1e --- /dev/null +++ b/lib/mlx_lm/models/olmoe.rb @@ -0,0 +1,64 @@ +require_relative "olmo2" + +module MlxLm + module Models + module OLMoE + class ModelArgs < OLMo2::ModelArgs + field :model_type, default: "olmoe" + field :num_experts + field :num_experts_per_tok + field :norm_topk_prob, default: false + end + + class Model < OLMo2::Model + def sanitize(weights) + result = super(weights) + rewrite_expert_weights(result) + end + + private + + def rewrite_expert_weights(weights) + return weights unless weights.key?("model.layers.0.mlp.experts.0.up_proj.weight") + + mx = MLX::Core + + layers.length.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.mlp" + %w[up_proj down_proj gate_proj].each do |projection| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.experts.0.#{projection}.#{param}" + next unless weights.key?(first_key) + + expert_count = @args.num_experts || infer_expert_count(weights, prefix, projection, param) + next unless expert_count && expert_count.positive? + + expert_keys = (0...expert_count).map do |expert_idx| + "#{prefix}.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |key| weights.key?(key) } + + weights["#{prefix}.switch_mlp.#{projection}.#{param}"] = mx.stack(expert_keys.map { |key| weights.delete(key) }) + end + end + end + + weights + end + + def infer_expert_count(weights, prefix, projection, param) + pattern = /\A#{Regexp.escape(prefix)}\.experts\.(\d+)\.#{projection}\.#{param}\z/ + indices = weights.keys.filter_map do |key| + match = pattern.match(key) + match[1].to_i if match + end + return 0 if indices.empty? + + indices.max + 1 + end + end + + Models.register("olmoe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/openelm.rb b/lib/mlx_lm/models/openelm.rb new file mode 100644 index 0000000..56fbd6c --- /dev/null +++ b/lib/mlx_lm/models/openelm.rb @@ -0,0 +1,208 @@ +module MlxLm + module Models + module OpenELM + module_function + + def make_divisible(v, divisor = 8, min_value = nil) + min_value ||= divisor + rounded = ((v + (divisor.to_f / 2)).to_i / divisor) * divisor + new_v = [min_value, rounded].max + new_v += divisor if new_v < (0.9 * v) + new_v + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "openelm" + field :head_dim, default: 64 + field :num_transformer_layers, default: 12 + field :model_dim, default: 2048 + field :vocab_size, default: 32_000 + field :ffn_dim_divisor, default: 8 + field :num_query_heads, default: [32] + field :num_kv_heads, default: [] + field :ffn_multipliers, default: [1.0] + field :ffn_with_glu, default: true + field :normalize_qk_projections, default: true + field :share_input_output_layers, default: true + field :rms_norm_eps, default: 1e-6 + field :rope_freq_constant, default: 10_000.0 + + def initialize(**kwargs) + super + @num_query_heads = normalize_schedule(@num_query_heads, @num_transformer_layers, 1, "num_query_heads").map(&:to_i) + + if @num_kv_heads.nil? || Array(@num_kv_heads).empty? + @num_kv_heads = @num_query_heads.dup + else + @num_kv_heads = normalize_schedule(@num_kv_heads, @num_transformer_layers, @num_query_heads[0], "num_kv_heads").map(&:to_i) + end + + @ffn_multipliers = normalize_schedule(@ffn_multipliers, @num_transformer_layers, 1.0, "ffn_multipliers").map(&:to_f) + end + + private + + def normalize_schedule(values, layers, fallback, field_name) + items = Array(values) + items = [fallback] if items.empty? + items = Array.new(layers, items[0]) if items.length == 1 && layers > 1 + + unless items.length == layers + raise ArgumentError, "#{field_name} must have #{layers} entries, got #{items.length}" + end + + items + end + end + + class Attention < MLX::NN::Module + def initialize(args, layer_id:) + super() + @head_dim = args.head_dim + @n_heads = args.num_query_heads[layer_id] + @n_kv_heads = args.num_kv_heads[layer_id] + @scale = @head_dim**(-0.5) + @normalize_qk_projections = args.normalize_qk_projections + + op_size = (@n_heads + (2 * @n_kv_heads)) * @head_dim + self.qkv_proj = MLX::NN::Linear.new(args.model_dim, op_size, bias: false) + self.out_proj = MLX::NN::Linear.new(@n_heads * @head_dim, args.model_dim, bias: false) + + if @normalize_qk_projections + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + end + + self.rope = MLX::NN::RoPE.new(@head_dim, traditional: false, base: args.rope_freq_constant) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = qkv_proj.call(x) + qkv = qkv.reshape([b, l, @n_heads + (2 * @n_kv_heads), @head_dim]).transpose([0, 2, 1, 3]) + queries, keys, values = mx.split(qkv, [@n_heads, @n_heads + @n_kv_heads], 1) + + if @normalize_qk_projections + queries = q_norm.call(queries) + keys = k_norm.call(keys) + end + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + out_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args, layer_id:) + super() + @ffn_with_glu = args.ffn_with_glu + dim = args.model_dim + multiplier = args.ffn_multipliers[layer_id] + @intermediate_dim = OpenELM.make_divisible(multiplier * dim, args.ffn_dim_divisor).to_i + + proj_1_dim = @ffn_with_glu ? (2 * @intermediate_dim) : @intermediate_dim + self.proj_1 = MLX::NN::Linear.new(dim, proj_1_dim, bias: false) + self.proj_2 = MLX::NN::Linear.new(@intermediate_dim, dim, bias: false) + end + + def call(x) + x = proj_1.call(x) + x = if @ffn_with_glu + gate, value = MLX::Core.split(x, [@intermediate_dim], -1) + Activations.swiglu(gate, value) + else + MLX::NN.gelu_approx(x) + end + proj_2.call(x) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, layer_id:) + super() + self.attn = Attention.new(args, layer_id: layer_id) + self.ffn = MLP.new(args, layer_id: layer_id) + self.attn_norm = MLX::NN::RMSNorm.new(args.model_dim, eps: args.rms_norm_eps) + self.ffn_norm = MLX::NN::RMSNorm.new(args.model_dim, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = attn.call(attn_norm.call(x), mask: mask, cache: cache) + h = x + r + r = ffn.call(ffn_norm.call(h)) + h + r + end + end + + class OpenELMModel < MLX::NN::Module + def initialize(args) + super() + self.token_embeddings = MLX::NN::Embedding.new(args.vocab_size, args.model_dim) + self.layers = Array.new(args.num_transformer_layers) do |layer_id| + TransformerBlock.new(args, layer_id: layer_id) + end + self.norm = MLX::NN::RMSNorm.new(args.model_dim, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = token_embeddings.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.transformer = OpenELMModel.new(args) + unless args.share_input_output_layers + self.lm_head = MLX::NN::Linear.new(args.model_dim, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = transformer.call(inputs, cache: cache) + if @args.share_input_output_layers + transformer.token_embeddings.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + transformer.layers + end + end + + Models.register("openelm", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/phi.rb b/lib/mlx_lm/models/phi.rb new file mode 100644 index 0000000..53dc507 --- /dev/null +++ b/lib/mlx_lm/models/phi.rb @@ -0,0 +1,156 @@ +module MlxLm + module Models + module Phi + class ModelArgs < BaseModelArgs + field :model_type, default: "phi" + field :max_position_embeddings, default: 2048 + field :vocab_size, default: 51_200 + field :hidden_size, default: 2560 + field :num_attention_heads, default: 32 + field :num_hidden_layers, default: 32 + field :num_key_value_heads, default: nil + field :partial_rotary_factor, default: 0.4 + field :intermediate_size, default: 10_240 + field :layer_norm_eps, default: 1e-5 + field :rope_theta, default: 10_000.0 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class PhiAttention < MLX::NN::Module + def initialize(args) + super() + @hidden_size = args.hidden_size + @num_heads = args.num_attention_heads + @head_dim = @hidden_size / @num_heads + @num_key_value_heads = args.num_key_value_heads + @scale = @head_dim**(-0.5) + + if (@head_dim * @num_heads) != @hidden_size + raise ArgumentError, + "hidden_size must be divisible by num_heads (hidden_size=#{@hidden_size}, num_heads=#{@num_heads})" + end + + self.q_proj = MLX::NN::Linear.new(@hidden_size, @num_heads * @head_dim, bias: true) + self.k_proj = MLX::NN::Linear.new(@hidden_size, @num_key_value_heads * @head_dim, bias: true) + self.v_proj = MLX::NN::Linear.new(@hidden_size, @num_key_value_heads * @head_dim, bias: true) + self.dense = MLX::NN::Linear.new(@num_heads * @head_dim, @hidden_size, bias: true) + + self.rope = MLX::NN::RoPE.new( + (args.partial_rotary_factor * @head_dim).to_i, + traditional: false, + base: args.rope_theta + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + b, l, _d = queries.shape + queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention( + queries.astype(mx.float32), + keys, + values, + @scale, + mask + ).astype(values.dtype) + + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @head_dim]) + dense.call(output) + end + end + + class PhiMLP < MLX::NN::Module + def initialize(args) + super() + self.fc1 = MLX::NN::Linear.new(args.hidden_size, args.intermediate_size, bias: true) + self.fc2 = MLX::NN::Linear.new(args.intermediate_size, args.hidden_size, bias: true) + end + + def call(x) + fc2.call(MLX::NN.gelu_approx(fc1.call(x))) + end + end + + class PhiDecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = PhiAttention.new(args) + self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps) + self.mlp = PhiMLP.new(args) + end + + def call(x, mask: nil, cache: nil) + h = input_layernorm.call(x) + attn_h = self_attn.call(h, mask: mask, cache: cache) + ff_h = mlp.call(h) + attn_h + ff_h + x + end + end + + class PhiModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { PhiDecoderLayer.new(args) } + self.final_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + final_layernorm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.model = PhiModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: true) + end + + def call(inputs, cache: nil) + lm_head.call(model.call(inputs, cache: cache)) + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("rotary_emb.inv_freq") } + end + + def layers + model.layers + end + end + + Models.register("phi", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/phi3small.rb b/lib/mlx_lm/models/phi3small.rb new file mode 100644 index 0000000..140b188 --- /dev/null +++ b/lib/mlx_lm/models/phi3small.rb @@ -0,0 +1,196 @@ +module MlxLm + module Models + module Phi3small + class ModelArgs < BaseModelArgs + field :model_type, default: "phi3small" + field :hidden_size + field :dense_attention_every_n_layers + field :ff_intermediate_size + field :gegelu_limit + field :num_hidden_layers + field :num_attention_heads + field :layer_norm_epsilon + field :vocab_size + field :num_key_value_heads + field :mup_attn_multiplier, default: 1.0 + field :mup_use_scaling, default: true + field :mup_embedding_multiplier, default: 10.0 + field :mup_width_multiplier, default: 8.0 + field :rope_embedding_base, default: 1_000_000.0 + field :rope_position_scale, default: 1.0 + field :blocksparse_block_size, default: 64 + field :blocksparse_num_local_blocks, default: 16 + field :blocksparse_vert_stride, default: 8 + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args, layer_idx) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @n_q_per_kv = @n_heads / @n_kv_heads + @head_dim = dim / @n_heads + + self.query_key_value = MLX::NN::Linear.new( + dim, + (@n_heads + 2 * @n_kv_heads) * @head_dim + ) + self.dense = MLX::NN::Linear.new(dim, dim) + + norm_factor = if args.mup_use_scaling + @head_dim / args.mup_attn_multiplier.to_f + else + Math.sqrt(@head_dim) + end + @scale = 1.0 / norm_factor + + self.rope = MLX::NN::RoPE.new( + @head_dim, + traditional: false, + base: args.rope_embedding_base, + scale: args.rope_position_scale + ) + + @block_sparse = (layer_idx % args.dense_attention_every_n_layers).zero? + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + qkv = query_key_value.call(x) + q_size = @n_heads * @head_dim + k_size = @n_kv_heads * @head_dim + + queries = mx.split(qkv, [q_size, q_size + k_size], -1)[0] + keys = mx.split(qkv, [q_size, q_size + k_size], -1)[1] + values = mx.split(qkv, [q_size + k_size], -1)[1] + + queries = queries.reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + dense.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @hidden_dim = args.ff_intermediate_size + self.up_proj = MLX::NN::Linear.new(dim, 2 * @hidden_dim) + self.down_proj = MLX::NN::Linear.new(@hidden_dim, dim) + end + + def call(x) + mx = MLX::Core + x = up_proj.call(x) + a_gelu, a_linear = mx.split(x, [@hidden_dim], -1) + out_gelu = a_gelu * mx.sigmoid(1.702 * a_gelu) + down_proj.call(out_gelu * (a_linear + 1.0)) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.self_attn = Attention.new(args, layer_idx) + self.mlp = MLP.new(args) + self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class Phi3Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |layer_idx| TransformerBlock.new(args, layer_idx) } + self.final_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + h = h * @args.mup_embedding_multiplier if @args.mup_embedding_multiplier + + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + final_layernorm.call(h) + end + + private + + def _create_attention_mask(h, cache) + n = h.shape[1] + return cache.make_mask(n) if cache && cache.respond_to?(:make_mask) + return nil if n == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Phi3Model.new(args) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + out = model.embed_tokens.as_linear(out) + out = out / @args.mup_width_multiplier if @args.mup_width_multiplier + out + end + + def sanitize(weights) + weights.reject do |key, _| + key_name = key.to_s + key_name.include?("self_attn.rotary_emb.inv_freq") || + key_name.include?("rotary_emb.inv_freq") || + key_name.include?("position_embeddings.inv_freq") + end + end + + def layers + model.layers + end + end + + Models.register("phi3small", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/phimoe.rb b/lib/mlx_lm/models/phimoe.rb new file mode 100644 index 0000000..a75d098 --- /dev/null +++ b/lib/mlx_lm/models/phimoe.rb @@ -0,0 +1,206 @@ +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module PhiMoe + class ModelArgs < BaseModelArgs + field :model_type, default: "phimoe" + field :vocab_size, default: 32064 + field :hidden_size, default: 4096 + field :intermediate_size, default: 6400 + field :num_hidden_layers, default: 32 + field :num_attention_heads, default: 32 + field :num_key_value_heads, default: 8 + field :max_position_embeddings, default: 131072 + field :original_max_position_embeddings, default: 4096 + field :rms_norm_eps, default: 1e-6 + field :rope_scaling, default: nil + field :num_local_experts, default: 16 + field :num_experts_per_tok, default: 2 + field :rope_theta, default: 10_000.0 + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = dim / @n_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: true) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: true) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: true) + + scaling = args.rope_scaling || {} + self.rope = SuScaledRoPE.new( + @head_dim, + base: args.rope_theta, + max_position_embeddings: args.max_position_embeddings, + original_max_position_embeddings: args.original_max_position_embeddings, + short_factor: _config_value(scaling, "short_factor", 1.0), + long_factor: _config_value(scaling, "long_factor", 1.0), + short_mscale: _config_value(scaling, "short_mscale"), + long_mscale: _config_value(scaling, "long_mscale") + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + + private + + def _config_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class PhiMoESparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + + @hidden_dim = args.hidden_size + @ffn_dim = args.intermediate_size + @num_experts = args.num_local_experts + @top_k = args.num_experts_per_tok + + self.gate = MLX::NN::Linear.new(@hidden_dim, @num_experts, bias: false) + self.switch_mlp = SwitchLayers::SwitchGLU.new(@hidden_dim, @ffn_dim, @num_experts) + end + + def call(x) + mx = MLX::Core + + k = [@top_k, @num_experts].min + gates = gate.call(x) + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + scores = mx.take_along_axis(gates, inds, -1) + scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype) + + y = switch_mlp.call(x, inds) + mx.sum(y * mx.expand_dims(scores, -1), -2) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.block_sparse_moe = PhiMoESparseMoeBlock.new(args) + self.input_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + residual = x + hidden_states = input_layernorm.call(x) + hidden_states = self_attn.call(hidden_states, mask: mask, cache: cache) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = post_attention_layernorm.call(hidden_states) + hidden_states = block_sparse_moe.call(hidden_states) + residual + hidden_states + end + end + + class PhiMoEModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::LayerNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = PhiMoEModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: true) + end + + def call(inputs, cache: nil) + lm_head.call(model.call(inputs, cache: cache)) + end + + def sanitize(weights) + return weights unless weights.key?("model.layers.0.block_sparse_moe.experts.0.w1.weight") + + mx = MLX::Core + result = weights.dup + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}" + [["w1", "gate_proj"], ["w2", "down_proj"], ["w3", "up_proj"]].each do |source, target| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.block_sparse_moe.experts.0.#{source}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...@args.num_local_experts).map do |expert_idx| + "#{prefix}.block_sparse_moe.experts.#{expert_idx}.#{source}.#{param}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.block_sparse_moe.switch_mlp.#{target}.#{param}"] = mx.stack(stacked) + end + end + end + + result + end + + def layers + model.layers + end + end + + Models.register("phimoe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/phixtral.rb b/lib/mlx_lm/models/phixtral.rb new file mode 100644 index 0000000..fb62ac3 --- /dev/null +++ b/lib/mlx_lm/models/phixtral.rb @@ -0,0 +1,208 @@ +require_relative "switch_layers" + +module MlxLm + module Models + module Phixtral + class ModelArgs < BaseModelArgs + field :model_type, default: "phixtral" + field :num_vocab, default: 51_200 + field :model_dim, default: 2_560 + field :num_heads, default: 32 + field :num_layers, default: 32 + field :rotary_dim, default: 32 + field :num_experts_per_tok, default: 2 + field :num_local_experts, default: 4 + end + + class RoPEAttention < MLX::NN::Module + def initialize(dims, num_heads, rotary_dim) + super() + @num_heads = num_heads + @head_dim = dims / num_heads + @scale = @head_dim**(-0.5) + + self.rope = MLX::NN::RoPE.new(rotary_dim, traditional: false) + self.wqkv = MLX::NN::Linear.new(dims, 3 * dims) + self.out_proj = MLX::NN::Linear.new(dims, dims) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, d = x.shape + + qkv = wqkv.call(x) + queries, keys, values = mx.split(qkv, [d, 2 * d], -1) + + queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + queries = queries.astype(mx.float32) + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask).astype(values.dtype) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, d]) + out_proj.call(output) + end + end + + class MOE < MLX::NN::Module + def initialize(args, dim, hidden_dim) + super() + @num_experts = args.num_local_experts + @num_experts_per_tok = args.num_experts_per_tok + + self.switch_mlp = SwitchLayers::SwitchMLP.new( + dim, + hidden_dim, + @num_experts, + bias: true + ) + self.gate = MLX::NN::Linear.new(args.model_dim, @num_experts, bias: false) + end + + def call(x) + mx = MLX::Core + k = @num_experts_per_tok + + gates = gate.call(x) + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + + scores = mx.take_along_axis(gates, inds, -1) + scores = mx.softmax(scores.astype(mx.float32), -1).astype(gates.dtype) + + y = switch_mlp.call(x, inds) + mx.sum(y * mx.expand_dims(scores, -1), -2) + end + end + + class ParallelBlock < MLX::NN::Module + def initialize(config) + super() + dims = config.model_dim + mlp_dims = dims * 4 + + self.mixer = RoPEAttention.new(dims, config.num_heads, config.rotary_dim) + self.ln = MLX::NN::LayerNorm.new(dims) + self.moe = MOE.new(config, dims, mlp_dims) + end + + def call(x, mask: nil, cache: nil) + h = ln.call(x) + attn_h = mixer.call(h, mask: mask, cache: cache) + ff_h = moe.call(h) + attn_h + ff_h + x + end + end + + class Embd < MLX::NN::Module + def initialize(config) + super() + self.wte = MLX::NN::Embedding.new(config.num_vocab, config.model_dim) + end + + def call(x) + wte.call(x) + end + end + + class TransformerDecoder < MLX::NN::Module + def initialize(config) + super() + self.embd = Embd.new(config) + self.h = Array.new(config.num_layers) { ParallelBlock.new(config) } + end + + def call(x, mask: nil, cache: nil) + hidden = embd.call(x) + layer_cache = cache || [nil] * h.length + + h.each_with_index do |layer, i| + hidden = layer.call(hidden, mask: mask, cache: layer_cache[i]) + end + + hidden + end + end + + class OutputHead < MLX::NN::Module + def initialize(config) + super() + self.ln = MLX::NN::LayerNorm.new(config.model_dim) + self.linear = MLX::NN::Linear.new(config.model_dim, config.num_vocab) + end + + def call(inputs) + linear.call(ln.call(inputs)) + end + end + + class Model < MLX::NN::Module + def initialize(config) + super() + @args = config + + self.model_type = config.model_type + self.transformer = TransformerDecoder.new(config) + self.lm_head = OutputHead.new(config) + end + + def call(x, mask: nil, cache: nil) + local_mask = mask || _create_attention_mask(x, cache) + y = transformer.call(x, mask: local_mask, cache: cache) + lm_head.call(y) + end + + def sanitize(weights) + first_key = "transformer.h.0.moe.mlp.0.fc1.weight" + return weights unless weights.key?(first_key) + + mx = MLX::Core + result = weights.dup + + @args.num_layers.times do |layer_idx| + prefix = "transformer.h.#{layer_idx}" + %w[fc1 fc2].each do |proj| + %w[weight scales biases bias].each do |suffix| + expert_keys = (0...@args.num_local_experts).map do |expert_idx| + "#{prefix}.moe.mlp.#{expert_idx}.#{proj}.#{suffix}" + end + next unless expert_keys.all? { |k| result.key?(k) } + + stacked = expert_keys.map { |k| result.delete(k) } + result["#{prefix}.moe.switch_mlp.#{proj}.#{suffix}"] = mx.stack(stacked) + end + end + end + + result + end + + def layers + transformer.h + end + + private + + def _create_attention_mask(tokens, cache) + first_cache = cache.is_a?(Array) ? cache[0] : cache + return first_cache.make_mask(tokens.shape[1]) if first_cache && first_cache.respond_to?(:make_mask) + return nil if tokens.shape[1] == 1 + + "causal" + end + end + + Models.register("phixtral", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/pipeline.rb b/lib/mlx_lm/models/pipeline.rb new file mode 100644 index 0000000..a82e23e --- /dev/null +++ b/lib/mlx_lm/models/pipeline.rb @@ -0,0 +1,37 @@ +module MlxLm + module Models + module PipelineMixin + attr_accessor :pipeline_rank, :pipeline_size, :start_idx, :end_idx + + def initialize(*args, **kwargs) + super(*args, **kwargs) + @pipeline_rank = 0 + @pipeline_size = 1 + @start_idx = 0 + @end_idx = nil + end + + def pipeline_layers + layers[@start_idx...@end_idx] + end + + def pipeline(group) + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first. + @pipeline_rank = group.rank + @pipeline_size = group.size + layers_per_rank = layers.length / @pipeline_size + extra = layers.length - (layers_per_rank * @pipeline_size) + layers_per_rank += 1 if @pipeline_rank < extra + + @start_idx = (@pipeline_size - @pipeline_rank - 1) * layers_per_rank + @end_idx = @start_idx + layers_per_rank + + self.layers = layers[0...@end_idx] + # Keep layer numbering stable for checkpoint loading. + self.layers[0...@start_idx] = Array.new(@start_idx, nil) + self + end + end + end +end diff --git a/lib/mlx_lm/models/pixtral.rb b/lib/mlx_lm/models/pixtral.rb new file mode 100644 index 0000000..8439d79 --- /dev/null +++ b/lib/mlx_lm/models/pixtral.rb @@ -0,0 +1,47 @@ +module MlxLm + module Models + module Pixtral + class ModelArgs < BaseModelArgs + field :model_type, default: "pixtral" + field :text_config + + def initialize(**kwargs) + super + @text_config ||= {} + @text_config["tie_word_embeddings"] = false + unless @text_config.key?("num_attention_heads") || @text_config.key?(:num_attention_heads) + @text_config["num_attention_heads"] = 32 + end + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Llama::Model.new(Llama::ModelArgs.from_dict(args.text_config)) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + weights.reject do |key, _| + key == "vision_tower" || + key.start_with?("vision_tower.") || + key == "multi_modal_projector" || + key.start_with?("multi_modal_projector.") + end + end + + def layers + language_model.model.layers + end + end + + Models.register("pixtral", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/plamo.rb b/lib/mlx_lm/models/plamo.rb new file mode 100644 index 0000000..1e7035b --- /dev/null +++ b/lib/mlx_lm/models/plamo.rb @@ -0,0 +1,169 @@ +module MlxLm + module Models + module Plamo + class ModelArgs < BaseModelArgs + field :model_type, default: "plamo" + field :hidden_size + field :num_hidden_layers + field :intermediate_size + field :num_attention_heads + field :rms_norm_eps + field :vocab_size + field :n_shared_head, default: 8 + field :rope_theta, default: 10_000.0 + field :rope_traditional, default: false + end + + class Attention < MLX::NN::Module + def initialize(config) + super() + @config = config + @hidden_size = config.hidden_size + @q_num_heads = config.num_attention_heads + @head_dim = @hidden_size / @q_num_heads + @qk_dim = @head_dim + @v_dim = @head_dim + @k_num_heads = (@q_num_heads.to_f / config.n_shared_head).ceil + @v_num_heads = @k_num_heads + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(@hidden_size, @q_num_heads * @qk_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(@hidden_size, @k_num_heads * @qk_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(@hidden_size, @v_num_heads * @v_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@q_num_heads * @v_dim, @hidden_size, bias: false) + self.rotary_emb = MLX::NN::RoPE.new( + @head_dim, + traditional: config.rope_traditional, + base: config.rope_theta, + scale: 1.0 + ) + end + + def call(hidden_states, attention_mask: nil, cache: nil) + mx = MLX::Core + bsz, q_len, _d = hidden_states.shape + + queries = q_proj.call(hidden_states) + keys = k_proj.call(hidden_states) + values = v_proj.call(hidden_states) + + queries = queries.reshape([bsz, q_len, @q_num_heads, @qk_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([bsz, q_len, @k_num_heads, @qk_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([bsz, q_len, @v_num_heads, @v_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rotary_emb.call(queries, offset: cache.offset) + keys = rotary_emb.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rotary_emb.call(queries) + keys = rotary_emb.call(keys) + end + + keys = mx.tile(keys, [1, @config.n_shared_head, 1, 1]) + values = mx.tile(values, [1, @config.n_shared_head, 1, 1]) + + output = mx.scaled_dot_product_attention( + queries, + keys, + values, + @scale, + attention_mask + ) + output = output.transpose([0, 2, 1, 3]).reshape([bsz, q_len, @q_num_heads * @v_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(config) + super() + self.gate_proj = MLX::NN::Linear.new(config.hidden_size, config.intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(config.hidden_size, config.intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(config.intermediate_size, config.hidden_size, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(config) + super() + self.self_attn = Attention.new(config) + self.mlp = MLP.new(config) + self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(hidden_states, attention_mask: nil, cache: nil) + residual = hidden_states + hidden_states = norm.call(hidden_states) + + hidden_states_sa = self_attn.call( + hidden_states, + attention_mask: attention_mask, + cache: cache + ) + hidden_states_mlp = mlp.call(hidden_states) + + residual + hidden_states_sa + hidden_states_mlp + end + end + + class PlamoModel < MLX::NN::Module + def initialize(config) + super() + self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size) + self.layers = Array.new(config.num_hidden_layers) { DecoderLayer.new(config) } + self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, attention_mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + self.model_type = args.model_type + self.model = PlamoModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + lm_head.call(out) + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + end + + def layers + model.layers + end + end + + Models.register("plamo", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/plamo2.rb b/lib/mlx_lm/models/plamo2.rb new file mode 100644 index 0000000..f8f0adb --- /dev/null +++ b/lib/mlx_lm/models/plamo2.rb @@ -0,0 +1,173 @@ +require_relative "falcon_h1" + +module MlxLm + module Models + module Plamo2 + class ModelArgs < FalconH1::ModelArgs + field :model_type, default: "plamo2" + field :rope_theta, default: 10_000.0 + field :tie_word_embeddings, default: true + field :hidden_size_per_head, default: nil + field :full_attention_idx, default: nil + field :mamba_d_state, default: nil + field :mamba_num_heads, default: nil + field :mamba_step, default: 2 + field :mamba_chunk_size, default: nil + field :mamba_enabled, default: true + + def initialize(**kwargs) + super + @head_dim = @hidden_size_per_head if kwargs.key?(:hidden_size_per_head) && !kwargs.key?(:head_dim) && !@hidden_size_per_head.nil? + @num_attention_heads ||= @mamba_num_heads + @num_key_value_heads ||= @num_attention_heads + @mamba_d_conv ||= 4 + @attention_window_size ||= @max_position_embeddings + @block_types ||= _to_block_types + end + + def to_falcon_h1_dict + hidden_size = @hidden_size + attention_heads = @num_attention_heads + inferred_head_dim = if !@head_dim.nil? + @head_dim + elsif !@hidden_size_per_head.nil? + @hidden_size_per_head + elsif !hidden_size.nil? && attention_heads.to_i > 0 + hidden_size / attention_heads + else + 64 + end + + { + "model_type" => @model_type, + "attention_bias" => @attention_bias, + "head_dim" => inferred_head_dim, + "hidden_size" => hidden_size, + "intermediate_size" => @intermediate_size, + "max_position_embeddings" => @max_position_embeddings, + "mamba_d_conv" => @mamba_d_conv, + "num_attention_heads" => attention_heads, + "num_hidden_layers" => @num_hidden_layers, + "num_key_value_heads" => @num_key_value_heads, + "rms_norm_eps" => @rms_norm_eps, + "rope_theta" => @rope_theta, + "vocab_size" => @vocab_size, + "tie_word_embeddings" => @tie_word_embeddings, + "attention_window_size" => @attention_window_size, + "block_types" => @block_types, + } + end + + private + + def _to_block_types + return @block_types if @block_types.is_a?(Array) && !@block_types.empty? + + count = @num_hidden_layers.to_i + return nil if count <= 0 + + if @full_attention_idx.is_a?(Array) && !@full_attention_idx.empty? + full_attention = @full_attention_idx.map(&:to_i) + return Array.new(count) { |i| full_attention.include?(i) ? "attention" : "recurrent" } + end + + return Array.new(count, "attention") unless @mamba_enabled + + step = @mamba_step.to_i + step = 2 if step <= 1 + midpoint = step / 2 + + if count <= midpoint + return Array.new(count) { |i| i == count - 1 ? "attention" : "recurrent" } + end + + Array.new(count) { |i| (i % step) == midpoint ? "attention" : "recurrent" } + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = FalconH1::Model.new( + FalconH1::ModelArgs.from_dict(args.to_falcon_h1_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + normalized = weights.is_a?(Hash) ? weights.dup : weights.to_h + _split_gate_up_proj!(normalized) + + remapped = {} + normalized.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return nil unless wrapped_model.respond_to?(:make_cache) + + wrapped_model.make_cache + end + + private + + def _split_gate_up_proj!(weights) + mx = MLX::Core + pattern = /\A(model\.layers(?:\.layers)?\.\d+\.mlp)\.gate_up_proj\.(weight|bias|scales|biases)\z/ + + weights.keys.each do |key| + match = pattern.match(key) + next unless match + + prefix = match[1] + param = match[2] + gate_up = weights.delete(key) + mid = gate_up.shape[0] / 2 + next if mid <= 0 + + gate_proj, up_proj = mx.split(gate_up, [mid], 0) + weights["#{prefix}.gate_proj.#{param}"] = gate_proj + weights["#{prefix}.up_proj.#{param}"] = up_proj + end + end + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub("model.layers.layers.", "model.layers.") + mapped = mapped.gsub("model.norm.", "model.final_layernorm.") + + mapped = mapped.gsub(/\.layers\.(\d+)\.pre_mixer_norm\./) { ".layers.#{$1}.input_layernorm." } + mapped = mapped.gsub(/\.layers\.(\d+)\.pre_mlp_norm\./) { ".layers.#{$1}.pre_ff_layernorm." } + + mapped = mapped.gsub(".mixer.conv1d.", ".mamba.conv1d.") + mapped = mapped.gsub(".mixer.in_proj.", ".mamba.in_proj.") + mapped = mapped.gsub(".mixer.out_proj.", ".mamba.out_proj.") + mapped = mapped.gsub(".mixer.qkv_proj.", ".self_attn.q_proj.") + mapped = mapped.gsub(".mixer.q_proj.", ".self_attn.q_proj.") + mapped = mapped.gsub(".mixer.k_proj.", ".self_attn.k_proj.") + mapped = mapped.gsub(".mixer.v_proj.", ".self_attn.v_proj.") + mapped = mapped.gsub(".mixer.o_proj.", ".self_attn.o_proj.") + mapped = mapped.gsub(".mlp.gate_up_proj.", ".feed_forward.gate_proj.") + mapped = mapped.gsub(".mlp.gate_proj.", ".feed_forward.gate_proj.") + mapped = mapped.gsub(".mlp.up_proj.", ".feed_forward.up_proj.") + mapped = mapped.gsub(".mlp.down_proj.", ".feed_forward.down_proj.") + mapped + end + end + + Models.register("plamo2", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen.rb b/lib/mlx_lm/models/qwen.rb new file mode 100644 index 0000000..154062c --- /dev/null +++ b/lib/mlx_lm/models/qwen.rb @@ -0,0 +1,175 @@ +module MlxLm + module Models + module Qwen + class ModelArgs < BaseModelArgs + field :model_type, default: "qwen" + field :hidden_size, default: 2048 + field :num_attention_heads, default: 16 + field :num_hidden_layers, default: 24 + field :kv_channels, default: 128 + field :max_position_embeddings, default: 8192 + field :layer_norm_epsilon, default: 1e-6 + field :intermediate_size, default: 11008 + field :no_bias, default: true + field :vocab_size, default: 151936 + field :num_key_value_heads, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + hidden_size = args.hidden_size + @num_attention_heads = args.num_attention_heads + hidden_size_per_attention_head = hidden_size / @num_attention_heads + + self.rotary_emb = MLX::NN::RoPE.new( + hidden_size_per_attention_head, + traditional: false + ) + + @proj_size = args.kv_channels * @num_attention_heads + + self.c_attn = MLX::NN::Linear.new(hidden_size, @proj_size * 3, bias: true) + self.c_proj = MLX::NN::Linear.new(hidden_size, @proj_size, bias: !args.no_bias) + + @head_dim = args.kv_channels + @scale = hidden_size_per_attention_head**(-0.5) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + + qkv = c_attn.call(x) + q, k, v = mx.split(qkv, [@proj_size, 2 * @proj_size], -1) + + b, l, _ = q.shape + + queries = q.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rotary_emb.call(queries, offset: cache.offset) + keys = rotary_emb.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rotary_emb.call(queries) + keys = rotary_emb.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @proj_size]) + + c_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(args) + super() + + self.w1 = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size / 2, + bias: !args.no_bias + ) + self.w2 = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size / 2, + bias: !args.no_bias + ) + self.c_proj = MLX::NN::Linear.new( + args.intermediate_size / 2, + args.hidden_size, + bias: !args.no_bias + ) + end + + def call(x) + a1 = w1.call(x) + a2 = w2.call(x) + c_proj.call(Activations.swiglu(a2, a1)) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + + self.ln_1 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + self.attn = Attention.new(args) + self.ln_2 = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + self.mlp = MLP.new(args) + end + + def call(x, mask: nil, cache: nil) + residual = x + x = ln_1.call(x) + x = attn.call(x, mask: mask, cache: cache) + residual = x + residual + x = ln_2.call(residual) + x = mlp.call(x) + x + residual + end + end + + class QwenModel < MLX::NN::Module + def initialize(args) + super() + self.wte = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.h = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.ln_f = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.layer_norm_epsilon) + end + + def call(inputs, cache: nil) + x = wte.call(inputs) + layer_cache = cache || [nil] * h.length + + mask = nil + mask = "causal" if x.shape[1] > 1 + + h.each_with_index do |layer, i| + x = layer.call(x, mask: mask, cache: layer_cache[i]) + end + + ln_f.call(x) + end + end + + class Model < MLX::NN::Module + def initialize(config) + super() + @args = config + self.model_type = config.model_type + self.transformer = QwenModel.new(config) + self.lm_head = MLX::NN::Linear.new( + config.hidden_size, + config.vocab_size, + bias: !config.no_bias + ) + end + + def call(x, cache: nil) + y = transformer.call(x, cache: cache) + lm_head.call(y) + end + + def sanitize(weights) + weights.reject { |k, _| k.include?("rotary_emb.inv_freq") } + end + + def layers + transformer.h + end + end + + Models.register("qwen", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen2_moe.rb b/lib/mlx_lm/models/qwen2_moe.rb new file mode 100644 index 0000000..dae2589 --- /dev/null +++ b/lib/mlx_lm/models/qwen2_moe.rb @@ -0,0 +1,189 @@ +require_relative "activations" +require_relative "qwen2" +require_relative "switch_layers" + +module MlxLm + module Models + module Qwen2Moe + class ModelArgs < Qwen2::ModelArgs + field :model_type, default: "qwen2_moe" + field :num_key_value_heads, default: nil + field :num_experts_per_tok + field :num_experts + field :moe_intermediate_size + field :shared_expert_intermediate_size + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + validate_rope_scaling! + end + + private + + def validate_rope_scaling! + return unless @rope_scaling + + required_keys = %w[factor type] + unless required_keys.all? { |key| _rope_scaling_has_key?(key) } + raise ArgumentError, "rope_scaling must contain keys #{required_keys}" + end + + return if _rope_scaling_value("type") == "linear" + + raise ArgumentError, "rope_scaling 'type' currently only supports 'linear'" + end + + def _rope_scaling_has_key?(key) + @rope_scaling.key?(key) || @rope_scaling.key?(key.to_sym) + end + + def _rope_scaling_value(key) + return @rope_scaling[key] if @rope_scaling.key?(key) + + @rope_scaling[key.to_sym] + end + end + + class SharedExpertMLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class SparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + intermediate_size = args.moe_intermediate_size + shared_expert_intermediate_size = args.shared_expert_intermediate_size + + @num_experts = args.num_experts + @top_k = args.num_experts_per_tok + + self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false) + self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, intermediate_size, @num_experts) + + self.shared_expert = SharedExpertMLP.new(dim, shared_expert_intermediate_size) + self.shared_expert_gate = MLX::NN::Linear.new(dim, 1, bias: false) + end + + def call(x) + mx = MLX::Core + + gates = gate.call(x) + gates = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype) + + k = [@top_k, @num_experts].min + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + scores = mx.take_along_axis(gates, inds, -1) + + y = switch_mlp.call(x, inds) + y = mx.sum(y * mx.expand_dims(scores, -1), -2) + + shared_expert_output = shared_expert.call(x) + shared_expert_output = mx.sigmoid(shared_expert_gate.call(x)) * shared_expert_output + + y + shared_expert_output + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Qwen2::Attention.new(args) + self.mlp = SparseMoeBlock.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class Qwen2MoeModel < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Qwen2MoeModel.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + lm_head.call(model.call(inputs, cache: cache)) + end + + def sanitize(weights) + return weights unless weights.key?("model.layers.0.mlp.experts.0.up_proj.weight") + + mx = MLX::Core + result = weights.dup + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}" + %w[up_proj down_proj gate_proj].each do |projection| + %w[weight scales biases].each do |param| + first_key = "#{prefix}.mlp.experts.0.#{projection}.#{param}" + next unless result.key?(first_key) + + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.mlp.experts.#{expert_idx}.#{projection}.#{param}" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.mlp.switch_mlp.#{projection}.#{param}"] = mx.stack(stacked) + end + end + end + + result + end + + def layers + model.layers + end + end + + Models.register("qwen2_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen2_vl.rb b/lib/mlx_lm/models/qwen2_vl.rb new file mode 100644 index 0000000..7692987 --- /dev/null +++ b/lib/mlx_lm/models/qwen2_vl.rb @@ -0,0 +1,48 @@ +module MlxLm + module Models + module Qwen2VL + class ModelArgs < BaseModelArgs + field :model_type, default: "qwen2_vl" + field :text_config + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + new(model_type: params["model_type"] || params[:model_type], text_config: params) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Qwen2::Model.new(Qwen2::ModelArgs.from_dict(args.text_config)) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache) + end + + def sanitize(weights) + sanitized = {} + weights.each do |key, value| + next if key == "visual" || key.start_with?("visual.") + next if key == "vision_tower" || key.start_with?("vision_tower.") + + mapped_key = key.start_with?("language_model.") ? key : "language_model.#{key}" + sanitized[mapped_key] = value + end + sanitized + end + + def layers + language_model.model.layers + end + end + + Models.register("qwen2_vl", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3.rb b/lib/mlx_lm/models/qwen3.rb new file mode 100644 index 0000000..ceb2b15 --- /dev/null +++ b/lib/mlx_lm/models/qwen3.rb @@ -0,0 +1,167 @@ +module MlxLm + module Models + module Qwen3 + class ModelArgs < BaseModelArgs + field :model_type, default: "qwen3" + field :hidden_size, default: 2048 + field :num_hidden_layers, default: 24 + field :intermediate_size, default: 11008 + field :num_attention_heads, default: 16 + field :rms_norm_eps, default: 1e-6 + field :vocab_size, default: 151936 + field :num_key_value_heads, default: nil + field :max_position_embeddings, default: 32768 + field :rope_theta, default: 1_000_000.0 + field :head_dim, default: nil + field :tie_word_embeddings, default: true + field :rope_scaling, default: nil + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: false) + + self.q_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = MLX::NN::RMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @n_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @n_kv_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: false) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: false) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class Qwen3Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil, input_embeddings: nil) + h = input_embeddings || embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Qwen3Model.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = model.call(inputs, cache: cache, input_embeddings: input_embeddings) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.dup + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("qwen3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3_5.rb b/lib/mlx_lm/models/qwen3_5.rb new file mode 100644 index 0000000..a5c8cf1 --- /dev/null +++ b/lib/mlx_lm/models/qwen3_5.rb @@ -0,0 +1,69 @@ +module MlxLm + module Models + module Qwen35 + class ModelArgs < BaseModelArgs + field :model_type, default: "qwen3_5" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + new(model_type: params["model_type"] || params[:model_type], text_config: params) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Qwen3::Model.new(Qwen3::ModelArgs.from_dict(_text_config_for_qwen3(args))) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + end + + def sanitize(weights) + language_model.sanitize(remap_language_model_weights(weights)) + end + + def layers + language_model.layers + end + + protected + + def remap_language_model_weights(weights) + remapped = {} + weights.each do |key, value| + next if key.start_with?("model.visual") + + mapped_key = if key.start_with?("model.language_model") + key.sub("model.language_model", "language_model.model") + elsif key.start_with?("language_model.") + key + else + "language_model.#{key}" + end + remapped[mapped_key] = value + end + remapped + end + + private + + def _text_config_for_qwen3(args) + config = {} + (args.text_config || {}).each { |key, value| config[key.to_s] = value } + config["model_type"] ||= args.model_type + config["tie_word_embeddings"] = false unless config.key?("tie_word_embeddings") + config + end + end + + Models.register("qwen3_5", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3_5_moe.rb b/lib/mlx_lm/models/qwen3_5_moe.rb new file mode 100644 index 0000000..99f8198 --- /dev/null +++ b/lib/mlx_lm/models/qwen3_5_moe.rb @@ -0,0 +1,54 @@ +module MlxLm + module Models + module Qwen35Moe + class ModelArgs < Qwen35::ModelArgs + field :model_type, default: "qwen3_5_moe" + end + + class Model < Qwen35::Model + def sanitize(weights) + remapped = remap_language_model_weights(weights) + rewrite_moe_expert_weights(remapped) + language_model.sanitize(remapped) + end + + private + + def rewrite_moe_expert_weights(weights) + mx = MLX::Core + + layers.length.times do |layer_idx| + prefix = "language_model.model.layers.#{layer_idx}.mlp" + gate_up_key = _first_existing_key( + weights, + ["#{prefix}.experts.gate_up_proj", "#{prefix}.experts.gate_up_proj.weight"] + ) + down_proj_key = _first_existing_key( + weights, + ["#{prefix}.experts.down_proj", "#{prefix}.experts.down_proj.weight"] + ) + + next unless gate_up_key && down_proj_key + + gate_up = weights.delete(gate_up_key) + down_proj = weights.delete(down_proj_key) + mid = gate_up.shape[-2] / 2 + gate_proj, up_proj = mx.split(gate_up, [mid], -2) + + weights["#{prefix}.switch_mlp.gate_proj.weight"] = gate_proj + weights["#{prefix}.switch_mlp.up_proj.weight"] = up_proj + weights["#{prefix}.switch_mlp.down_proj.weight"] = down_proj + end + + weights + end + + def _first_existing_key(weights, candidates) + candidates.find { |key| weights.key?(key) } + end + end + + Models.register("qwen3_5_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3_moe.rb b/lib/mlx_lm/models/qwen3_moe.rb new file mode 100644 index 0000000..e9cf269 --- /dev/null +++ b/lib/mlx_lm/models/qwen3_moe.rb @@ -0,0 +1,166 @@ +require_relative "qwen3" +require_relative "switch_layers" + +module MlxLm + module Models + module Qwen3Moe + class ModelArgs < Qwen3::ModelArgs + field :model_type, default: "qwen3_moe" + field :num_experts, default: 128 + field :num_experts_per_tok, default: 8 + field :decoder_sparse_step, default: 1 + field :mlp_only_layers, default: [] + field :moe_intermediate_size, default: 1408 + field :norm_topk_prob, default: false + + def initialize(**kwargs) + super + @mlp_only_layers ||= [] + end + end + + class SparseMoeBlock < MLX::NN::Module + def initialize(args) + super() + @top_k = [args.num_experts_per_tok.to_i, 1].max + @num_experts = args.num_experts + @norm_topk_prob = args.norm_topk_prob + + dim = args.hidden_size + hidden_dim = args.moe_intermediate_size + + self.gate = MLX::NN::Linear.new(dim, @num_experts, bias: false) + self.switch_mlp = SwitchLayers::SwitchGLU.new(dim, hidden_dim, @num_experts) + end + + def call(x) + mx = MLX::Core + + gates = gate.call(x) + gates = mx.softmax(gates.astype(mx.float32), -1).astype(gates.dtype) + + k = [@top_k, @num_experts].min + inds = mx.stop_gradient(mx.argpartition(gates * -1.0, k - 1, -1)) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + inds = mx.take(inds, take_ids, -1) + scores = mx.take_along_axis(gates, inds, -1) + + if @norm_topk_prob + denom = mx.expand_dims(mx.sum(scores, -1), -1) + scores = scores / denom + end + + y = switch_mlp.call(x, inds) + mx.sum(y * mx.expand_dims(scores, -1), -2) + end + end + + class DecoderLayer < MLX::NN::Module + def initialize(args, layer_idx) + super() + self.self_attn = Qwen3::Attention.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + + if _use_sparse_moe_layer?(args, layer_idx) + self.mlp = SparseMoeBlock.new(args) + else + self.mlp = Qwen3::MLP.new(args.hidden_size, args.intermediate_size) + end + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + + private + + def _use_sparse_moe_layer?(args, layer_idx) + sparse_step = [args.decoder_sparse_step.to_i, 1].max + mlp_only_layers = args.mlp_only_layers || [] + + !mlp_only_layers.include?(layer_idx) && + args.num_experts.to_i > 0 && + ((layer_idx + 1) % sparse_step).zero? + end + end + + class Qwen3MoeModel < MLX::NN::Module + def initialize(args) + super() + @args = args + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |layer_idx| DecoderLayer.new(args, layer_idx) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil, input_embeddings: nil) + h = input_embeddings || embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, layer_idx| + h = layer.call(h, mask: mask, cache: layer_cache[layer_idx]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Qwen3MoeModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = model.call(inputs, cache: cache, input_embeddings: input_embeddings) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + mx = MLX::Core + + result = weights.dup + result.delete("lm_head.weight") if @args.tie_word_embeddings + return result unless result.key?("model.layers.0.mlp.experts.0.up_proj.weight") + + @args.num_hidden_layers.times do |layer_idx| + prefix = "model.layers.#{layer_idx}.mlp" + %w[up_proj down_proj gate_proj].each do |projection| + expert_keys = (0...@args.num_experts).map do |expert_idx| + "#{prefix}.experts.#{expert_idx}.#{projection}.weight" + end + next unless expert_keys.all? { |key| result.key?(key) } + + stacked = expert_keys.map { |key| result.delete(key) } + result["#{prefix}.switch_mlp.#{projection}.weight"] = mx.stack(stacked) + end + end + + result + end + + def layers + model.layers + end + end + + Models.register("qwen3_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3_next.rb b/lib/mlx_lm/models/qwen3_next.rb new file mode 100644 index 0000000..5bd69b9 --- /dev/null +++ b/lib/mlx_lm/models/qwen3_next.rb @@ -0,0 +1,147 @@ +require_relative "kimi_linear" + +module MlxLm + module Models + module Qwen3Next + class ModelArgs < KimiLinear::ModelArgs + field :model_type, default: "qwen3_next" + field :linear_num_value_heads, default: nil + field :linear_num_key_heads, default: nil + field :linear_key_head_dim, default: nil + field :linear_value_head_dim, default: nil + field :linear_conv_kernel_dim, default: nil + field :decoder_sparse_step, default: nil + field :shared_expert_intermediate_size, default: nil + field :mlp_only_layers, default: [] + field :full_attention_interval, default: 4 + field :head_dim, default: nil + field :attention_bias, default: false + field :num_shared_experts, default: 1 + field :norm_topk_prob, default: false + field :first_k_dense_replace, default: 0 + + def self.from_dict(params) + normalized = params.each_with_object({}) do |(key, value), out| + out[key.to_s] = value + end + + { + "shared_expert_intermediate_size" => "moe_shared_expert_intermediate_size", + }.each do |source_key, target_key| + next unless normalized.key?(source_key) + + normalized[target_key] = normalized[source_key] unless normalized.key?(target_key) + end + + if normalized.key?("attention_bias") + normalized["use_bias"] = normalized["attention_bias"] unless normalized.key?("use_bias") + normalized["use_qkv_bias"] = normalized["attention_bias"] unless normalized.key?("use_qkv_bias") + end + + if normalized.key?("linear_num_key_heads") && !normalized.key?("num_key_value_heads") + normalized["num_key_value_heads"] = normalized["linear_num_key_heads"] + end + + if normalized.key?("mlp_only_layers") && !normalized.key?("first_k_dense_replace") + normalized["first_k_dense_replace"] = _dense_prefix_length(normalized["mlp_only_layers"]) + end + + normalized["num_shared_experts"] = 1 unless normalized.key?("num_shared_experts") + normalized["norm_topk_prob"] = false unless normalized.key?("norm_topk_prob") + normalized["first_k_dense_replace"] = 0 unless normalized.key?("first_k_dense_replace") + normalized["model_type"] ||= "qwen3_next" + super(normalized) + end + + def initialize(**kwargs) + super + @moe_shared_expert_intermediate_size = @shared_expert_intermediate_size if kwargs.key?(:shared_expert_intermediate_size) && !kwargs.key?(:moe_shared_expert_intermediate_size) && !@shared_expert_intermediate_size.nil? + + if kwargs.key?(:attention_bias) && !@attention_bias.nil? + @use_bias = @attention_bias unless kwargs.key?(:use_bias) + @use_qkv_bias = @attention_bias unless kwargs.key?(:use_qkv_bias) + end + + if kwargs.key?(:mlp_only_layers) && !kwargs.key?(:first_k_dense_replace) + @first_k_dense_replace = self.class._dense_prefix_length(@mlp_only_layers) + end + + @num_shared_experts = 1 if @num_shared_experts.nil? + @norm_topk_prob = false if @norm_topk_prob.nil? + @first_k_dense_replace = 0 if @first_k_dense_replace.nil? + @num_key_value_heads ||= @num_attention_heads + end + + def to_kimi_linear_dict + dict = to_bailing_moe_linear_dict + dict["model_type"] = @model_type + dict["num_shared_experts"] = @num_shared_experts || 1 + dict["norm_topk_prob"] = @norm_topk_prob.nil? ? false : @norm_topk_prob + dict["first_k_dense_replace"] = @first_k_dense_replace || 0 + dict["use_bias"] = @use_bias + dict["use_qkv_bias"] = @use_qkv_bias + dict["moe_shared_expert_intermediate_size"] = @moe_shared_expert_intermediate_size unless @moe_shared_expert_intermediate_size.nil? + dict + end + + def self._dense_prefix_length(mlp_only_layers) + layers = Array(mlp_only_layers).map(&:to_i) + count = 0 + count += 1 while layers.include?(count) + count + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = KimiLinear::Model.new( + KimiLinear::ModelArgs.from_dict(args.to_kimi_linear_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + flat_weights = weights.is_a?(Hash) ? weights : weights.to_h + flat_weights.each do |key, value| + mapped = key.to_s.gsub(".mlp.shared_expert.", ".mlp.shared_experts.") + next if mapped.include?(".mtp.") + + remapped[mapped] = value + end + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache) + + nil + end + + def cast_predicate + return wrapped_model.cast_predicate if wrapped_model.respond_to?(:cast_predicate) + + lambda { |_key| true } + end + + def quant_predicate + return wrapped_model.quant_predicate if wrapped_model.respond_to?(:quant_predicate) + + lambda { |_key, _value| true } + end + end + + Models.register("qwen3_next", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3_vl.rb b/lib/mlx_lm/models/qwen3_vl.rb new file mode 100644 index 0000000..b763d2b --- /dev/null +++ b/lib/mlx_lm/models/qwen3_vl.rb @@ -0,0 +1,48 @@ +module MlxLm + module Models + module Qwen3VL + class ModelArgs < BaseModelArgs + field :model_type, default: "qwen3_vl" + field :text_config, default: nil + + def self.from_dict(params) + return super if params.key?("text_config") + + new(model_type: params["model_type"], text_config: params) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Qwen3::Model.new(Qwen3::ModelArgs.from_dict(args.text_config)) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + end + + def sanitize(weights) + nested = MLX::Utils.tree_unflatten(weights.to_a) + nested.delete("vision_tower") if nested.is_a?(Hash) + + flattened = MLX::Utils.tree_flatten(nested, destination: {}) + sanitized = {} + flattened.each do |key, value| + sanitized_key = key.start_with?("language_model.") ? key : "language_model.#{key}" + sanitized[sanitized_key] = value + end + sanitized + end + + def layers + language_model.layers + end + end + + Models.register("qwen3_vl", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/qwen3_vl_moe.rb b/lib/mlx_lm/models/qwen3_vl_moe.rb new file mode 100644 index 0000000..bac9f73 --- /dev/null +++ b/lib/mlx_lm/models/qwen3_vl_moe.rb @@ -0,0 +1,92 @@ +require_relative "qwen3_moe" + +module MlxLm + module Models + module Qwen3VLMoe + class ModelArgs < BaseModelArgs + field :model_type, default: "qwen3_vl_moe" + field :text_config, default: nil + + def self.from_dict(params) + has_text_config = params.key?("text_config") || params.key?(:text_config) + return super if has_text_config + + new(model_type: params["model_type"] || params[:model_type], text_config: params) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.language_model = Qwen3Moe::Model.new(Qwen3Moe::ModelArgs.from_dict(args.text_config)) + end + + def call(inputs, cache: nil, input_embeddings: nil) + language_model.call(inputs, cache: cache, input_embeddings: input_embeddings) + end + + def sanitize(weights) + nested = MLX::Utils.tree_unflatten(weights.to_a) + nested.delete("visual") if nested.is_a?(Hash) + + language_model_tree = {} + if nested.is_a?(Hash) + language_model_node = nested["language_model"] + if language_model_node.is_a?(Hash) + language_model_tree["model"] = language_model_node["model"] if language_model_node.key?("model") + language_model_tree["lm_head"] = language_model_node["lm_head"] if language_model_node.key?("lm_head") + end + end + + flattened = MLX::Utils.tree_flatten({ "language_model" => language_model_tree }, destination: {}) + sanitized = flattened.is_a?(Hash) ? flattened : {} + rewrite_moe_expert_weights(sanitized) + sanitized + end + + def layers + language_model.model.layers + end + + private + + def rewrite_moe_expert_weights(weights) + mx = MLX::Core + + layers.length.times do |layer_idx| + prefix = "language_model.model.layers.#{layer_idx}.mlp" + gate_up_key = _first_existing_key( + weights, + ["#{prefix}.experts.gate_up_proj", "#{prefix}.experts.gate_up_proj.weight"] + ) + down_proj_key = _first_existing_key( + weights, + ["#{prefix}.experts.down_proj", "#{prefix}.experts.down_proj.weight"] + ) + + next unless gate_up_key && down_proj_key + + gate_up = weights.delete(gate_up_key) + down_proj = weights.delete(down_proj_key) + mid = gate_up.shape[-1] / 2 + gate_proj, up_proj = mx.split(gate_up, [mid], -1) + + weights["#{prefix}.switch_mlp.gate_proj.weight"] = mx.swapaxes(gate_proj, -2, -1) + weights["#{prefix}.switch_mlp.up_proj.weight"] = mx.swapaxes(up_proj, -2, -1) + weights["#{prefix}.switch_mlp.down_proj.weight"] = mx.swapaxes(down_proj, -2, -1) + end + + weights + end + + def _first_existing_key(weights, candidates) + candidates.find { |key| weights.key?(key) } + end + end + + Models.register("qwen3_vl_moe", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/recurrent_gemma.rb b/lib/mlx_lm/models/recurrent_gemma.rb new file mode 100644 index 0000000..29e0243 --- /dev/null +++ b/lib/mlx_lm/models/recurrent_gemma.rb @@ -0,0 +1,444 @@ +require_relative "cache" + +module MlxLm + module Models + module RecurrentGemma + class ModelArgs < BaseModelArgs + field :model_type, default: "recurrent_gemma" + field :attention_bias + field :conv1d_width + field :hidden_size + field :intermediate_size + field :logits_soft_cap + field :num_attention_heads + field :num_hidden_layers + field :num_key_value_heads + field :rms_norm_eps + field :rope_theta + field :attention_window_size + field :vocab_size + field :embeddings_scale_by_sqrt_dim, default: true + field :block_types, default: nil + field :_block_types, default: nil + + def initialize(**kwargs) + super + @block_types ||= @_block_types + @block_types ||= ["recurrent", "attention"] + end + end + + class RMSNorm < MLX::NN::Module + def initialize(dims, eps: 1e-5) + super() + self.weight = MLX::Core.ones([dims]) + @eps = eps + end + + def call(x) + mx = MLX::Core + mean_sq = mx.mean(x * x, -1, keepdims: true) + norm = x * mx.rsqrt(mean_sq + @eps) + norm * (weight + 1.0) + end + end + + class RGLRU < MLX::NN::Module + def initialize(width:, num_heads:) + super() + @width = width + @num_heads = num_heads + @head_dim = @width / @num_heads + + mx = MLX::Core + self.recurrent_param = mx.zeros([@width]) + self.input_gate_weight = mx.zeros([@num_heads, @head_dim, @head_dim]) + self.input_gate_bias = mx.zeros([@num_heads, @head_dim]) + self.recurrent_gate_weight = mx.zeros([@num_heads, @head_dim, @head_dim]) + self.recurrent_gate_bias = mx.zeros([@num_heads, @head_dim]) + end + + def call(x, cache: nil) + mx = MLX::Core + b, l, _ = x.shape + + gate_x = _apply_block_linear(x, input_gate_weight, input_gate_bias, batch: b, seq: l) + gate_a = _apply_block_linear(x, recurrent_gate_weight, recurrent_gate_bias, batch: b, seq: l) + + log_a = -8.0 * gate_a * MLX::NN.softplus(recurrent_param) + a = mx.exp(log_a) + a_square = mx.exp(2.0 * log_a) + + gated_x = x * gate_x + multiplier = mx.sqrt(1.0 - a_square) + if cache.nil? + first = mx.ones([b, 1, @width], multiplier.dtype) + if l == 1 + multiplier = first + else + rest = mx.split(multiplier, [1], 1)[1] + multiplier = mx.concatenate([first, rest], 1) + end + end + + normalized_x = gated_x * multiplier.astype(x.dtype) + _rnn_scan(normalized_x, a, cache) + end + + private + + def _apply_block_linear(h, w, b, batch:, seq:) + mx = MLX::Core + h = h.reshape([batch, seq, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + h = mx.matmul(h, w).transpose([0, 2, 1, 3]) + b + mx.sigmoid(h.reshape([batch, seq, @width])) + end + + def _rnn_scan(x, a, h0) + mx = MLX::Core + b, l, d = x.shape + + if l == 1 + if h0.nil? + return x, _slice_step(x, 0) + end + + y = a * mx.expand_dims(h0, 1) + x + return y, _slice_step(y, 0) + end + + h_t = h0 || mx.zeros([b, d], x.dtype) + ys = [] + l.times do |t| + h_t = _slice_step(a, t) * h_t + _slice_step(x, t) + ys << h_t + end + [mx.stack(ys, 1), h_t] + end + + def _slice_step(array, idx) + mx = MLX::Core + idx_arr = mx.array([idx], dtype: mx.int32) + mx.squeeze(mx.take(array, idx_arr, 1), 1) + end + end + + class RecurrentBlock < MLX::NN::Module + def initialize(width:, num_heads:, lru_width: nil, conv1d_temporal_width: 4) + super() + @width = width + @num_heads = num_heads + @lru_width = lru_width || width + @conv1d_temporal_width = conv1d_temporal_width + + self.linear_y = MLX::NN::Linear.new(width, @lru_width) + self.linear_x = MLX::NN::Linear.new(width, @lru_width) + self.linear_out = MLX::NN::Linear.new(@lru_width, width) + self.conv_1d = MLX::NN::Conv1d.new( + @lru_width, + @lru_width, + @conv1d_temporal_width, + groups: @lru_width, + bias: true, + padding: 0 + ) + self.rg_lru = RGLRU.new(width: @lru_width, num_heads: @num_heads) + end + + def call(x, cache: nil, mask: nil) + _ = mask + mx = MLX::Core + + y = MLX::NN.gelu_approx(linear_y.call(x)) + x = linear_x.call(x) + + conv_cache = _read_cache(cache, 0) + rnn_cache = _read_cache(cache, 1) + + x = if conv_cache + mx.concatenate([conv_cache, x], 1) + else + mx.pad(x, [[0, 0], [@conv1d_temporal_width - 1, 0], [0, 0]]) + end + + conv_input = x + x = conv_1d.call(x) + _write_cache(cache, 0, _tail_cache(conv_input)) + + x, last_h = rg_lru.call(x, cache: rnn_cache) + _write_cache(cache, 1, last_h) + + linear_out.call(x * y) + end + + private + + def _tail_cache(full_x) + mx = MLX::Core + n_keep = @conv1d_temporal_width - 1 + return mx.zeros([full_x.shape[0], 0, full_x.shape[2]], full_x.dtype) if n_keep <= 0 + + split_at = full_x.shape[1] - n_keep + mx.split(full_x, [split_at], 1)[1] + end + + def _read_cache(cache, idx) + if cache.is_a?(MlxLm::ArraysCache) || cache.is_a?(Array) + cache[idx] + else + nil + end + end + + def _write_cache(cache, idx, value) + return unless cache.is_a?(MlxLm::ArraysCache) || cache.is_a?(Array) + + cache[idx] = value + end + end + + class LocalAttentionBlock < MLX::NN::Module + def initialize(width:, num_heads:, window_size:) + super() + @width = width + @num_heads = num_heads + @window_size = window_size + @scale = (width / num_heads)**(-0.5) + @head_dim = @width / @num_heads + + self.q_proj = MLX::NN::Linear.new(@width, @width, bias: false) + self.k_proj = MLX::NN::Linear.new(@width, @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(@width, @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@width, @width, bias: true) + self.rope = MLX::NN::RoPE.new(@head_dim / 2, traditional: false) + end + + def call(x, cache: nil, mask: nil) + mx = MLX::Core + b, l, _ = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @num_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, 1, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, 1, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @width]) + o_proj.call(output) + end + end + + class MLPBlock < MLX::NN::Module + def initialize(width:, expanded_width:) + super() + hidden = expanded_width / 2 + self.up_proj = MLX::NN::Linear.new(width, hidden) + self.gate_proj = MLX::NN::Linear.new(width, hidden) + self.down_proj = MLX::NN::Linear.new(hidden, width) + end + + def call(x) + down_proj.call(MLX::NN.gelu_approx(gate_proj.call(x)) * up_proj.call(x)) + end + end + + class ResidualBlock < MLX::NN::Module + attr_reader :temporal_block_type + + def initialize( + width:, + mlp_expanded_width:, + num_heads:, + attention_window_size:, + temporal_block_type:, + lru_width: nil, + conv1d_temporal_width: 4 + ) + super() + @temporal_block_type = temporal_block_type + + self.temporal_pre_norm = RMSNorm.new(width) + self.temporal_block = if temporal_block_type == "recurrent" + RecurrentBlock.new( + width: width, + num_heads: num_heads, + lru_width: lru_width, + conv1d_temporal_width: conv1d_temporal_width + ) + else + LocalAttentionBlock.new( + width: width, + num_heads: num_heads, + window_size: attention_window_size + ) + end + + self.channel_pre_norm = RMSNorm.new(width) + self.mlp_block = MLPBlock.new(width: width, expanded_width: mlp_expanded_width) + end + + def call(x, cache: nil, mask: nil) + raw_x = x + x = temporal_block.call(temporal_pre_norm.call(raw_x), cache: cache, mask: mask) + residual = x + raw_x + x = mlp_block.call(channel_pre_norm.call(residual)) + x + residual + end + end + + class Griffin < MLX::NN::Module + attr_reader :window_size, :swa_idx + + def initialize(config) + super() + @config = config + @scale_by_sqrt_dim = config.embeddings_scale_by_sqrt_dim + + block_types = Array(config.block_types) + block_types = ["recurrent"] if block_types.empty? + + self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size) + self.layers = Array.new(config.num_hidden_layers) do |i| + ResidualBlock.new( + width: config.hidden_size, + mlp_expanded_width: config.intermediate_size, + num_heads: config.num_attention_heads, + attention_window_size: config.attention_window_size, + temporal_block_type: block_types[i % block_types.length], + lru_width: nil, + conv1d_temporal_width: config.conv1d_width + ) + end + self.final_norm = RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + + @window_size = config.attention_window_size + @swa_idx = block_types.index("attention") || 0 + end + + def call(tokens, cache: nil) + x = embed_tokens.call(tokens) + x = x * Math.sqrt(x.shape[-1]) if @scale_by_sqrt_dim + + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(x, layer_cache[@swa_idx], window_size: @window_size) + + layers.each_with_index do |block, i| + x = block.call(x, mask: mask, cache: layer_cache[i]) + end + + final_norm.call(x) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + attr_reader :args + + def initialize(config) + super() + @args = config + @tie_word_embeddings = false + self.model_type = config.model_type + self.model = Griffin.new(config) + self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false) + end + + def call(tokens, cache: nil) + mx = MLX::Core + logits = model.call(tokens, cache: cache) + logits = if @tie_word_embeddings || lm_head.nil? + model.embed_tokens.as_linear(logits) + else + lm_head.call(logits) + end + + c = args.logits_soft_cap + logits = mx.tanh(logits / c) * c if c && c != 0 + logits + end + + def layers + model.layers + end + + def sanitize(weights) + mx = MLX::Core + sanitized = {} + weights.each do |key, value| + current = value + if key.include?("conv_1d.weight") && value.shape[-1] != 1 + current = mx.swapaxes(value, 1, 2) + end + sanitized[key] = current + end + + unless sanitized.key?("lm_head.weight") + @tie_word_embeddings = true + self.lm_head = nil + end + + sanitized + end + + def make_cache + layers.map do |layer| + if layer.temporal_block_type == "recurrent" + MlxLm::ArraysCache.new(2) + else + MlxLm::RotatingKVCache.new(max_size: args.attention_window_size) + end + end + end + end + + Models.register("recurrent_gemma", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/rope_utils.rb b/lib/mlx_lm/models/rope_utils.rb new file mode 100644 index 0000000..047ebeb --- /dev/null +++ b/lib/mlx_lm/models/rope_utils.rb @@ -0,0 +1,316 @@ +module MlxLm + module Models + class SuScaledRoPE < MLX::NN::Module + def initialize( + dims, + base: 10_000.0, + max_position_embeddings: 131_072, + original_max_position_embeddings: 4096, + short_factor: 1.0, + long_factor: 1.0, + short_mscale: nil, + long_mscale: nil + ) + super() + mx = MLX::Core + @dim = dims + @original_max_position_embeddings = original_max_position_embeddings + + freqs = mx.power( + base.to_f, + mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f) + ) + self._freqs = mx.multiply(mx.array(long_factor, dtype: mx.float32), freqs) + + factor = max_position_embeddings.to_f / original_max_position_embeddings + self._scale = long_mscale || if factor <= 1.0 + 1.0 + else + Math.sqrt(1 + Math.log(factor) / Math.log(original_max_position_embeddings)) + end + end + + def call(x, offset: 0) + mx = MLX::Core + x = scale_rotary_part(x, _scale) + mx.rope(x, @dim, false, nil, 1.0, offset, _freqs) + end + + private + + def scale_rotary_part(x, scale) + return x if scale == 1.0 + + mx = MLX::Core + rotary, rest = mx.split(x, [@dim], -1) + mx.concatenate([mx.multiply(rotary, scale), rest], -1) + end + end + + class Llama3RoPE < MLX::NN::Module + def initialize( + dims:, + max_position_embeddings: 2048, + traditional: false, + base: 10_000, + scaling_config: nil + ) + super() + mx = MLX::Core + + @dims = dims + @max_position_embeddings = max_position_embeddings + @traditional = traditional + + factor = config_value(scaling_config, "factor") + low_freq_factor = config_value(scaling_config, "low_freq_factor", 1.0) + high_freq_factor = config_value(scaling_config, "high_freq_factor", 4.0) + old_context_len = config_value( + scaling_config, + "original_max_position_embeddings", + 8192 + ) + + low_freq_wavelen = old_context_len.to_f / low_freq_factor + high_freq_wavelen = old_context_len.to_f / high_freq_factor + + freqs = mx.power( + base.to_f, + mx.divide(mx.arange(0, dims, 2), dims.to_f) + ) + wavelens = mx.multiply(2.0 * Math::PI, freqs) + + freqs = mx.where( + mx.greater(wavelens, low_freq_wavelen), + mx.multiply(freqs, factor), + freqs + ) + + is_medium_freq = mx.logical_and( + mx.greater(wavelens, high_freq_wavelen), + mx.less(wavelens, low_freq_wavelen) + ) + + smooth_factors = mx.divide( + mx.subtract(mx.divide(old_context_len.to_f, wavelens), low_freq_factor), + (high_freq_factor - low_freq_factor).to_f + ) + + smooth_freqs = mx.divide( + freqs, + mx.add( + mx.divide(mx.subtract(1.0, smooth_factors), factor.to_f), + smooth_factors + ) + ) + + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + end + + def extra_repr + "#{@dims}, traditional=#{@traditional}, max_position_embeddings=#{@max_position_embeddings}" + end + + def call(x, offset: 0) + MLX::Core.rope(x, @dims, @traditional, nil, 1.0, offset, _freqs) + end + + private + + def config_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + end + + class YarnRoPE < MLX::NN::Module + def initialize( + dims, + traditional: false, + max_position_embeddings: 2048, + base: 10_000, + scaling_factor: 1.0, + original_max_position_embeddings: 4096, + beta_fast: 32, + beta_slow: 1, + mscale: 1, + mscale_all_dim: 0 + ) + super() + mx = MLX::Core + + self.mscale = yarn_get_mscale(scaling_factor, mscale) / + yarn_get_mscale(scaling_factor, mscale_all_dim) + + freq_extra = mx.power( + base.to_f, + mx.divide(mx.arange(0, dims, 2, mx.float32), dims.to_f) + ) + freq_inter = mx.multiply(scaling_factor.to_f, freq_extra) + + low, high = yarn_find_correction_range( + dims, + base, + original_max_position_embeddings, + beta_fast, + beta_slow + ) + + freq_mask = mx.subtract(1.0, yarn_linear_ramp_mask(low, high, dims / 2)) + self._freqs = mx.divide( + mx.multiply(freq_inter, freq_extra), + mx.add( + mx.multiply(freq_inter, freq_mask), + mx.multiply(freq_extra, mx.subtract(1.0, freq_mask)) + ) + ) + + @dims = dims + @traditional = traditional + end + + def call(x, offset: 0) + mx = MLX::Core + x = scale_rotary_part(x, mscale) unless mscale == 1.0 + + mx.rope(x, @dims, @traditional, nil, 1.0, offset, _freqs) + end + + private + + def scale_rotary_part(x, scale) + mx = MLX::Core + rotary, rest = mx.split(x, [@dims], -1) + mx.concatenate([mx.multiply(rotary, scale), rest], -1) + end + + def yarn_find_correction_dim(dims, base, original_max_position_embeddings, num_rotations) + dims * Math.log(original_max_position_embeddings.to_f / (num_rotations * 2 * Math::PI)) / + (2 * Math.log(base)) + end + + def yarn_find_correction_range(dims, base, original_max_position_embeddings, beta_fast, beta_slow) + low = yarn_find_correction_dim(dims, base, original_max_position_embeddings, beta_fast).floor + high = yarn_find_correction_dim(dims, base, original_max_position_embeddings, beta_slow).ceil + [ + [low, 0].max, + [high, dims - 1].min, + ] + end + + def yarn_get_mscale(scale = 1, mscale = 1) + return 1.0 if scale <= 1 + + 0.1 * mscale * Math.log(scale) + 1.0 + end + + def yarn_linear_ramp_mask(min_val, max_val, dim) + mx = MLX::Core + + max_val += 0.001 if min_val == max_val + + linear = mx.divide( + mx.subtract(mx.arange(0, dim, 1, mx.float32), min_val), + max_val - min_val + ) + mx.clip(linear, 0.0, 1.0) + end + end + + module_function + + def initialize_rope( + dims, + base, + traditional, + scaling_config = nil, + max_position_embeddings: nil + ) + rope_type = if scaling_config + config_value(scaling_config, "type") || + config_value(scaling_config, "rope_type", "default") + else + "default" + end + + case rope_type + when "default", "linear" + scale = rope_type == "linear" ? 1.0 / config_value(scaling_config, "factor") : 1.0 + MLX::NN::RoPE.new(dims, traditional: traditional, base: base, scale: scale) + when "llama3" + Llama3RoPE.new( + dims: dims, + max_position_embeddings: max_position_embeddings, + traditional: traditional, + base: base, + scaling_config: scaling_config + ) + when "yarn", "deepseek_yarn", "telechat3-yarn" + rope_kwargs = {} + %w[ + original_max_position_embeddings + beta_fast + beta_slow + mscale + mscale_all_dim + ].each do |key| + value = config_value(scaling_config, key) + rope_kwargs[key.to_sym] = value unless value.nil? + end + + YarnRoPE.new( + dims, + max_position_embeddings: max_position_embeddings, + traditional: traditional, + scaling_factor: config_value(scaling_config, "factor"), + base: base, + **rope_kwargs + ) + when "longrope" + SuScaledRoPE.new( + dims, + base: base, + max_position_embeddings: max_position_embeddings, + original_max_position_embeddings: config_value( + scaling_config, + "original_max_position_embeddings" + ), + short_factor: config_value(scaling_config, "short_factor"), + long_factor: config_value(scaling_config, "long_factor") + ) + when "mrope" + mrope_section = config_value(scaling_config, "mrope_section", []) + unless mrope_section.length == 3 + raise ArgumentError, + "MRoPE currently only supports 3 sections, got #{mrope_section.length}." + end + + MLX::NN::RoPE.new(dims, traditional: traditional, base: base) + else + raise ArgumentError, "Unsupported RoPE type #{rope_type}" + end + end + + def config_value(config, key, default = nil) + return default if config.nil? + return config[key] if config.key?(key) + + config.fetch(key.to_sym, default) + end + private_class_method :config_value + + module RoPEUtils + SuScaledRoPE = MlxLm::Models::SuScaledRoPE + Llama3RoPE = MlxLm::Models::Llama3RoPE + YarnRoPE = MlxLm::Models::YarnRoPE + + module_function + + def initialize_rope(*args, **kwargs) + MlxLm::Models.initialize_rope(*args, **kwargs) + end + end + end +end diff --git a/lib/mlx_lm/models/rwkv7.rb b/lib/mlx_lm/models/rwkv7.rb new file mode 100644 index 0000000..bad530e --- /dev/null +++ b/lib/mlx_lm/models/rwkv7.rb @@ -0,0 +1,101 @@ +require_relative "recurrent_gemma" + +module MlxLm + module Models + module Rwkv7 + class ModelArgs < BaseModelArgs + field :model_type, default: "rwkv7" + field :vocab_size + field :hidden_size + field :intermediate_size + field :norm_eps, default: 1e-5 + field :head_dim + field :num_hidden_layers + field :a_low_rank_dim, default: nil + field :v_low_rank_dim, default: nil + field :gate_low_rank_dim, default: nil + field :decay_low_rank_dim, default: nil + field :tie_word_embeddings, default: false + field :rope_theta, default: 10_000.0 + field :attention_window_size, default: 128 + field :block_types, default: nil + field :num_attention_heads, default: nil + field :num_key_value_heads, default: nil + + def initialize(**kwargs) + super + if @num_attention_heads.nil? && !@hidden_size.nil? && !@head_dim.nil? && @head_dim.to_i > 0 + @num_attention_heads = @hidden_size / @head_dim + end + @num_attention_heads ||= 1 + @num_key_value_heads ||= @num_attention_heads + @block_types ||= Array.new(@num_hidden_layers.to_i, "recurrent") + end + + def to_recurrent_gemma_dict + { + "model_type" => @model_type, + "attention_bias" => false, + "conv1d_width" => 3, + "hidden_size" => @hidden_size, + "intermediate_size" => @intermediate_size, + "logits_soft_cap" => nil, + "num_attention_heads" => @num_attention_heads, + "num_hidden_layers" => @num_hidden_layers, + "num_key_value_heads" => @num_key_value_heads, + "rms_norm_eps" => @norm_eps, + "rope_theta" => @rope_theta, + "attention_window_size" => @attention_window_size, + "vocab_size" => @vocab_size, + "embeddings_scale_by_sqrt_dim" => false, + "block_types" => @block_types, + } + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.wrapped_model = RecurrentGemma::Model.new( + RecurrentGemma::ModelArgs.from_dict(args.to_recurrent_gemma_dict) + ) + end + + def call(inputs, cache: nil) + wrapped_model.call(inputs, cache: cache) + end + + def sanitize(weights) + remapped = {} + weights.each do |key, value| + remapped[_remap_weight_key(key)] = value + end + wrapped_model.sanitize(remapped) + end + + def layers + wrapped_model.layers + end + + def make_cache + return wrapped_model.make_cache if wrapped_model.respond_to?(:make_cache) + + nil + end + + private + + def _remap_weight_key(key) + mapped = key.dup + mapped = mapped.gsub(/\Ablocks\./, "model.layers.") + mapped = mapped.gsub(".time_mix.", ".temporal_block.") + mapped + end + end + + Models.register("rwkv7", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/seed_oss.rb b/lib/mlx_lm/models/seed_oss.rb new file mode 100644 index 0000000..c420538 --- /dev/null +++ b/lib/mlx_lm/models/seed_oss.rb @@ -0,0 +1,167 @@ +module MlxLm + module Models + module SeedOSS + class ModelArgs < BaseModelArgs + field :model_type, default: "seed_oss" + field :hidden_size, default: 4096 + field :num_hidden_layers, default: 32 + field :intermediate_size, default: 11008 + field :num_attention_heads, default: 32 + field :rms_norm_eps, default: 1e-6 + field :vocab_size, default: 151936 + field :num_key_value_heads, default: nil + field :head_dim, default: nil + field :max_position_embeddings, default: nil + field :attention_bias, default: false + field :attention_out_bias, default: false + field :mlp_bias, default: false + field :rope_theta, default: 10000.0 + field :rope_traditional, default: false + field :rope_scaling, default: nil + field :tie_word_embeddings, default: true + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Attention < MLX::NN::Module + def initialize(args) + super() + + dim = args.hidden_size + @n_heads = args.num_attention_heads + @n_kv_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + input_bias = args.attention_bias + output_bias = args.attention_out_bias + + self.q_proj = MLX::NN::Linear.new(dim, @n_heads * @head_dim, bias: input_bias) + self.k_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: input_bias) + self.v_proj = MLX::NN::Linear.new(dim, @n_kv_heads * @head_dim, bias: input_bias) + self.o_proj = MLX::NN::Linear.new(@n_heads * @head_dim, dim, bias: output_bias) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + args.rope_traditional, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x).reshape([b, l, @n_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = k_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = v_proj.call(x).reshape([b, l, @n_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @n_heads * @head_dim]) + o_proj.call(output) + end + end + + class MLP < MLX::NN::Module + def initialize(dim, hidden_dim, bias: false) + super() + self.gate_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + self.down_proj = MLX::NN::Linear.new(hidden_dim, dim, bias: bias) + self.up_proj = MLX::NN::Linear.new(dim, hidden_dim, bias: bias) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class TransformerBlock < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Attention.new(args) + self.mlp = MLP.new(args.hidden_size, args.intermediate_size, bias: args.mlp_bias) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class SeedModel < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { TransformerBlock.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = SeedModel.new(args) + self.tie_word_embeddings = args.tie_word_embeddings + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + h = model.call(inputs, cache: cache) + if tie_word_embeddings + model.embed_tokens.as_linear(h) + else + lm_head.call(h) + end + end + + def sanitize(weights) + weights.delete("lm_head.weight") if tie_word_embeddings + weights + end + + def layers + model.layers + end + end + + Models.register("seed_oss", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/smollm3.rb b/lib/mlx_lm/models/smollm3.rb new file mode 100644 index 0000000..49f0f16 --- /dev/null +++ b/lib/mlx_lm/models/smollm3.rb @@ -0,0 +1,89 @@ +module MlxLm + module Models + module SmolLM3 + class ModelArgs < Llama::ModelArgs + field :model_type, default: "smollm3" + field :no_rope_layer_interval, default: 4 + field :no_rope_layers, default: nil + + def initialize(**kwargs) + super + + if @no_rope_layers.nil? + @no_rope_layers = Array.new(@num_hidden_layers) do |i| + ((i + 1) % @no_rope_layer_interval).zero? ? 0 : 1 + end + elsif @no_rope_layers.length != @num_hidden_layers + raise ArgumentError, "`no_rope_layers` length mismatch" + end + end + end + + class NoPE < MLX::NN::Module + def call(x, offset: 0) + x + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Llama::LlamaModel.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + args.no_rope_layers.each_with_index do |use_rope, idx| + next if use_rope && use_rope != 0 + + model.layers[idx].self_attn.rope = NoPE.new + end + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = if input_embeddings.nil? + model.call(inputs, cache: cache) + else + _call_with_input_embeddings(input_embeddings, cache) + end + + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def layers + model.layers + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + private + + def _call_with_input_embeddings(input_embeddings, cache) + h = input_embeddings + layer_cache = cache || [nil] * model.layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + model.layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + model.norm.call(h) + end + end + + Models.register("smollm3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/solar_open.rb b/lib/mlx_lm/models/solar_open.rb new file mode 100644 index 0000000..7ea0b04 --- /dev/null +++ b/lib/mlx_lm/models/solar_open.rb @@ -0,0 +1,79 @@ +require_relative "deepseek" + +module MlxLm + module Models + module SolarOpen + class ModelArgs < BaseModelArgs + field :model_type, default: "solar_open" + field :vocab_size + field :hidden_size + field :intermediate_size + field :moe_intermediate_size + field :num_hidden_layers + field :num_attention_heads + field :num_key_value_heads + field :head_dim + field :n_shared_experts + field :n_routed_experts + field :routed_scaling_factor + field :num_experts_per_tok + field :first_k_dense_replace + field :norm_topk_prob + field :max_position_embeddings + field :rms_norm_eps + field :rope_theta + field :tie_word_embeddings + field :partial_rotary_factor + field :rope_scaling, default: nil + field :attention_bias, default: false + field :use_qk_norm, default: false + field :n_group, default: 1 + field :topk_group, default: 1 + field :scoring_func, default: "sigmoid" + field :topk_method, default: "noaux_tc" + end + + class Model < DeepSeek::Model + def initialize(args) + super(DeepSeek::ModelArgs.from_dict(_to_deepseek_config(args))) + self.model_type = args.model_type + end + + def sanitize(weights) + sanitized = super(weights) + mpt_prefix = "model.layers.#{@args.num_hidden_layers}" + sanitized.reject do |k, _| + k == mpt_prefix || k.start_with?("#{mpt_prefix}.") + end + end + + private + + def _to_deepseek_config(args) + { + "model_type" => args.model_type, + "vocab_size" => args.vocab_size, + "hidden_size" => args.hidden_size, + "intermediate_size" => args.intermediate_size, + "moe_intermediate_size" => args.moe_intermediate_size, + "num_hidden_layers" => args.num_hidden_layers, + "num_attention_heads" => args.num_attention_heads, + "num_key_value_heads" => args.num_key_value_heads, + "n_shared_experts" => args.n_shared_experts, + "n_routed_experts" => args.n_routed_experts, + "num_experts_per_tok" => args.num_experts_per_tok, + "first_k_dense_replace" => args.first_k_dense_replace, + "moe_layer_freq" => 1, + "max_position_embeddings" => args.max_position_embeddings, + "rms_norm_eps" => args.rms_norm_eps, + "rope_theta" => args.rope_theta, + "rope_scaling" => args.rope_scaling, + "attention_bias" => args.attention_bias, + } + end + end + + Models.register("solar_open", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/ssm.rb b/lib/mlx_lm/models/ssm.rb new file mode 100644 index 0000000..d45681d --- /dev/null +++ b/lib/mlx_lm/models/ssm.rb @@ -0,0 +1,162 @@ +module MlxLm + module Models + module SSM + module_function + + def compute_dt(dt, dt_bias, time_step_limit = [0.001, 100.0]) + dt = MLX::NN.softplus(dt + dt_bias) + MLX::Core.clip(dt, time_step_limit[0], time_step_limit[1]) + end + + def segsum(x, mask: nil) + mx = MLX::Core + l = x.shape[-1] + + unless mask.nil? + mask_e = mx.expand_dims(mask, 1) + x = x * mask_e + end + + x = mx.repeat(mx.expand_dims(x, -1), l, -1) + x = mx.tril(x, -1) + x_segsum = mx.cumsum(x, -2) + + unless mask.nil? + mask_e = mx.expand_dims(mask, 1) + valid = mx.multiply(mx.expand_dims(mask_e, -1), mx.expand_dims(mask_e, -2)) + x_segsum = mx.where(valid, x_segsum, -Float::INFINITY) + end + + x_segsum + end + + # Baseline implementation for SSD-SSM using explicit recurrence. + def ssm_attn( + x, + a_log, + b, + c, + d, + dt, + dt_bias, + state: nil, + time_step_limit: [0.001, 100.0], + mask: nil, + lengths: nil, + step: 256 + ) + _ = step + raise NotImplementedError, "length-aware SSM path is not implemented yet" unless lengths.nil? + + mx = MLX::Core + batch_size, seq_len, num_heads, head_dim = x.shape + _, _, num_groups, state_dim = b.shape + + repeats = num_heads / num_groups + dt = compute_dt(dt, dt_bias, time_step_limit) + dt = mx.expand_dims(dt, 0) if dt.ndim == 2 + a = mx.multiply(-1.0, mx.exp(a_log).astype(dt.dtype)) + + state ||= mx.zeros([batch_size, num_heads, head_dim, state_dim], x.dtype) + + ys = [] + seq_len.times do |t| + x_t = _slice_step(x, t) + dt_t = _slice_step(dt, t) + b_t = _slice_step(b, t) + c_t = _slice_step(c, t) + + if repeats > 1 + b_t = mx.repeat(b_t, repeats, 1) + c_t = mx.repeat(c_t, repeats, 1) + end + + decay = mx.exp(dt_t * a.reshape([1, num_heads])) + prev_state = state + state = state * decay.reshape([batch_size, num_heads, 1, 1]) + + dB = dt_t.reshape([batch_size, num_heads, 1, 1]) * b_t.reshape([batch_size, num_heads, 1, state_dim]) + state = state + x_t.reshape([batch_size, num_heads, head_dim, 1]) * dB + + y_t = (state * c_t.reshape([batch_size, num_heads, 1, state_dim])).sum(-1) + y_t = y_t + x_t * d.reshape([1, num_heads, 1]) + + unless mask.nil? + m_t = _slice_step(mask, t) + m_t = m_t.reshape([batch_size, 1, 1]) + state = mx.where(m_t, state, prev_state) + y_t = mx.where(m_t, y_t, mx.zeros(y_t.shape, y_t.dtype)) + end + + ys << y_t + end + + [mx.stack(ys, 1), state] + end + + def ssm_update_kernel(*_args, **_kwargs) + raise NotImplementedError, + "SSM metal kernel path is not implemented in mlx-ruby-lm yet" + end + + def ssm_update( + hidden_states, + a_log, + b, + c, + d, + dt, + dt_bias, + state: nil, + time_step_limit: [0.001, 100.0], + mask: nil, + lengths: nil + ) + mx = MLX::Core + seq_len = hidden_states.shape[1] + + use_attn_path = seq_len > 1 || + state.nil? || + !mx.respond_to?(:metal_is_available) || + !mx.metal_is_available || + !mx.respond_to?(:default_device) || + (mx.default_device.respond_to?(:type) && mx.default_device.type != :gpu) + + if use_attn_path + return ssm_attn( + hidden_states, + a_log, + b, + c, + d, + dt, + dt_bias, + state: state, + time_step_limit: time_step_limit, + mask: mask, + lengths: lengths + ) + end + + ssm_update_kernel( + hidden_states, + a_log, + b, + c, + d, + dt, + dt_bias, + state, + time_step_limit + ) + end + + def _slice_step(array, idx) + mx = MLX::Core + tail = idx.zero? ? array : mx.split(array, [idx], 1)[1] + mx.squeeze(mx.split(tail, [1], 1)[0], 1) + end + private_class_method :_slice_step + end + end +end diff --git a/lib/mlx_lm/models/step3p5.rb b/lib/mlx_lm/models/step3p5.rb new file mode 100644 index 0000000..5ac9b37 --- /dev/null +++ b/lib/mlx_lm/models/step3p5.rb @@ -0,0 +1,479 @@ +require_relative "activations" +require_relative "cache" +require_relative "rope_utils" +require_relative "switch_layers" + +module MlxLm + module Models + module Step3p5 + def self.clamped_swiglu(x, gate, limit) + mx = MLX::Core + clipped_gate = mx.minimum(MLX::NN.silu(gate), limit) + clipped_x = mx.clip(x, -limit, limit) + clipped_gate * clipped_x + end + + class ModelArgs < BaseModelArgs + field :model_type, default: "step3p5" + field :hidden_size + field :num_hidden_layers + field :vocab_size + field :num_attention_heads + field :num_attention_groups + field :head_dim + field :intermediate_size + field :rms_norm_eps, default: 1e-5 + field :rope_theta, default: 10_000.0 + field :rope_scaling, default: nil + field :max_position_embeddings, default: 262_144 + field :sliding_window, default: 512 + field :layer_types, default: nil + field :yarn_only_types, default: nil + field :partial_rotary_factors, default: nil + field :attention_other_setting, default: nil + field :use_head_wise_attn_gate, default: true + field :moe_num_experts, default: 288 + field :moe_top_k, default: 8 + field :moe_intermediate_size, default: 1280 + field :share_expert_dim, default: 1280 + field :moe_layers_enum, default: nil + field :moe_router_scaling_factor, default: 3.0 + field :norm_expert_weight, default: true + field :swiglu_limits, default: nil + field :swiglu_limits_shared, default: nil + field :tie_word_embeddings, default: false + end + + class ZeroCenteredRMSNorm < MLX::NN::Module + def initialize(dims, eps: 1e-5) + super() + self.weight = MLX::Core.ones([dims]) + @eps = eps + end + + def call(x) + mx = MLX::Core + mean_sq = mx.mean(x * x, -1, keepdims: true) + (x * mx.rsqrt(mean_sq + @eps)) * weight + end + end + + class Step3p5MLP < MLX::NN::Module + def initialize(args, intermediate_size:, swiglu_limit: 0) + super() + @hidden_size = args.hidden_size + @intermediate_size = intermediate_size + + self.gate_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false) + self.up_proj = MLX::NN::Linear.new(@hidden_size, @intermediate_size, bias: false) + self.down_proj = MLX::NN::Linear.new(@intermediate_size, @hidden_size, bias: false) + + @limit = swiglu_limit && swiglu_limit > 0 ? swiglu_limit : nil + end + + def call(x) + if @limit + return down_proj.call( + Step3p5.clamped_swiglu(up_proj.call(x), gate_proj.call(x), @limit) + ) + end + + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class Step3p5MoEGate < MLX::NN::Module + def initialize(args) + super() + @top_k = args.moe_top_k + @n_routed_experts = args.moe_num_experts + @routed_scaling_factor = args.moe_router_scaling_factor + @norm_topk_prob = args.norm_expert_weight + + self.gate = MLX::NN::Linear.new(args.hidden_size, @n_routed_experts, bias: false) + self.router_bias = MLX::Core.zeros([@n_routed_experts]) + end + + def call(x) + _moe_gate_select(gate.call(x)) + end + + private + + def _moe_gate_select(gates) + mx = MLX::Core + scores = mx.sigmoid(gates.astype(mx.float32)) + corrected_scores = scores + router_bias + + k = [[@top_k.to_i, 1].max, @n_routed_experts].min + topk_indices = mx.argpartition(corrected_scores * -1.0, k - 1, -1) + take_ids = mx.array((0...k).to_a, dtype: mx.int32) + topk_indices = mx.take(topk_indices, take_ids, -1) + topk_weights = mx.take_along_axis(scores, topk_indices, -1) + + if @norm_topk_prob + topk_weights = topk_weights / (mx.expand_dims(mx.sum(topk_weights, -1), -1) + 1e-20) + end + + [topk_indices, topk_weights * @routed_scaling_factor] + end + end + + class Step3p5MoE < MLX::NN::Module + def initialize(args, layer_idx) + super() + swiglu_limit = _limit_at(args.swiglu_limits, layer_idx) + swiglu_limit_shared = _limit_at(args.swiglu_limits_shared, layer_idx) + + self.gate = Step3p5MoEGate.new(args) + self.switch_mlp = SwitchLayers::SwitchGLU.new( + args.hidden_size, + args.moe_intermediate_size, + args.moe_num_experts + ) + self.share_expert = Step3p5MLP.new( + args, + intermediate_size: args.share_expert_dim, + swiglu_limit: swiglu_limit_shared + ) + + @swiglu_limit = swiglu_limit + end + + def call(x) + mx = MLX::Core + topk_indices, topk_weights = gate.call(x) + + routed_output = switch_mlp.call(x, topk_indices) + routed_output = mx.sum(routed_output * mx.expand_dims(topk_weights, -1), -2).astype(routed_output.dtype) + routed_output + share_expert.call(x) + end + + private + + def _limit_at(values, idx) + arr = Array(values) + return 0 unless idx < arr.length + + arr[idx] || 0 + end + end + + class Step3p5Attention < MLX::NN::Module + attr_reader :is_sliding + + def initialize(args, layer_idx) + super() + dim = args.hidden_size + layer_types = Array(args.layer_types) + + @is_sliding = if layer_types.empty? + layer_idx.even? + else + layer_types[layer_idx] == "sliding_attention" + end + + if @is_sliding && args.attention_other_setting + settings = args.attention_other_setting + @num_heads = _cfg_value(settings, "num_attention_heads", args.num_attention_heads) + @num_kv_heads = _cfg_value(settings, "num_attention_groups", args.num_attention_groups) + else + @num_heads = args.num_attention_heads + @num_kv_heads = args.num_attention_groups + end + + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new(dim, @num_heads * @head_dim, bias: false) + self.k_proj = MLX::NN::Linear.new(dim, @num_kv_heads * @head_dim, bias: false) + self.v_proj = MLX::NN::Linear.new(dim, @num_kv_heads * @head_dim, bias: false) + self.o_proj = MLX::NN::Linear.new(@num_heads * @head_dim, dim, bias: false) + + self.q_norm = ZeroCenteredRMSNorm.new(@head_dim, eps: args.rms_norm_eps) + self.k_norm = ZeroCenteredRMSNorm.new(@head_dim, eps: args.rms_norm_eps) + + @use_head_wise_attn_gate = args.use_head_wise_attn_gate + self.g_proj = MLX::NN::Linear.new(dim, @num_heads, bias: false) if @use_head_wise_attn_gate + + rope_theta = args.rope_theta + if rope_theta.is_a?(Array) + rope_theta = rope_theta[layer_idx] || rope_theta[0] + end + + partial_rotary_factor = _partial_rotary_factor(args.partial_rotary_factors, layer_idx) + rope_dims = (@head_dim * partial_rotary_factor).to_i + rope_dims = 1 if rope_dims < 1 + + yarn_only_types = Array(args.yarn_only_types) + layer_type = layer_types.empty? ? "full_attention" : layer_types[layer_idx] + rope_scaling = if !yarn_only_types.empty? && !yarn_only_types.include?(layer_type) + nil + else + args.rope_scaling + end + + self.rope = MlxLm::Models.initialize_rope( + rope_dims, + rope_theta, + false, + rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _ = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = q_norm.call(queries.reshape([b, l, @num_heads, @head_dim])).transpose([0, 2, 1, 3]) + keys = k_norm.call(keys.reshape([b, l, @num_kv_heads, @head_dim])).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_kv_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]) + + if @use_head_wise_attn_gate + output = output * mx.expand_dims(mx.sigmoid(g_proj.call(x)), -1) + end + + o_proj.call(output.reshape([b, l, @num_heads * @head_dim])) + end + + private + + def _partial_rotary_factor(factors, idx) + arr = Array(factors) + return 1.0 unless idx < arr.length + + arr[idx] || 1.0 + end + + def _cfg_value(hash, key, default = nil) + return hash[key] if hash.key?(key) + + hash.fetch(key.to_sym, default) + end + end + + class Step3p5DecoderLayer < MLX::NN::Module + attr_reader :is_sliding + + def initialize(args, layer_idx) + super() + self.self_attn = Step3p5Attention.new(args, layer_idx) + @is_sliding = self_attn.is_sliding + + moe_layers_idx = _build_moe_layers_idx(args) + is_moe_layer = moe_layers_idx[layer_idx] + + if is_moe_layer + self.mlp = Step3p5MoE.new(args, layer_idx) + else + swiglu_limit = _limit_at(args.swiglu_limits_shared, layer_idx) + self.mlp = Step3p5MLP.new( + args, + intermediate_size: args.intermediate_size, + swiglu_limit: swiglu_limit + ) + end + + self.input_layernorm = ZeroCenteredRMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = ZeroCenteredRMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + h + mlp.call(post_attention_layernorm.call(h)) + end + + private + + def _build_moe_layers_idx(args) + mapping = {} + if args.moe_layers_enum + args.moe_layers_enum.split(",").each do |idx| + stripped = idx.strip + next if stripped.empty? + + mapping[stripped.to_i] = true + end + else + (1...args.num_hidden_layers).each { |idx| mapping[idx] = true } + end + mapping + end + + def _limit_at(values, idx) + arr = Array(values) + return 0 unless idx < arr.length + + arr[idx] || 0 + end + end + + class Step3p5Model < MLX::NN::Module + def initialize(args) + super() + @args = args + @num_layers = args.num_hidden_layers + + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { |layer_idx| Step3p5DecoderLayer.new(args, layer_idx) } + self.norm = ZeroCenteredRMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + + @swa_idx = layers.index(&:is_sliding) + @full_idx = layers.index { |layer| !layer.is_sliding } + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * @num_layers + + full_mask = @full_idx.nil? ? nil : _create_attention_mask(h, layer_cache[@full_idx]) + swa_mask = if @swa_idx.nil? + nil + else + _create_attention_mask(h, layer_cache[@swa_idx], window_size: @args.sliding_window) + end + + layers.each_with_index do |layer, i| + mask = layer.is_sliding ? swa_mask : full_mask + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache = nil, window_size: nil) + n = h.shape[1] + if cache && cache.respond_to?(:make_mask) + return cache.make_mask(n, window_size: window_size) + end + + if window_size + offset = 0 + if cache + offset = cache.offset + if cache.instance_variable_defined?(:@max_size) + max_size = cache.instance_variable_get(:@max_size) + offset = [max_size - 1, offset].min if max_size && max_size > 0 + end + end + return _create_causal_mask(n, offset: offset, window_size: window_size) if offset + n > window_size + end + return nil if n == 1 + + "causal" + end + + def _create_causal_mask(n, offset: 0, window_size: nil) + mx = MLX::Core + rinds = mx.arange(0, offset + n, 1, mx.int32).reshape([1, offset + n]) + linds = mx.arange(offset, offset + n, 1, mx.int32).reshape([n, 1]) + + mask = mx.greater_equal(linds, rinds) + if window_size + mask = mx.logical_and(mask, mx.less(linds, mx.add(rinds, window_size))) + end + mask + end + end + + class Model < MLX::NN::Module + attr_reader :args + + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Step3p5Model.new(args) + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + + def call(inputs, cache: nil) + lm_head.call(model.call(inputs, cache: cache)) + end + + def layers + model.layers + end + + def make_cache + Array.new(layers.length) { MlxLm::KVCache.new } + end + + def sanitize(weights) + remappings = [ + [".moe.gate_proj.", ".mlp.switch_mlp.gate_proj."], + [".moe.up_proj.", ".mlp.switch_mlp.up_proj."], + [".moe.down_proj.", ".mlp.switch_mlp.down_proj."], + [".moe.gate.", ".mlp.gate.gate."], + [".moe.router_bias", ".mlp.gate.router_bias"], + [".share_expert.", ".mlp.share_expert."], + ] + + is_vanilla = weights.any? do |key, _| + remappings.any? { |src, dst| key.include?(src) && !key.include?(dst) } + end + + sanitized = {} + weights.each do |key, value| + next if key.include?(".mtp") + + if (match = key.match(/model\.layers\.(\d+)\./)) && match[1].to_i >= args.num_hidden_layers + next + end + + mapped_key = key + remappings.each do |src, dst| + if mapped_key.include?(src) && !mapped_key.include?(dst) + mapped_key = mapped_key.gsub(src, dst) + break + end + end + + mapped_value = value + if is_vanilla && mapped_key.end_with?(".weight") && mapped_key.include?("norm") + mapped_value = mapped_value + 1 + end + + sanitized[mapped_key] = mapped_value + end + + sanitized + end + + def cast_predicate + ->(key) { !key.include?("router_bias") } + end + + def quant_predicate + lambda do |path, _| + return {group_size: 64, bits: 8} if path.include?("mlp.gate.gate") + + true + end + end + end + + Models.register("step3p5", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/switch_layers.rb b/lib/mlx_lm/models/switch_layers.rb new file mode 100644 index 0000000..6ee2e75 --- /dev/null +++ b/lib/mlx_lm/models/switch_layers.rb @@ -0,0 +1,221 @@ +module MlxLm + module Models + module SwitchLayers + # Gather-sort helper: reorder tokens so same-expert tokens are contiguous. + # Returns [sorted_x, sorted_indices, inv_order]. + def self.gather_sort(x, indices) + mx = MLX::Core + m = indices.shape[-1] + flat_indices = mx.flatten(indices) + order = mx.argsort(flat_indices) + inv_order = mx.argsort(order) + token_ids = mx.floor_divide(order, m) + sorted_x = mx.take(mx.flatten(x, 0, -3), token_ids, 0) + sorted_indices = mx.take(flat_indices, order) + [sorted_x, sorted_indices, inv_order] + end + + # Scatter-unsort helper: restore original token order after sorted computation. + def self.scatter_unsort(x, inv_order, shape = nil) + mx = MLX::Core + x = mx.take(x, inv_order, 0) + x = mx.unflatten(x, 0, shape) if shape + x + end + + # SwitchLinear: batched expert linear layer using gather_mm. + # Stacks all expert weights into a single [num_experts, output_dims, input_dims] tensor + # and dispatches via mx.gather_mm. + class SwitchLinear < MLX::NN::Module + def initialize(input_dims, output_dims, num_experts, bias: false) + super() + mx = MLX::Core + scale = Math.sqrt(1.0 / input_dims) + self.weight = mx.random_uniform( + [num_experts, output_dims, input_dims], + scale * -1.0, scale, mx.float32 + ) + self.bias = mx.zeros([num_experts, output_dims]) if bias + end + + def call(x, indices, sorted_indices: false) + mx = MLX::Core + x = mx.gather_mm( + x, + mx.swapaxes(weight, -1, -2), + nil, + indices, + sorted_indices + ) + if respond_to?(:bias) + x = x + mx.expand_dims(mx.take(bias, indices, 0), -2) + end + x + end + + def to_quantized(group_size: nil, bits: nil, mode: "affine", quantize_input: false) + raise ArgumentError, "Quantized input is not supported." if quantize_input + + QuantizedSwitchLinear.from_switch_linear(self, group_size, bits, mode: mode) + end + end + + # Quantized version of SwitchLinear using gather_qmm. + class QuantizedSwitchLinear < MLX::NN::Module + attr_reader :group_size, :bits, :mode + + def initialize(input_dims, output_dims, num_experts, bias: false, group_size: nil, bits: nil, mode: "affine") + super() + + @group_size, @bits = MLX::NN.__send__(:defaults_for_mode, mode, group_size, bits) + @mode = mode + + mx = MLX::Core + scale = Math.sqrt(1.0 / input_dims) + q_weight, q_scales, *q_biases = mx.quantize( + mx.random_uniform( + [num_experts, output_dims, input_dims], + scale * -1.0, + scale, + mx.float32 + ), + @group_size, + @bits, + @mode + ) + self.weight = q_weight + self.scales = q_scales + self.biases = q_biases.empty? ? nil : q_biases[0] + self.bias = mx.zeros([num_experts, output_dims]) if bias + + freeze + end + + def call(x, indices, sorted_indices: false) + mx = MLX::Core + q_biases = respond_to?(:biases) ? biases : nil + x = mx.gather_qmm( + x, + weight, + scales, + q_biases, + nil, + indices, + true, + @group_size, + @bits, + @mode, + sorted_indices + ) + if respond_to?(:bias) + x = x + mx.expand_dims(mx.take(bias, indices, 0), -2) + end + x + end + + def self.from_switch_linear(linear_layer, group_size = nil, bits = nil, mode: "affine") + num_experts, output_dims, input_dims = linear_layer.weight.shape + out = new( + input_dims, + output_dims, + num_experts, + bias: false, + group_size: group_size, + bits: bits, + mode: mode + ) + q_weight, q_scales, *q_biases = MLX::Core.quantize( + linear_layer.weight, + out.group_size, + out.bits, + out.mode + ) + out.weight = q_weight + out.scales = q_scales + out.biases = q_biases.empty? ? nil : q_biases[0] + out.bias = linear_layer.bias if linear_layer.state.key?("bias") + out + end + end + + # SwitchGLU: batched expert MLP with SwiGLU activation using SwitchLinear. + # Replaces per-token expert routing loops with gather_mm for ONNX traceability. + class SwitchGLU < MLX::NN::Module + def initialize(input_dims, hidden_dims, num_experts, bias: false) + super() + self.gate_proj = SwitchLinear.new(input_dims, hidden_dims, num_experts, bias: bias) + self.up_proj = SwitchLinear.new(input_dims, hidden_dims, num_experts, bias: bias) + self.down_proj = SwitchLinear.new(hidden_dims, input_dims, num_experts, bias: bias) + end + + def call(x, indices) + mx = MLX::Core + x = mx.expand_dims(x, [-2, -3]) + + # Sort optimization for many tokens + do_sort = indices.size >= 64 + idx = indices + inv_order = nil + + if do_sort + x, idx, inv_order = SwitchLayers.gather_sort(x, indices) + end + + idx = mx.stop_gradient(idx) if training + + x_up = up_proj.call(x, idx, sorted_indices: do_sort) + x_gate = gate_proj.call(x, idx, sorted_indices: do_sort) + + # SwiGLU activation: silu(gate) * up + x = down_proj.call( + MLX::NN.silu(x_gate) * x_up, + idx, + sorted_indices: do_sort + ) + + if do_sort + x = SwitchLayers.scatter_unsort(x, inv_order, indices.shape) + end + + mx.squeeze(x, -2) + end + end + + # Batched expert MLP with configurable activation. + class SwitchMLP < MLX::NN::Module + def initialize(input_dims, hidden_dims, num_experts, activation: nil, bias: false) + super() + self.fc1 = SwitchLinear.new(input_dims, hidden_dims, num_experts, bias: bias) + self.fc2 = SwitchLinear.new(hidden_dims, input_dims, num_experts, bias: bias) + self.activation = activation || MLX::NN::GELU.new("precise") + end + + def call(x, indices) + mx = MLX::Core + x = mx.expand_dims(x, [-2, -3]) + + # Sort optimization for many tokens + do_sort = indices.size >= 64 + idx = indices + inv_order = nil + + if do_sort + x, idx, inv_order = SwitchLayers.gather_sort(x, indices) + end + + idx = mx.stop_gradient(idx) if training + + x = fc1.call(x, idx, sorted_indices: do_sort) + x = activation.call(x) + x = fc2.call(x, idx, sorted_indices: do_sort) + + if do_sort + x = SwitchLayers.scatter_unsort(x, inv_order, indices.shape) + end + + mx.squeeze(x, -2) + end + end + end + end +end diff --git a/lib/mlx_lm/models/telechat3.rb b/lib/mlx_lm/models/telechat3.rb new file mode 100644 index 0000000..2763567 --- /dev/null +++ b/lib/mlx_lm/models/telechat3.rb @@ -0,0 +1,192 @@ +module MlxLm + module Models + module Telechat3 + class ModelArgs < BaseModelArgs + field :model_type, default: "telechat3" + field :hidden_size, default: 4096 + field :intermediate_size, default: 14336 + field :max_position_embeddings, default: 32768 + field :num_attention_heads, default: 32 + field :num_hidden_layers, default: 32 + field :num_key_value_heads, default: nil + field :rms_norm_eps, default: 1e-6 + field :vocab_size, default: 151936 + field :rope_theta, default: 10_000.0 + field :mlp_bias, default: false + field :attention_bias, default: false + field :head_dim, default: nil + field :rope_scaling, default: nil + field :tie_word_embeddings, default: false + + def initialize(**kwargs) + super + @num_key_value_heads ||= @num_attention_heads + @head_dim ||= @hidden_size / @num_attention_heads + end + end + + class Telechat3Attention < MLX::NN::Module + def initialize(args) + super() + dim = args.hidden_size + @num_attention_heads = args.num_attention_heads + @num_key_value_heads = args.num_key_value_heads + @head_dim = args.head_dim + @scale = @head_dim**(-0.5) + + self.q_proj = MLX::NN::Linear.new( + dim, + args.num_attention_heads * @head_dim, + bias: args.attention_bias + ) + self.k_proj = MLX::NN::Linear.new( + dim, + args.num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.v_proj = MLX::NN::Linear.new( + dim, + args.num_key_value_heads * @head_dim, + bias: args.attention_bias + ) + self.o_proj = MLX::NN::Linear.new( + args.num_attention_heads * @head_dim, + dim, + bias: args.attention_bias + ) + + self.rope = MlxLm::Models.initialize_rope( + @head_dim, + args.rope_theta, + false, + args.rope_scaling, + max_position_embeddings: args.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + queries = q_proj.call(x) + keys = k_proj.call(x) + values = v_proj.call(x) + + queries = queries.reshape([b, l, @num_attention_heads, @head_dim]).transpose([0, 2, 1, 3]) + keys = keys.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + values = values.reshape([b, l, @num_key_value_heads, @head_dim]).transpose([0, 2, 1, 3]) + + if cache + queries = rope.call(queries, offset: cache.offset) + keys = rope.call(keys, offset: cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else + queries = rope.call(queries) + keys = rope.call(keys) + end + + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_attention_heads * @head_dim]) + o_proj.call(output) + end + end + + class Telechat3MLP < MLX::NN::Module + def initialize(args) + super() + self.gate_proj = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size, + bias: args.mlp_bias + ) + self.down_proj = MLX::NN::Linear.new( + args.intermediate_size, + args.hidden_size, + bias: args.mlp_bias + ) + self.up_proj = MLX::NN::Linear.new( + args.hidden_size, + args.intermediate_size, + bias: args.mlp_bias + ) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class Telechat3DecoderLayer < MLX::NN::Module + def initialize(args) + super() + self.self_attn = Telechat3Attention.new(args) + self.mlp = Telechat3MLP.new(args) + self.input_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + h + mlp.call(post_attention_layernorm.call(h)) + end + end + + class Telechat3Model < MLX::NN::Module + def initialize(args) + super() + self.embed_tokens = MLX::NN::Embedding.new(args.vocab_size, args.hidden_size) + self.layers = Array.new(args.num_hidden_layers) { Telechat3DecoderLayer.new(args) } + self.norm = MLX::NN::RMSNorm.new(args.hidden_size, eps: args.rms_norm_eps) + end + + def call(inputs, cache: nil, input_embeddings: nil) + h = input_embeddings || embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + + mask = nil + mask = "causal" if h.shape[1] > 1 + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + end + + class Model < MLX::NN::Module + def initialize(args) + super() + @args = args + self.model_type = args.model_type + self.model = Telechat3Model.new(args) + unless args.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(args.hidden_size, args.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil, input_embeddings: nil) + out = model.call(inputs, cache: cache, input_embeddings: input_embeddings) + if @args.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.reject { |k, _| k.include?("self_attn.rotary_emb.inv_freq") } + result.delete("lm_head.weight") if @args.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("telechat3", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/models/youtu_llm.rb b/lib/mlx_lm/models/youtu_llm.rb new file mode 100644 index 0000000..448025a --- /dev/null +++ b/lib/mlx_lm/models/youtu_llm.rb @@ -0,0 +1,230 @@ +module MlxLm + module Models + module YoutuLLM + class ModelArgs < BaseModelArgs + field :model_type, default: "youtu_llm" + field :vocab_size, default: 128_256 + field :hidden_size, default: 2048 + field :intermediate_size, default: 6144 + field :num_hidden_layers, default: 32 + field :num_attention_heads, default: 16 + field :num_key_value_heads, default: 16 + field :kv_lora_rank, default: 512 + field :q_lora_rank, default: 1536 + field :qk_rope_head_dim, default: 64 + field :v_head_dim, default: 128 + field :qk_nope_head_dim, default: 128 + field :max_position_embeddings, default: 131_072 + field :rms_norm_eps, default: 1e-6 + field :rope_theta, default: 1_600_000.0 + field :rope_traditional, default: true + field :rope_scaling, default: nil + field :attention_bias, default: false + field :mlp_bias, default: false + field :tie_word_embeddings, default: true + end + + class YoutuLLMAttention < MLX::NN::Module + def initialize(config) + super() + @hidden_size = config.hidden_size + @num_heads = config.num_attention_heads + @q_lora_rank = config.q_lora_rank + @qk_rope_head_dim = config.qk_rope_head_dim + @kv_lora_rank = config.kv_lora_rank + @v_head_dim = config.v_head_dim + @qk_nope_head_dim = config.qk_nope_head_dim + @q_head_dim = @qk_nope_head_dim + @qk_rope_head_dim + @kv_head_dim = @qk_nope_head_dim + @v_head_dim + @scale = @q_head_dim**(-0.5) + + if @q_lora_rank.nil? + self.q_proj = MLX::NN::Linear.new( + @hidden_size, + @num_heads * @q_head_dim, + bias: false + ) + else + self.q_a_proj = MLX::NN::Linear.new( + @hidden_size, + @q_lora_rank, + bias: config.attention_bias + ) + self.q_a_layernorm = MLX::NN::RMSNorm.new(@q_lora_rank, eps: config.rms_norm_eps) + self.q_b_proj = MLX::NN::Linear.new(@q_lora_rank, @num_heads * @q_head_dim, bias: false) + end + + self.kv_a_proj_with_mqa = MLX::NN::Linear.new( + @hidden_size, + @kv_lora_rank + @qk_rope_head_dim, + bias: config.attention_bias + ) + self.kv_a_layernorm = MLX::NN::RMSNorm.new(@kv_lora_rank, eps: config.rms_norm_eps) + self.kv_b_proj = MLX::NN::Linear.new( + @kv_lora_rank, + @num_heads * (@q_head_dim - @qk_rope_head_dim + @v_head_dim), + bias: false + ) + + self.o_proj = MLX::NN::Linear.new( + @num_heads * @v_head_dim, + @hidden_size, + bias: config.attention_bias + ) + + self.rope = MlxLm::Models.initialize_rope( + @qk_rope_head_dim, + config.rope_theta, + config.rope_traditional, + config.rope_scaling, + max_position_embeddings: config.max_position_embeddings + ) + end + + def call(x, mask: nil, cache: nil) + mx = MLX::Core + b, l, _d = x.shape + + q = if @q_lora_rank.nil? + q_proj.call(x) + else + q_b_proj.call(q_a_layernorm.call(q_a_proj.call(x))) + end + + q = q.reshape([b, l, @num_heads, @q_head_dim]).transpose([0, 2, 1, 3]) + q_nope, q_pe = mx.split(q, [@qk_nope_head_dim], -1) + + compressed_kv = kv_a_proj_with_mqa.call(x) + compressed_kv, k_pe = mx.split(compressed_kv, [@kv_lora_rank], -1) + k_pe = k_pe.reshape([b, l, 1, @qk_rope_head_dim]).transpose([0, 2, 1, 3]) + + kv = kv_b_proj.call(kv_a_layernorm.call(compressed_kv)) + kv = kv.reshape([b, l, @num_heads, @kv_head_dim]).transpose([0, 2, 1, 3]) + k_nope, values = mx.split(kv, [@qk_nope_head_dim], -1) + + if cache + q_pe = rope.call(q_pe, offset: cache.offset) + k_pe = rope.call(k_pe, offset: cache.offset) + k_pe = mx.repeat(k_pe, @num_heads, 1) + keys, values = cache.update_and_fetch(mx.concatenate([k_nope, k_pe], -1), values) + else + q_pe = rope.call(q_pe) + k_pe = rope.call(k_pe) + k_pe = mx.repeat(k_pe, @num_heads, 1) + keys = mx.concatenate([k_nope, k_pe], -1) + end + + queries = mx.concatenate([q_nope, q_pe], -1) + output = mx.scaled_dot_product_attention(queries, keys, values, @scale, mask) + output = output.transpose([0, 2, 1, 3]).reshape([b, l, @num_heads * @v_head_dim]) + o_proj.call(output) + end + end + + class YoutuLLMMLP < MLX::NN::Module + def initialize(config) + super() + self.gate_proj = MLX::NN::Linear.new( + config.hidden_size, + config.intermediate_size, + bias: config.mlp_bias + ) + self.up_proj = MLX::NN::Linear.new( + config.hidden_size, + config.intermediate_size, + bias: config.mlp_bias + ) + self.down_proj = MLX::NN::Linear.new( + config.intermediate_size, + config.hidden_size, + bias: config.mlp_bias + ) + end + + def call(x) + down_proj.call(Activations.swiglu(gate_proj.call(x), up_proj.call(x))) + end + end + + class YoutuLLMDecoderLayer < MLX::NN::Module + def initialize(config) + super() + self.self_attn = YoutuLLMAttention.new(config) + self.mlp = YoutuLLMMLP.new(config) + self.input_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + self.post_attention_layernorm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(x, mask: nil, cache: nil) + r = self_attn.call(input_layernorm.call(x), mask: mask, cache: cache) + h = x + r + r = mlp.call(post_attention_layernorm.call(h)) + h + r + end + end + + class YoutuLLMModel < MLX::NN::Module + def initialize(config) + super() + self.embed_tokens = MLX::NN::Embedding.new(config.vocab_size, config.hidden_size) + self.layers = Array.new(config.num_hidden_layers) { YoutuLLMDecoderLayer.new(config) } + self.norm = MLX::NN::RMSNorm.new(config.hidden_size, eps: config.rms_norm_eps) + end + + def call(inputs, cache: nil) + h = embed_tokens.call(inputs) + layer_cache = cache || [nil] * layers.length + mask = _create_attention_mask(h, layer_cache[0]) + + layers.each_with_index do |layer, i| + h = layer.call(h, mask: mask, cache: layer_cache[i]) + end + + norm.call(h) + end + + private + + def _create_attention_mask(h, cache) + return cache.make_mask(h.shape[1]) if cache && cache.respond_to?(:make_mask) + return nil if h.shape[1] == 1 + + "causal" + end + end + + class Model < MLX::NN::Module + def initialize(config) + super() + @config = config + self.model_type = config.model_type + self.model = YoutuLLMModel.new(config) + unless config.tie_word_embeddings + self.lm_head = MLX::NN::Linear.new(config.hidden_size, config.vocab_size, bias: false) + end + end + + def call(inputs, cache: nil) + out = model.call(inputs, cache: cache) + if @config.tie_word_embeddings + model.embed_tokens.as_linear(out) + else + lm_head.call(out) + end + end + + def sanitize(weights) + result = weights.dup + result.delete("lm_head.weight") if @config.tie_word_embeddings + result + end + + def layers + model.layers + end + end + + Models.register("youtu_llm", Model, ModelArgs) + end + end +end diff --git a/lib/mlx_lm/perplexity.rb b/lib/mlx_lm/perplexity.rb index 935101b..3964c16 100644 --- a/lib/mlx_lm/perplexity.rb +++ b/lib/mlx_lm/perplexity.rb @@ -16,7 +16,7 @@ def compute(model, tokens, batch_size: nil) def log_likelihood(model, tokens, batch_size: nil) mx = MLX::Core - token_arr = tokens.is_a?(MLX::Core::Array) ? tokens : mx.array(tokens).astype(mx.int32) + token_arr = tokens.is_a?(MLX::Core::Array) ? tokens : mx.array(tokens, dtype: mx.int32) total_tokens = token_arr.size # Process all at once for small sequences diff --git a/lib/mlx_lm/sample_utils.rb b/lib/mlx_lm/sample_utils.rb index c7b314e..dfe9eab 100644 --- a/lib/mlx_lm/sample_utils.rb +++ b/lib/mlx_lm/sample_utils.rb @@ -51,7 +51,7 @@ def apply_top_k(logprobs, top_k) mask_idx = mx.argpartition(neg_logprobs, top_k - 1, -1) # Get indices after top_k (the ones to mask) rest = mx.split(mask_idx, [top_k], -1)[1] - neg_inf = mx.array([-Float::INFINITY]).astype(logprobs.dtype) + neg_inf = mx.array([-Float::INFINITY], dtype: logprobs.dtype) mx.put_along_axis(logprobs, rest, neg_inf, -1) end @@ -73,7 +73,7 @@ def apply_min_p(logprobs, min_p, min_tokens_to_keep = 1) # Mask tokens below threshold tokens_to_remove = mx.less(sorted_logprobs, scaled_min_p) - neg_inf = mx.array(-Float::INFINITY).astype(sorted_logprobs.dtype) + neg_inf = mx.array(-Float::INFINITY, dtype: sorted_logprobs.dtype) selected_logprobs = mx.where(tokens_to_remove, neg_inf, sorted_logprobs) # Restore the top min_tokens_to_keep tokens regardless @@ -113,9 +113,9 @@ def apply_top_p(logprobs, top_p) cumulative_probs = mx.take_along_axis(cumulative_probs, inverse_indices, -1) # select tokens with cumulative probs above threshold - threshold = mx.array(1.0 - top_p).astype(cumulative_probs.dtype) + threshold = mx.array(1.0 - top_p, dtype: cumulative_probs.dtype) mask = mx.greater(cumulative_probs, threshold) - neg_inf = mx.array(-Float::INFINITY).astype(logprobs.dtype) + neg_inf = mx.array(-Float::INFINITY, dtype: logprobs.dtype) mx.where(mask, logprobs, neg_inf) end @@ -138,11 +138,11 @@ def make_repetition_penalty(penalty, context_size = 20) [] end if recent.length > 0 - token_indices = mx.array(recent).astype(mx.int32) + token_indices = mx.array(recent, dtype: mx.int32) n_tokens = recent.length idx_2d = token_indices.reshape([1, n_tokens]) selected_logits = mx.take_along_axis(logits, idx_2d, -1) - zero = mx.array(0.0).astype(selected_logits.dtype) + zero = mx.array(0.0, dtype: selected_logits.dtype) is_negative = mx.less(selected_logits, zero) selected_logits = mx.where( is_negative, diff --git a/lib/mlx_lm/weight_utils.rb b/lib/mlx_lm/weight_utils.rb index 851e2bd..cffa5c4 100644 --- a/lib/mlx_lm/weight_utils.rb +++ b/lib/mlx_lm/weight_utils.rb @@ -58,19 +58,19 @@ def _tensor_to_mlx(info, mx) elsif dtype_str == "F16" # 16-bit float: unpack as uint16, create array as float32, then view as float16 values = data.unpack("S<*") - mx.array(values).astype(mx.uint16).view(mx.float16).reshape(shape) + mx.array(values, dtype: mx.uint16).view(mx.float16).reshape(shape) elsif dtype_str == "BF16" values = data.unpack("S<*") - mx.array(values).astype(mx.uint16).view(mx.bfloat16).reshape(shape) + mx.array(values, dtype: mx.uint16).view(mx.bfloat16).reshape(shape) elsif dtype_str == "I32" || dtype_str == "int32" values = data.unpack("l<*") - mx.array(values).astype(mx.int32).reshape(shape) + mx.array(values, dtype: mx.int32).reshape(shape) elsif dtype_str == "I64" values = data.unpack("q<*") - mx.array(values).astype(mx.int64).reshape(shape) + mx.array(values, dtype: mx.int64).reshape(shape) elsif dtype_str == "U8" values = data.unpack("C*") - mx.array(values).astype(mx.uint8).reshape(shape) + mx.array(values, dtype: mx.uint8).reshape(shape) else # Fallback: try F32 values = data.unpack("e*") diff --git a/mlx-ruby b/mlx-ruby index 85afc8a..476f721 160000 --- a/mlx-ruby +++ b/mlx-ruby @@ -1 +1 @@ -Subproject commit 85afc8a3e3ec461e003bb240d4a544c7bfcf2208 +Subproject commit 476f72167371fa1de1044d0b0f4066145cf07a6e diff --git a/mlx-ruby-lm.gemspec b/mlx-ruby-lm.gemspec index 7e035bc..2c4bd0e 100644 --- a/mlx-ruby-lm.gemspec +++ b/mlx-ruby-lm.gemspec @@ -27,7 +27,11 @@ Gem::Specification.new do |s| s.executables = s.files.grep(%r{\Aexe/}) { |f| File.basename(f) } s.require_paths = ["lib"] - s.add_dependency "mlx", "~> 0.1" + s.add_dependency "mlx", ">= 0.30.7.5", "< 1.0" s.add_dependency "safetensors", "~> 0.2" s.add_dependency "tokenizers", "~> 0.6" + + s.add_development_dependency "minitest", "~> 5.20" + s.add_development_dependency "ostruct" + s.add_development_dependency "rake", "~> 13.0" end diff --git a/prd/2026_02_25_python_ruby_parity_checklist.md b/prd/2026_02_25_python_ruby_parity_checklist.md new file mode 100644 index 0000000..3df004f --- /dev/null +++ b/prd/2026_02_25_python_ruby_parity_checklist.md @@ -0,0 +1,794 @@ +# Python `mlx-lm` -> Ruby `mlx-ruby-lm` Class Parity Checklist + +**Status:** Active +**Date:** 2026-02-26 +**Scope:** Full Python class inventory from `mlx-lm/mlx_lm/**/*.py` +**Ruby surface:** `lib/mlx_lm/**/*.rb` + +## Summary + +| Metric | Count | +|---|---:| +| Python classes discovered | 768 | +| Implemented | 527 | +| Partial | 221 | +| Missing | 20 | + +## Status Rules + +- `Implemented`: class is present in the expected Ruby file (or uniquely implemented in a different Ruby file with a note). +- `Partial`: Ruby file exists for that Python module, but the class is absent/renamed/merged. +- `Missing`: no Ruby class/file counterpart found for that Python class. + +## Full Class Inventory + +| Python File | Line | Python Class | Ruby Status | Ruby Reference | Notes | +|---|---:|---|---|---|---| +| evaluate.py | 72 | MLXLM | Missing | - | | +| generate.py | 266 | GenerationResponse | Partial | generate.rb | Ruby file exists but defines no classes | +| generate.py | 804 | BatchStats | Partial | generate.rb | Ruby file exists but defines no classes | +| generate.py | 828 | BatchResponse | Partial | generate.rb | Ruby file exists but defines no classes | +| generate.py | 843 | Batch | Partial | generate.rb | Ruby file exists but defines no classes | +| generate.py | 930 | BatchGenerator | Partial | generate.rb | Ruby file exists but defines no classes | +| generate.py | 932 | Response | Partial | generate.rb | Ruby file exists but defines no classes | +| gguf.py | 10 | TokenType | Missing | - | | +| gguf.py | 19 | GGMLFileType | Missing | - | | +| gguf.py | 24 | HfVocab | Missing | - | | +| models/Klear.py | 15 | ModelArgs | Implemented | models/klear.rb | | +| models/Klear.py | 36 | KlearAttention | Implemented | models/klear.rb | | +| models/Klear.py | 110 | KlearMLP | Implemented | models/klear.rb | | +| models/Klear.py | 121 | KlearSparseMoeBlock | Implemented | models/klear.rb | | +| models/Klear.py | 156 | KlearDecoderLayer | Implemented | models/klear.rb | | +| models/Klear.py | 186 | KlearModel | Implemented | models/klear.rb | | +| models/Klear.py | 214 | Model | Implemented | models/klear.rb | | +| models/activations.py | 25 | XieLU | Implemented | models/activations.rb | | +| models/afm7.py | 19 | ModelArgs | Implemented | models/afm7.rb | | +| models/afm7.py | 32 | FusedLoRALinear | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 96 | FusedQuantizedLinear | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 123 | FusedLinear | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 165 | Attention | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 226 | KVReuseAttention | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 266 | MLP | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 283 | TransformerBlock | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 306 | KVReuseTransformerBlock | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 330 | AFMModel | Partial | models/afm7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/afm7.py | 369 | Model | Implemented | models/afm7.rb | | +| models/afmoe.py | 18 | ModelArgs | Implemented | models/afmoe.rb | | +| models/afmoe.py | 48 | Attention | Implemented | models/afmoe.rb | | +| models/afmoe.py | 137 | MLP | Implemented | models/afmoe.rb | | +| models/afmoe.py | 156 | MoERouter | Implemented | models/afmoe.rb | | +| models/afmoe.py | 167 | AfmoeMoE | Implemented | models/afmoe.rb | | +| models/afmoe.py | 242 | DecoderLayer | Implemented | models/afmoe.rb | | +| models/afmoe.py | 278 | AfmoeModel | Implemented | models/afmoe.rb | | +| models/afmoe.py | 332 | Model | Implemented | models/afmoe.rb | | +| models/apertus.py | 16 | ModelArgs | Implemented | models/apertus.rb | | +| models/apertus.py | 36 | ApertusMLP | Implemented | models/apertus.rb | | +| models/apertus.py | 51 | ApertusAttention | Implemented | models/apertus.rb | | +| models/apertus.py | 117 | ApertusDecoderLayer | Implemented | models/apertus.rb | | +| models/apertus.py | 137 | ApertusModel | Implemented | models/apertus.rb | | +| models/apertus.py | 164 | Model | Implemented | models/apertus.rb | | +| models/baichuan_m1.py | 15 | ModelArgs | Implemented | models/baichuan_m1.rb | | +| models/baichuan_m1.py | 33 | Attention | Implemented | models/baichuan_m1.rb | | +| models/baichuan_m1.py | 130 | MLP | Implemented | models/baichuan_m1.rb | | +| models/baichuan_m1.py | 147 | DecoderLayer | Implemented | models/baichuan_m1.rb | | +| models/baichuan_m1.py | 166 | BaichuanModel | Implemented | models/baichuan_m1.rb | | +| models/baichuan_m1.py | 212 | Model | Implemented | models/baichuan_m1.rb | | +| models/bailing_moe.py | 17 | ModelArgs | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 60 | BailingMoeMLP | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 83 | BailingMoeAttention | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 202 | BailingMoeGate | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 234 | BailingMoeSparseMoeBlock | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 267 | BailingMoeDecoderLayer | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 296 | BailingMoeModel | Implemented | models/bailing_moe.rb | | +| models/bailing_moe.py | 324 | Model | Implemented | models/bailing_moe.rb | | +| models/bailing_moe_linear.py | 23 | ModelArgs | Implemented | models/bailing_moe_linear.rb | | +| models/bailing_moe_linear.py | 100 | GroupRMSNorm | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 114 | MLP | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 137 | Attention | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 213 | LinearAttention | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 368 | Gate | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 400 | SparseMoeBlock | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 433 | DecoderLayer | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 475 | LanguageModel | Partial | models/bailing_moe_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/bailing_moe_linear.py | 509 | Model | Implemented | models/bailing_moe_linear.rb | | +| models/base.py | 12 | BaseModelArgs | Implemented | model_args.rb | Implemented in different Ruby file | +| models/bitlinear_layers.py | 92 | BitLinear | Implemented | models/bitlinear_layers.rb | | +| models/bitnet.py | 16 | ModelArgs | Implemented | models/bitnet.rb | | +| models/bitnet.py | 35 | Attention | Implemented | models/bitnet.rb | | +| models/bitnet.py | 96 | MLP | Implemented | models/bitnet.rb | | +| models/bitnet.py | 120 | TransformerBlock | Implemented | models/bitnet.rb | | +| models/bitnet.py | 146 | LlamaModel | Partial | models/bitnet.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, BitnetModel, Model | +| models/bitnet.py | 176 | Model | Implemented | models/bitnet.rb | | +| models/cache.py | 125 | _BaseCache | Partial | models/cache.rb | Ruby file exists; classes differ: BaseCache, KVCache, RotatingKVCache, QuantizedKVCache, ArraysCache, ChunkedKVCache, CacheList | +| models/cache.py | 176 | ConcatenateKVCache | Partial | models/cache.rb | Ruby file exists; classes differ: BaseCache, KVCache, RotatingKVCache, QuantizedKVCache, ArraysCache, ChunkedKVCache, CacheList | +| models/cache.py | 230 | QuantizedKVCache | Implemented | models/cache.rb | | +| models/cache.py | 323 | KVCache | Implemented | models/cache.rb | | +| models/cache.py | 408 | RotatingKVCache | Implemented | models/cache.rb | | +| models/cache.py | 592 | ArraysCache | Implemented | models/cache.rb | | +| models/cache.py | 682 | ChunkedKVCache | Implemented | models/cache.rb | | +| models/cache.py | 765 | CacheList | Implemented | models/cache.rb | | +| models/cache.py | 863 | BatchKVCache | Partial | models/cache.rb | Ruby file exists; classes differ: BaseCache, KVCache, RotatingKVCache, QuantizedKVCache, ArraysCache, ChunkedKVCache, CacheList | +| models/cache.py | 1058 | BatchRotatingKVCache | Partial | models/cache.rb | Ruby file exists; classes differ: BaseCache, KVCache, RotatingKVCache, QuantizedKVCache, ArraysCache, ChunkedKVCache, CacheList | +| models/cohere.py | 14 | ModelArgs | Implemented | models/cohere.rb | | +| models/cohere.py | 30 | LayerNorm2D | Partial | models/cohere.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, CohereModel, Model | +| models/cohere.py | 41 | Attention | Implemented | models/cohere.rb | | +| models/cohere.py | 105 | MLP | Implemented | models/cohere.rb | | +| models/cohere.py | 116 | TransformerBlock | Implemented | models/cohere.rb | | +| models/cohere.py | 141 | CohereModel | Implemented | models/cohere.rb | | +| models/cohere.py | 174 | Model | Implemented | models/cohere.rb | | +| models/cohere2.py | 15 | ModelArgs | Implemented | models/cohere2.rb | | +| models/cohere2.py | 33 | Attention | Implemented | models/cohere2.rb | | +| models/cohere2.py | 102 | MLP | Implemented | models/cohere2.rb | | +| models/cohere2.py | 113 | TransformerBlock | Implemented | models/cohere2.rb | | +| models/cohere2.py | 140 | CohereModel | Partial | models/cohere2.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, Cohere2Model, Model | +| models/cohere2.py | 184 | Model | Implemented | models/cohere2.rb | | +| models/dbrx.py | 15 | ModelArgs | Implemented | models/dbrx.rb | | +| models/dbrx.py | 25 | Attention | Implemented | models/dbrx.rb | | +| models/dbrx.py | 85 | NormAttnNorm | Implemented | models/dbrx.rb | | +| models/dbrx.py | 103 | MLP | Implemented | models/dbrx.rb | | +| models/dbrx.py | 116 | Router | Implemented | models/dbrx.rb | | +| models/dbrx.py | 125 | SparseMoeBlock | Implemented | models/dbrx.rb | | +| models/dbrx.py | 172 | DecoderLayer | Implemented | models/dbrx.rb | | +| models/dbrx.py | 189 | DBRX | Partial | models/dbrx.rb | Ruby file exists; classes differ: ModelArgs, Attention, NormAttnNorm, MLP, Router, SparseMoeBlock, DecoderLayer, DbrxModel, Model | +| models/dbrx.py | 215 | Model | Implemented | models/dbrx.rb | | +| models/deepseek.py | 13 | ModelArgs | Implemented | models/deepseek.rb | | +| models/deepseek.py | 34 | DeepseekAttention | Partial | models/deepseek.rb | Ruby file exists; classes differ: ModelArgs, Attention, DeepseekMLP, MoEGate, DeepseekMoE, DecoderLayer, DeepseekModel, Model | +| models/deepseek.py | 108 | DeepseekMLP | Implemented | models/deepseek.rb | | +| models/deepseek.py | 127 | MoEGate | Implemented | models/deepseek.rb | | +| models/deepseek.py | 144 | DeepseekMoE | Implemented | models/deepseek.rb | | +| models/deepseek.py | 169 | DeepseekDecoderLayer | Partial | models/deepseek.rb | Ruby file exists; classes differ: ModelArgs, Attention, DeepseekMLP, MoEGate, DeepseekMoE, DecoderLayer, DeepseekModel, Model | +| models/deepseek.py | 200 | DeepseekModel | Implemented | models/deepseek.rb | | +| models/deepseek.py | 228 | Model | Implemented | models/deepseek.rb | | +| models/deepseek_v2.py | 18 | ModelArgs | Implemented | models/deepseek_v2.rb | | +| models/deepseek_v2.py | 82 | DeepseekV2YarnRotaryEmbedding | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 129 | DeepseekV2Attention | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 248 | DeepseekV2MLP | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 268 | MoEGate | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 304 | DeepseekV2MoE | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 338 | DeepseekV2DecoderLayer | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 369 | DeepseekV2Model | Partial | models/deepseek_v2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v2.py | 414 | Model | Implemented | models/deepseek_v2.rb | | +| models/deepseek_v3.py | 21 | ModelArgs | Implemented | models/deepseek_v3.rb | | +| models/deepseek_v3.py | 53 | DeepseekV3Attention | Partial | models/deepseek_v3.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v3.py | 172 | DeepseekV3MLP | Partial | models/deepseek_v3.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v3.py | 227 | MoEGate | Partial | models/deepseek_v3.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v3.py | 253 | DeepseekV3MoE | Partial | models/deepseek_v3.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v3.py | 289 | DeepseekV3DecoderLayer | Partial | models/deepseek_v3.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v3.py | 319 | DeepseekV3Model | Partial | models/deepseek_v3.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v3.py | 364 | Model | Implemented | models/deepseek_v3.rb | | +| models/deepseek_v32.py | 20 | ModelArgs | Implemented | models/deepseek_v32.rb | | +| models/deepseek_v32.py | 55 | Indexer | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 120 | DeepseekV32Attention | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 265 | DeepseekV32MLP | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 320 | MoEGate | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 346 | DeepseekV32MoE | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 382 | DeepseekV32DecoderLayer | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 412 | DeepseekV32Model | Partial | models/deepseek_v32.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/deepseek_v32.py | 481 | Model | Implemented | models/deepseek_v32.rb | | +| models/dots1.py | 17 | ModelArgs | Implemented | models/dots1.rb | | +| models/dots1.py | 45 | Dots1Attention | Implemented | models/dots1.rb | | +| models/dots1.py | 138 | Dots1TopkRouter | Implemented | models/dots1.rb | | +| models/dots1.py | 162 | Dots1MLP | Implemented | models/dots1.rb | | +| models/dots1.py | 187 | Dots1MoE | Implemented | models/dots1.rb | | +| models/dots1.py | 216 | Dots1DecoderLayer | Implemented | models/dots1.rb | | +| models/dots1.py | 243 | Dots1Model | Implemented | models/dots1.rb | | +| models/dots1.py | 271 | Model | Implemented | models/dots1.rb | | +| models/ernie4_5.py | 15 | ModelArgs | Implemented | models/ernie4_5.rb | | +| models/ernie4_5.py | 31 | Attention | Implemented | models/ernie4_5.rb | | +| models/ernie4_5.py | 83 | MLP | Implemented | models/ernie4_5.rb | | +| models/ernie4_5.py | 94 | DecoderLayer | Implemented | models/ernie4_5.rb | | +| models/ernie4_5.py | 117 | Ernie45Model | Implemented | models/ernie4_5.rb | | +| models/ernie4_5.py | 142 | Model | Implemented | models/ernie4_5.rb | | +| models/ernie4_5_moe.py | 16 | ModelArgs | Implemented | models/ernie4_5_moe.rb | | +| models/ernie4_5_moe.py | 42 | Attention | Partial | models/ernie4_5_moe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/ernie4_5_moe.py | 94 | Ernie4_5_MLP | Partial | models/ernie4_5_moe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/ernie4_5_moe.py | 105 | Ernie4_5_MoeMLP | Partial | models/ernie4_5_moe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/ernie4_5_moe.py | 163 | Ernie4_5_DecoderLayer | Partial | models/ernie4_5_moe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/ernie4_5_moe.py | 211 | Ernie45Model | Partial | models/ernie4_5_moe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/ernie4_5_moe.py | 238 | Model | Implemented | models/ernie4_5_moe.rb | | +| models/exaone.py | 15 | ModelArgs | Implemented | models/exaone.rb | | +| models/exaone.py | 34 | AttentionModule | Implemented | models/exaone.rb | | +| models/exaone.py | 79 | Attention | Implemented | models/exaone.rb | | +| models/exaone.py | 85 | MLP | Implemented | models/exaone.rb | | +| models/exaone.py | 98 | TransformerBlock | Implemented | models/exaone.rb | | +| models/exaone.py | 117 | ExaoneModel | Implemented | models/exaone.rb | | +| models/exaone.py | 142 | Model | Implemented | models/exaone.rb | | +| models/exaone4.py | 16 | ModelArgs | Implemented | models/exaone4.rb | | +| models/exaone4.py | 34 | Attention | Implemented | models/exaone4.rb | | +| models/exaone4.py | 98 | MLP | Implemented | models/exaone4.rb | | +| models/exaone4.py | 109 | TransformerBlock | Implemented | models/exaone4.rb | | +| models/exaone4.py | 137 | ExaoneModel | Implemented | models/exaone4.rb | | +| models/exaone4.py | 187 | Model | Implemented | models/exaone4.rb | | +| models/exaone_moe.py | 18 | ModelArgs | Implemented | models/exaone_moe.rb | | +| models/exaone_moe.py | 88 | MoEGate | Implemented | models/exaone_moe.rb | | +| models/exaone_moe.py | 113 | MLP | Implemented | models/exaone_moe.rb | | +| models/exaone_moe.py | 126 | MoE | Implemented | models/exaone_moe.rb | | +| models/exaone_moe.py | 164 | Attention | Implemented | models/exaone_moe.rb | | +| models/exaone_moe.py | 237 | DecoderLayer | Implemented | models/exaone_moe.rb | | +| models/exaone_moe.py | 262 | ExaoneMoEModel | Implemented | models/exaone_moe.rb | Name variant in Ruby: ExaoneMoeModel | +| models/exaone_moe.py | 307 | Model | Implemented | models/exaone_moe.rb | | +| models/falcon_h1.py | 22 | ModelArgs | Implemented | models/falcon_h1.rb | | +| models/falcon_h1.py | 76 | FalconH1RMSNormGated | Partial | models/falcon_h1.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/falcon_h1.py | 116 | FalconH1Attention | Partial | models/falcon_h1.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/falcon_h1.py | 179 | FalconH1Mixer | Partial | models/falcon_h1.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/falcon_h1.py | 339 | FalconH1MLP | Partial | models/falcon_h1.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/falcon_h1.py | 357 | FalconH1DecoderLayer | Partial | models/falcon_h1.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/falcon_h1.py | 402 | FalconH1Model | Partial | models/falcon_h1.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/falcon_h1.py | 441 | Model | Implemented | models/falcon_h1.rb | | +| models/gemma.py | 13 | ModelArgs | Implemented | models/gemma.rb | | +| models/gemma.py | 27 | RMSNorm | Partial | models/gemma.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, GemmaModel, Model | +| models/gemma.py | 37 | Attention | Implemented | models/gemma.rb | | +| models/gemma.py | 90 | MLP | Implemented | models/gemma.rb | | +| models/gemma.py | 101 | TransformerBlock | Implemented | models/gemma.rb | | +| models/gemma.py | 125 | GemmaModel | Implemented | models/gemma.rb | | +| models/gemma.py | 157 | Model | Implemented | models/gemma.rb | | +| models/gemma2.py | 13 | ModelArgs | Implemented | models/gemma2.rb | | +| models/gemma2.py | 30 | RMSNorm | Partial | models/gemma2.rb | Ruby file exists; classes differ: ModelArgs, Gemma2RMSNorm, Attention, MLP, TransformerBlock, Gemma2Model, Model | +| models/gemma2.py | 40 | Attention | Implemented | models/gemma2.rb | | +| models/gemma2.py | 111 | MLP | Implemented | models/gemma2.rb | | +| models/gemma2.py | 122 | TransformerBlock | Implemented | models/gemma2.rb | | +| models/gemma2.py | 152 | GemmaModel | Partial | models/gemma2.rb | Ruby file exists; classes differ: ModelArgs, Gemma2RMSNorm, Attention, MLP, TransformerBlock, Gemma2Model, Model | +| models/gemma2.py | 184 | Model | Implemented | models/gemma2.rb | | +| models/gemma3.py | 15 | ModelArgs | Implemented | models/gemma3.rb | | +| models/gemma3.py | 30 | Model | Implemented | models/gemma3.rb | | +| models/gemma3_text.py | 16 | ModelArgs | Implemented | models/gemma3_text.rb | | +| models/gemma3_text.py | 35 | Attention | Implemented | models/gemma3_text.rb | | +| models/gemma3_text.py | 104 | RMSNorm | Implemented | models/gemma3_text.rb | | +| models/gemma3_text.py | 114 | MLP | Implemented | models/gemma3_text.rb | | +| models/gemma3_text.py | 135 | TransformerBlock | Implemented | models/gemma3_text.rb | | +| models/gemma3_text.py | 164 | Gemma3Model | Implemented | models/gemma3_text.rb | | +| models/gemma3_text.py | 215 | Model | Implemented | models/gemma3_text.rb | | +| models/gemma3n.py | 17 | TextConfig | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 46 | ModelArgs | Implemented | models/gemma3n.rb | | +| models/gemma3n.py | 51 | RMSNoScale | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 60 | Gemma3nLaurelBlock | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 85 | Gemma3nAttention | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 171 | MLP | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 204 | Gemma3nAltUp | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 283 | Gemma3nDecoderLayer | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 379 | LanguageModel | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 568 | Gemma3n | Partial | models/gemma3n.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/gemma3n.py | 587 | Model | Implemented | models/gemma3n.rb | | +| models/glm.py | 15 | ModelArgs | Implemented | models/glm.rb | | +| models/glm.py | 31 | GLMAttention | Partial | models/glm.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, GLMModel, Model | +| models/glm.py | 95 | GLMMLP | Partial | models/glm.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, GLMModel, Model | +| models/glm.py | 109 | GLMBlock | Partial | models/glm.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, GLMModel, Model | +| models/glm.py | 132 | GLMModel | Implemented | models/glm.rb | | +| models/glm.py | 157 | Model | Implemented | models/glm.rb | | +| models/glm4.py | 14 | ModelArgs | Implemented | models/glm4.rb | | +| models/glm4.py | 31 | Glm4MLP | Implemented | models/glm4.rb | Name variant in Ruby: GLM4MLP | +| models/glm4.py | 45 | Glm4Attention | Implemented | models/glm4.rb | Name variant in Ruby: GLM4Attention | +| models/glm4.py | 107 | Glm4DecoderLayer | Implemented | models/glm4.rb | Name variant in Ruby: GLM4DecoderLayer | +| models/glm4.py | 136 | Glm4Model | Implemented | models/glm4.rb | Name variant in Ruby: GLM4Model | +| models/glm4.py | 163 | Model | Implemented | models/glm4.rb | | +| models/glm4_moe.py | 19 | ModelArgs | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 49 | Attention | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 111 | MLP | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 166 | MoEGate | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 192 | MoE | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 228 | DecoderLayer | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 257 | LanguageModel | Implemented | models/glm4_moe.rb | | +| models/glm4_moe.py | 301 | Model | Implemented | models/glm4_moe.rb | | +| models/glm4_moe_lite.py | 20 | ModelArgs | Implemented | models/glm4_moe_lite.rb | | +| models/glm4_moe_lite.py | 57 | Glm4MoeLiteAttention | Partial | models/glm4_moe_lite.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/glm4_moe_lite.py | 177 | Glm4MoeLiteMLP | Partial | models/glm4_moe_lite.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/glm4_moe_lite.py | 231 | MoEGate | Partial | models/glm4_moe_lite.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/glm4_moe_lite.py | 257 | Glm4MoeLiteMoE | Partial | models/glm4_moe_lite.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/glm4_moe_lite.py | 293 | Glm4MoeLiteDecoderLayer | Partial | models/glm4_moe_lite.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/glm4_moe_lite.py | 320 | Glm4MoeLiteModel | Partial | models/glm4_moe_lite.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/glm4_moe_lite.py | 365 | Model | Implemented | models/glm4_moe_lite.rb | | +| models/glm_moe_dsa.py | 11 | ModelArgs | Implemented | models/glm_moe_dsa.rb | | +| models/glm_moe_dsa.py | 51 | Model | Implemented | models/glm_moe_dsa.rb | | +| models/gpt2.py | 13 | ModelArgs | Implemented | models/gpt2.rb | | +| models/gpt2.py | 29 | Attention | Implemented | models/gpt2.rb | | +| models/gpt2.py | 71 | MLP | Implemented | models/gpt2.rb | | +| models/gpt2.py | 83 | TransformerBlock | Implemented | models/gpt2.rb | | +| models/gpt2.py | 111 | GPT2Model | Implemented | models/gpt2.rb | | +| models/gpt2.py | 154 | Model | Implemented | models/gpt2.rb | | +| models/gpt_bigcode.py | 14 | ModelArgs | Implemented | models/gpt_bigcode.rb | | +| models/gpt_bigcode.py | 34 | Attention | Implemented | models/gpt_bigcode.rb | | +| models/gpt_bigcode.py | 84 | MLP | Implemented | models/gpt_bigcode.rb | | +| models/gpt_bigcode.py | 102 | TransformerBlock | Implemented | models/gpt_bigcode.rb | | +| models/gpt_bigcode.py | 126 | GPTBigCodeModel | Implemented | models/gpt_bigcode.rb | | +| models/gpt_bigcode.py | 162 | Model | Implemented | models/gpt_bigcode.rb | | +| models/gpt_neox.py | 16 | ModelArgs | Implemented | models/gpt_neox.rb | | +| models/gpt_neox.py | 34 | Attention | Implemented | models/gpt_neox.rb | | +| models/gpt_neox.py | 90 | MLP | Implemented | models/gpt_neox.rb | | +| models/gpt_neox.py | 103 | TransformerBlock | Implemented | models/gpt_neox.rb | | +| models/gpt_neox.py | 142 | GPTNeoXModel | Implemented | models/gpt_neox.rb | | +| models/gpt_neox.py | 178 | Model | Implemented | models/gpt_neox.rb | | +| models/gpt_oss.py | 19 | ModelArgs | Implemented | models/gpt_oss.rb | | +| models/gpt_oss.py | 62 | SwiGLU | Partial | models/gpt_oss.rb | Ruby file exists; classes differ: ModelArgs, AttentionBlock, MLPBlock, TransformerBlock, GptOssMoeModel, Model | +| models/gpt_oss.py | 70 | AttentionBlock | Implemented | models/gpt_oss.rb | | +| models/gpt_oss.py | 130 | MLPBlock | Implemented | models/gpt_oss.rb | | +| models/gpt_oss.py | 169 | TransformerBlock | Implemented | models/gpt_oss.rb | | +| models/gpt_oss.py | 192 | GptOssMoeModel | Implemented | models/gpt_oss.rb | | +| models/gpt_oss.py | 232 | Model | Implemented | models/gpt_oss.rb | | +| models/granite.py | 15 | ModelArgs | Implemented | models/granite.rb | | +| models/granite.py | 36 | Attention | Implemented | models/granite.rb | | +| models/granite.py | 92 | MLP | Implemented | models/granite.rb | | +| models/granite.py | 111 | TransformerBlock | Implemented | models/granite.rb | | +| models/granite.py | 137 | GraniteModel | Implemented | models/granite.rb | | +| models/granite.py | 169 | Model | Implemented | models/granite.rb | | +| models/granitemoe.py | 15 | ModelArgs | Implemented | models/granitemoe.rb | | +| models/granitemoe.py | 37 | GraniteMoeAttention | Partial | models/granitemoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoe.py | 92 | GraniteMoeTopKGating | Partial | models/granitemoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoe.py | 110 | GraniteMoeMoE | Partial | models/granitemoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoe.py | 131 | GraniteMoeDecoderLayer | Partial | models/granitemoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoe.py | 155 | GraniteMoEModel | Partial | models/granitemoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoe.py | 184 | Model | Implemented | models/granitemoe.rb | | +| models/granitemoehybrid.py | 23 | ModelArgs | Implemented | models/granitemoehybrid.rb | | +| models/granitemoehybrid.py | 71 | GraniteMoeHybridRMSNormGated | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 83 | GraniteMoeHybridMamba2Mixer | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 226 | GraniteMoeHybridAttention | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 290 | GraniteMoeHybridTopKGating | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 308 | GraniteMoeHybridMoE | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 329 | GraniteMoeHybridSharedMLP | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 344 | GraniteMoeHybridMLP | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 359 | GraniteMoeHybridLayer | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 420 | GraniteMoeHybridModel | Partial | models/granitemoehybrid.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/granitemoehybrid.py | 467 | Model | Implemented | models/granitemoehybrid.rb | | +| models/helium.py | 14 | ModelArgs | Implemented | models/helium.rb | | +| models/helium.py | 31 | HeliumAttention | Partial | models/helium.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, HeliumModel, Model | +| models/helium.py | 79 | HeliumMLP | Partial | models/helium.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, HeliumModel, Model | +| models/helium.py | 99 | HeliumDecoderLayer | Partial | models/helium.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, HeliumModel, Model | +| models/helium.py | 124 | HeliumModel | Implemented | models/helium.rb | | +| models/helium.py | 155 | Model | Implemented | models/helium.rb | | +| models/hunyuan.py | 15 | ModelArgs | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 51 | DynamicNTKAlphaRoPE | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 75 | Attention | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 144 | MLP | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 155 | Gate | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 164 | MoeBlock | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 210 | DecoderLayer | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 242 | HunYuanModel | Implemented | models/hunyuan.rb | | +| models/hunyuan.py | 279 | Model | Implemented | models/hunyuan.rb | | +| models/hunyuan_v1_dense.py | 14 | ModelArgs | Implemented | models/hunyuan_v1_dense.rb | | +| models/hunyuan_v1_dense.py | 38 | DynamicNTKAlphaRoPE | Implemented | models/hunyuan_v1_dense.rb | | +| models/hunyuan_v1_dense.py | 62 | Attention | Implemented | models/hunyuan_v1_dense.rb | | +| models/hunyuan_v1_dense.py | 136 | MLP | Implemented | models/hunyuan_v1_dense.rb | | +| models/hunyuan_v1_dense.py | 151 | TransformerBlock | Implemented | models/hunyuan_v1_dense.rb | | +| models/hunyuan_v1_dense.py | 177 | HunyuanV1DenseModel | Implemented | models/hunyuan_v1_dense.rb | | +| models/hunyuan_v1_dense.py | 204 | Model | Implemented | models/hunyuan_v1_dense.rb | | +| models/internlm2.py | 14 | ModelArgs | Implemented | models/internlm2.rb | | +| models/internlm2.py | 45 | DynamicNTKScalingRoPE | Partial | models/internlm2.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, InternLM2Model, Model | +| models/internlm2.py | 85 | Attention | Implemented | models/internlm2.rb | | +| models/internlm2.py | 152 | MLP | Implemented | models/internlm2.rb | | +| models/internlm2.py | 163 | TransformerBlock | Implemented | models/internlm2.rb | | +| models/internlm2.py | 184 | InternLM2Model | Implemented | models/internlm2.rb | | +| models/internlm2.py | 211 | Model | Implemented | models/internlm2.rb | | +| models/internlm3.py | 14 | ModelArgs | Implemented | models/internlm3.rb | | +| models/internlm3.py | 46 | DynamicNTKScalingRoPE | Implemented | models/internlm3.rb | | +| models/internlm3.py | 86 | Attention | Implemented | models/internlm3.rb | | +| models/internlm3.py | 150 | MLP | Implemented | models/internlm3.rb | | +| models/internlm3.py | 161 | TransformerBlock | Implemented | models/internlm3.rb | | +| models/internlm3.py | 184 | InternLM2Model | Partial | models/internlm3.rb | Ruby file exists; classes differ: ModelArgs, DynamicNTKScalingRoPE, Attention, MLP, TransformerBlock, InternLM3Model, Model | +| models/internlm3.py | 211 | Model | Implemented | models/internlm3.rb | | +| models/iquestloopcoder.py | 36 | ModelArgs | Implemented | models/iquestloopcoder.rb | | +| models/iquestloopcoder.py | 56 | LoopGateProjection | Implemented | models/iquestloopcoder.rb | | +| models/iquestloopcoder.py | 68 | Attention | Implemented | models/iquestloopcoder.rb | | +| models/iquestloopcoder.py | 117 | MLP | Implemented | models/iquestloopcoder.rb | | +| models/iquestloopcoder.py | 130 | TransformerBlock | Implemented | models/iquestloopcoder.rb | | +| models/iquestloopcoder.py | 141 | IQuestLoopCoderModel | Implemented | models/iquestloopcoder.rb | | +| models/iquestloopcoder.py | 219 | Model | Implemented | models/iquestloopcoder.rb | | +| models/jamba.py | 22 | ModelArgs | Implemented | models/jamba.rb | | +| models/jamba.py | 61 | JambaMLP | Partial | models/jamba.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/jamba.py | 72 | JambaAttention | Partial | models/jamba.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/jamba.py | 127 | JambaMambaMixer | Partial | models/jamba.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/jamba.py | 229 | JambaSparseMoeBlock | Partial | models/jamba.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/jamba.py | 250 | JambaDecoderLayer | Partial | models/jamba.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/jamba.py | 284 | JambaModel | Partial | models/jamba.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/jamba.py | 317 | Model | Implemented | models/jamba.rb | | +| models/kimi_k25.py | 17 | ModelArgs | Implemented | models/kimi_k25.rb | | +| models/kimi_k25.py | 26 | LanguageModel | Partial | models/kimi_k25.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_k25.py | 42 | Model | Implemented | models/kimi_k25.rb | | +| models/kimi_linear.py | 23 | ModelArgs | Implemented | models/kimi_linear.rb | | +| models/kimi_linear.py | 57 | KimiMLP | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 120 | KimiSparseMoE | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 158 | KimiMLAAttention | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 235 | ShortConv1d | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 275 | KimiDeltaAttention | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 388 | KimiDecoderLayer | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 426 | KimiLinearModel | Partial | models/kimi_linear.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_linear.py | 458 | Model | Implemented | models/kimi_linear.rb | | +| models/kimi_vl.py | 14 | TextArgs | Partial | models/kimi_vl.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_vl.py | 46 | ModelArgs | Implemented | models/kimi_vl.rb | | +| models/kimi_vl.py | 54 | LanguageModel | Partial | models/kimi_vl.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/kimi_vl.py | 70 | Model | Implemented | models/kimi_vl.rb | | +| models/lfm2-vl.py | 15 | ModelArgs | Implemented | models/lfm2_vl.rb | | +| models/lfm2-vl.py | 23 | Model | Implemented | models/lfm2_vl.rb | | +| models/lfm2.py | 19 | ModelArgs | Implemented | models/lfm2.rb | | +| models/lfm2.py | 53 | Attention | Partial | models/lfm2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/lfm2.py | 112 | ShortConv | Partial | models/lfm2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/lfm2.py | 173 | MLP | Partial | models/lfm2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/lfm2.py | 197 | Lfm2DecoderLayer | Partial | models/lfm2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/lfm2.py | 237 | Lfm2Model | Partial | models/lfm2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/lfm2.py | 282 | Model | Implemented | models/lfm2.rb | | +| models/lfm2_moe.py | 20 | ModelArgs | Implemented | models/lfm2_moe.rb | | +| models/lfm2_moe.py | 54 | Attention | Implemented | models/lfm2_moe.rb | | +| models/lfm2_moe.py | 113 | ShortConv | Implemented | models/lfm2_moe.rb | | +| models/lfm2_moe.py | 174 | MLP | Implemented | models/lfm2_moe.rb | | +| models/lfm2_moe.py | 189 | Lfm2MoeSparseMoeBlock | Partial | models/lfm2_moe.rb | Ruby file exists; classes differ: ModelArgs, Attention, ShortConv, MLP, SparseMoeBlock, DecoderLayer, Lfm2MoeModel, Model | +| models/lfm2_moe.py | 229 | Lfm2DecoderLayer | Partial | models/lfm2_moe.rb | Ruby file exists; classes differ: ModelArgs, Attention, ShortConv, MLP, SparseMoeBlock, DecoderLayer, Lfm2MoeModel, Model | +| models/lfm2_moe.py | 270 | Lfm2Model | Partial | models/lfm2_moe.rb | Ruby file exists; classes differ: ModelArgs, Attention, ShortConv, MLP, SparseMoeBlock, DecoderLayer, Lfm2MoeModel, Model | +| models/lfm2_moe.py | 315 | Model | Implemented | models/lfm2_moe.rb | | +| models/lille-130m.py | 14 | ModelArgs | Implemented | models/lille_130m.rb | | +| models/lille-130m.py | 27 | Lille130mAttention | Implemented | models/lille_130m.rb | | +| models/lille-130m.py | 79 | Lille130mMLP | Implemented | models/lille_130m.rb | | +| models/lille-130m.py | 94 | Lille130Block | Implemented | models/lille_130m.rb | | +| models/lille-130m.py | 111 | Lille130 | Implemented | models/lille_130m.rb | | +| models/lille-130m.py | 136 | Model | Implemented | models/lille_130m.rb | | +| models/llama.py | 17 | ModelArgs | Implemented | models/llama.rb | | +| models/llama.py | 45 | Attention | Implemented | models/llama.rb | | +| models/llama.py | 105 | MLP | Implemented | models/llama.rb | | +| models/llama.py | 124 | TransformerBlock | Implemented | models/llama.rb | | +| models/llama.py | 151 | LlamaModel | Implemented | models/llama.rb | | +| models/llama.py | 200 | Model | Implemented | models/llama.rb | | +| models/llama4.py | 17 | TextArgs | Implemented | models/llama4.rb | | +| models/llama4.py | 43 | ModelArgs | Implemented | models/llama4.rb | | +| models/llama4.py | 51 | Attention | Implemented | models/llama4.rb | | +| models/llama4.py | 137 | MLP | Implemented | models/llama4.rb | | +| models/llama4.py | 152 | MoE | Implemented | models/llama4.rb | | +| models/llama4.py | 175 | TransformerBlock | Implemented | models/llama4.rb | | +| models/llama4.py | 208 | LlamaModel | Implemented | models/llama4.rb | | +| models/llama4.py | 258 | LanguageModel | Implemented | models/llama4.rb | | +| models/llama4.py | 277 | Model | Implemented | models/llama4.rb | | +| models/llama4_text.py | 14 | ModelArgs | Implemented | models/llama4_text.rb | | +| models/llama4_text.py | 31 | Attention | Implemented | models/llama4_text.rb | | +| models/llama4_text.py | 91 | MLP | Implemented | models/llama4_text.rb | | +| models/llama4_text.py | 102 | TransformerBlock | Implemented | models/llama4_text.rb | | +| models/llama4_text.py | 129 | LanguageModel | Implemented | models/llama4_text.rb | | +| models/llama4_text.py | 158 | Model | Implemented | models/llama4_text.rb | | +| models/longcat_flash.py | 18 | ModelArgs | Implemented | models/longcat_flash.rb | | +| models/longcat_flash.py | 48 | LongcatFlashMLA | Partial | models/longcat_flash.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash.py | 182 | LongcatFlashMLP | Partial | models/longcat_flash.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash.py | 195 | LongcatFlashTopkRouter | Partial | models/longcat_flash.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash.py | 231 | LongcatFlashMoE | Partial | models/longcat_flash.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash.py | 280 | LongcatFlashDecoderLayer | Partial | models/longcat_flash.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash.py | 329 | LongcatFlashModel | Partial | models/longcat_flash.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash.py | 355 | Model | Implemented | models/longcat_flash.rb | | +| models/longcat_flash_ngram.py | 16 | ModelArgs | Implemented | models/longcat_flash_ngram.rb | | +| models/longcat_flash_ngram.py | 48 | NgramEmbedding | Partial | models/longcat_flash_ngram.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash_ngram.py | 146 | LongcatFlashNgramModel | Partial | models/longcat_flash_ngram.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/longcat_flash_ngram.py | 172 | Model | Implemented | models/longcat_flash_ngram.rb | | +| models/mamba.py | 15 | ModelArgs | Implemented | models/mamba.rb | | +| models/mamba.py | 54 | MambaBlock | Implemented | models/mamba.rb | | +| models/mamba.py | 163 | ResidualBlock | Implemented | models/mamba.rb | | +| models/mamba.py | 173 | Mamba | Partial | models/mamba.rb | Ruby file exists; classes differ: ModelArgs, MambaBlock, ResidualBlock, MambaModel, Model | +| models/mamba.py | 189 | Model | Implemented | models/mamba.rb | | +| models/mamba2.py | 17 | ModelArgs | Implemented | models/mamba2.rb | | +| models/mamba2.py | 44 | MambaRMSNormGated | Implemented | models/mamba2.rb | | +| models/mamba2.py | 56 | Mamba2Block | Implemented | models/mamba2.rb | | +| models/mamba2.py | 196 | ResidualBlock | Implemented | models/mamba2.rb | | +| models/mamba2.py | 209 | Mamba2 | Partial | models/mamba2.rb | Ruby file exists; classes differ: ModelArgs, MambaRMSNormGated, Mamba2Block, ResidualBlock, Mamba2Model, Model | +| models/mamba2.py | 232 | Model | Implemented | models/mamba2.rb | | +| models/mimo.py | 15 | ModelArgs | Implemented | models/mimo.rb | | +| models/mimo.py | 32 | Attention | Implemented | models/mimo.rb | | +| models/mimo.py | 86 | MLP | Implemented | models/mimo.rb | | +| models/mimo.py | 97 | TransformerBlock | Implemented | models/mimo.rb | | +| models/mimo.py | 123 | MiMoModel | Implemented | models/mimo.rb | Name variant in Ruby: MimoModel | +| models/mimo.py | 158 | Model | Implemented | models/mimo.rb | | +| models/mimo_v2_flash.py | 18 | ModelArgs | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 54 | Attention | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 127 | MLP | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 182 | MoEGate | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 212 | MoE | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 240 | DecoderLayer | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 265 | LanguageModel | Implemented | models/mimo_v2_flash.rb | | +| models/mimo_v2_flash.py | 305 | Model | Implemented | models/mimo_v2_flash.rb | | +| models/minicpm.py | 15 | ModelArgs | Implemented | models/minicpm.rb | | +| models/minicpm.py | 34 | MLP | Implemented | models/minicpm.rb | | +| models/minicpm.py | 45 | Attention | Implemented | models/minicpm.rb | | +| models/minicpm.py | 114 | DecoderLayer | Implemented | models/minicpm.rb | | +| models/minicpm.py | 144 | MiniCPMModel | Implemented | models/minicpm.rb | | +| models/minicpm.py | 173 | Model | Implemented | models/minicpm.rb | | +| models/minicpm3.py | 15 | ModelArgs | Implemented | models/minicpm3.rb | | +| models/minicpm3.py | 39 | Attention | Implemented | models/minicpm3.rb | | +| models/minicpm3.py | 152 | MLP | Implemented | models/minicpm3.rb | | +| models/minicpm3.py | 163 | DecoderLayer | Implemented | models/minicpm3.rb | | +| models/minicpm3.py | 193 | MiniCPM3Model | Implemented | models/minicpm3.rb | | +| models/minicpm3.py | 224 | Model | Implemented | models/minicpm3.rb | | +| models/minimax.py | 16 | ModelArgs | Implemented | models/minimax.rb | | +| models/minimax.py | 58 | ShardedRMSNorm | Partial | models/minimax.rb | Ruby file exists; classes differ: ModelArgs, Attention, SparseMoeBlock, DecoderLayer, MiniMaxModel, Model | +| models/minimax.py | 86 | MiniMaxAttention | Partial | models/minimax.rb | Ruby file exists; classes differ: ModelArgs, Attention, SparseMoeBlock, DecoderLayer, MiniMaxModel, Model | +| models/minimax.py | 162 | MiniMaxSparseMoeBlock | Partial | models/minimax.rb | Ruby file exists; classes differ: ModelArgs, Attention, SparseMoeBlock, DecoderLayer, MiniMaxModel, Model | +| models/minimax.py | 200 | MiniMaxDecoderLayer | Partial | models/minimax.rb | Ruby file exists; classes differ: ModelArgs, Attention, SparseMoeBlock, DecoderLayer, MiniMaxModel, Model | +| models/minimax.py | 224 | MiniMaxModel | Implemented | models/minimax.rb | | +| models/minimax.py | 254 | Model | Implemented | models/minimax.rb | | +| models/ministral3.py | 18 | ModelArgs | Implemented | models/ministral3.rb | | +| models/ministral3.py | 55 | Attention | Implemented | models/ministral3.rb | | +| models/ministral3.py | 114 | MLP | Implemented | models/ministral3.rb | | +| models/ministral3.py | 128 | TransformerBlock | Implemented | models/ministral3.rb | | +| models/ministral3.py | 156 | LanguageModel | Implemented | models/ministral3.rb | | +| models/ministral3.py | 245 | Model | Implemented | models/ministral3.rb | | +| models/mistral3.py | 15 | ModelArgs | Implemented | models/mistral3.rb | | +| models/mistral3.py | 24 | Model | Implemented | models/mistral3.rb | | +| models/mixtral.py | 14 | ModelArgs | Implemented | models/mixtral.rb | | +| models/mixtral.py | 35 | MixtralAttention | Partial | models/mixtral.rb | Ruby file exists; classes differ: ModelArgs, Attention, SparseMoeBlock, MixtralDecoderLayer, MixtralModel, Model | +| models/mixtral.py | 97 | MixtralSparseMoeBlock | Partial | models/mixtral.rb | Ruby file exists; classes differ: ModelArgs, Attention, SparseMoeBlock, MixtralDecoderLayer, MixtralModel, Model | +| models/mixtral.py | 124 | MixtralDecoderLayer | Implemented | models/mixtral.rb | | +| models/mixtral.py | 150 | MixtralModel | Implemented | models/mixtral.rb | | +| models/mixtral.py | 184 | Model | Implemented | models/mixtral.rb | | +| models/mla.py | 9 | MultiLinear | Implemented | models/mla.rb | | +| models/mla.py | 45 | QuantizedMultiLinear | Implemented | models/mla.rb | | +| models/nanochat.py | 15 | ModelArgs | Implemented | models/nanochat.rb | | +| models/nanochat.py | 66 | Attention | Implemented | models/nanochat.rb | | +| models/nanochat.py | 144 | MLP | Implemented | models/nanochat.rb | | +| models/nanochat.py | 157 | TransformerBlock | Implemented | models/nanochat.rb | | +| models/nanochat.py | 175 | NanoChatModel | Implemented | models/nanochat.rb | | +| models/nanochat.py | 210 | Model | Implemented | models/nanochat.rb | | +| models/nemotron-nas.py | 15 | AttentionConfig | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 47 | FFNConfig | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 66 | BlockConfig | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 102 | ModelArgs | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 150 | Attention | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 215 | MLP | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 238 | LinearSubblockReplacement | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 250 | TransformerBlock | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 319 | NemotronNASModel | Implemented | models/nemotron_nas.rb | | +| models/nemotron-nas.py | 361 | Model | Implemented | models/nemotron_nas.rb | | +| models/nemotron.py | 14 | ModelArgs | Implemented | models/nemotron.rb | | +| models/nemotron.py | 54 | NemotronLayerNorm1P | Implemented | models/nemotron.rb | | +| models/nemotron.py | 61 | Attention | Implemented | models/nemotron.rb | | +| models/nemotron.py | 123 | MLP | Implemented | models/nemotron.rb | | +| models/nemotron.py | 138 | TransformerBlock | Implemented | models/nemotron.rb | | +| models/nemotron.py | 163 | NemotronModel | Implemented | models/nemotron.rb | | +| models/nemotron.py | 193 | Model | Implemented | models/nemotron.rb | | +| models/nemotron_h.py | 23 | ModelArgs | Implemented | models/nemotron_h.rb | | +| models/nemotron_h.py | 67 | MambaRMSNormGated | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 82 | NemotronHMamba2Mixer | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 228 | NemotronHAttention | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 288 | NemotronHMLP | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 338 | MoEGate | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 363 | NemotronHMoE | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 392 | NemotronHBlock | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 423 | NemotronHModel | Partial | models/nemotron_h.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/nemotron_h.py | 474 | Model | Implemented | models/nemotron_h.rb | | +| models/olmo.py | 21 | ModelArgs | Implemented | models/olmo.rb | | +| models/olmo.py | 42 | TransformerBlock | Implemented | models/olmo.rb | | +| models/olmo.py | 113 | Transformer | Implemented | models/olmo.rb | | +| models/olmo.py | 148 | OlmoModel | Implemented | models/olmo.rb | | +| models/olmo.py | 161 | Model | Implemented | models/olmo.rb | | +| models/olmo2.py | 15 | ModelArgs | Implemented | models/olmo2.rb | | +| models/olmo2.py | 38 | Attention | Implemented | models/olmo2.rb | | +| models/olmo2.py | 103 | MLP | Implemented | models/olmo2.rb | | +| models/olmo2.py | 122 | TransformerBlock | Implemented | models/olmo2.rb | | +| models/olmo2.py | 150 | LlamaModel | Partial | models/olmo2.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, TransformerBlock, OLMo2Model, Model | +| models/olmo2.py | 180 | Model | Implemented | models/olmo2.rb | | +| models/olmo3.py | 16 | ModelArgs | Implemented | models/olmo3.rb | | +| models/olmo3.py | 44 | Olmo3Attention | Implemented | models/olmo3.rb | | +| models/olmo3.py | 127 | Olmo3MLP | Implemented | models/olmo3.rb | | +| models/olmo3.py | 138 | Olmo3DecoderLayer | Implemented | models/olmo3.rb | | +| models/olmo3.py | 166 | Olmo3Model | Implemented | models/olmo3.rb | | +| models/olmo3.py | 204 | Model | Implemented | models/olmo3.rb | | +| models/olmoe.py | 15 | ModelArgs | Implemented | models/olmoe.rb | | +| models/olmoe.py | 41 | Attention | Partial | models/olmoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/olmoe.py | 96 | OlmoeSparseMoeBlock | Partial | models/olmoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/olmoe.py | 128 | TransformerBlock | Partial | models/olmoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/olmoe.py | 149 | OlmoeModel | Partial | models/olmoe.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/olmoe.py | 176 | Model | Implemented | models/olmoe.rb | | +| models/openelm.py | 14 | ModelArgs | Implemented | models/openelm.rb | | +| models/openelm.py | 57 | Attention | Implemented | models/openelm.rb | | +| models/openelm.py | 120 | MLP | Implemented | models/openelm.rb | | +| models/openelm.py | 143 | TransformerBlock | Implemented | models/openelm.rb | | +| models/openelm.py | 165 | OpenELMModel | Implemented | models/openelm.rb | | +| models/openelm.py | 196 | Model | Implemented | models/openelm.rb | | +| models/phi.py | 13 | ModelArgs | Implemented | models/phi.rb | | +| models/phi.py | 31 | PhiAttention | Implemented | models/phi.rb | | +| models/phi.py | 109 | PhiMLP | Implemented | models/phi.rb | | +| models/phi.py | 119 | PhiDecoderLayer | Implemented | models/phi.rb | | +| models/phi.py | 135 | PhiModel | Implemented | models/phi.rb | | +| models/phi.py | 156 | Model | Implemented | models/phi.rb | | +| models/phi3.py | 15 | ModelArgs | Implemented | models/phi3.rb | | +| models/phi3.py | 48 | Attention | Implemented | models/phi3.rb | | +| models/phi3.py | 121 | MLP | Implemented | models/phi3.rb | | +| models/phi3.py | 133 | TransformerBlock | Implemented | models/phi3.rb | | +| models/phi3.py | 159 | Phi3Model | Implemented | models/phi3.rb | | +| models/phi3.py | 190 | Model | Implemented | models/phi3.rb | | +| models/phi3small.py | 15 | ModelArgs | Implemented | models/phi3small.rb | | +| models/phi3small.py | 58 | Attention | Implemented | models/phi3small.rb | | +| models/phi3small.py | 198 | MLP | Implemented | models/phi3small.rb | | +| models/phi3small.py | 212 | TransformerBlock | Implemented | models/phi3small.rb | | +| models/phi3small.py | 241 | Phi3Model | Implemented | models/phi3small.rb | | +| models/phi3small.py | 278 | Model | Implemented | models/phi3small.rb | | +| models/phimoe.py | 15 | ModelArgs | Implemented | models/phimoe.rb | | +| models/phimoe.py | 32 | Attention | Implemented | models/phimoe.rb | | +| models/phimoe.py | 89 | PhiMoESparseMoeBlock | Implemented | models/phimoe.rb | | +| models/phimoe.py | 114 | PhiMoEDecoderLayer | Partial | models/phimoe.rb | Ruby file exists; classes differ: ModelArgs, Attention, PhiMoESparseMoeBlock, DecoderLayer, PhiMoEModel, Model | +| models/phimoe.py | 145 | PhiMoEModel | Implemented | models/phimoe.rb | | +| models/phimoe.py | 172 | Model | Implemented | models/phimoe.rb | | +| models/phixtral.py | 16 | ModelArgs | Implemented | models/phixtral.rb | | +| models/phixtral.py | 37 | RoPEAttention | Implemented | models/phixtral.rb | | +| models/phixtral.py | 87 | MOE | Implemented | models/phixtral.rb | | +| models/phixtral.py | 113 | ParallelBlock | Implemented | models/phixtral.rb | | +| models/phixtral.py | 129 | TransformerDecoder | Implemented | models/phixtral.rb | | +| models/phixtral.py | 145 | Embd | Implemented | models/phixtral.rb | | +| models/phixtral.py | 154 | OutputHead | Implemented | models/phixtral.rb | | +| models/phixtral.py | 164 | Model | Implemented | models/phixtral.rb | | +| models/pipeline.py | 6 | PipelineMixin | Partial | models/pipeline.rb | Ruby file exists but defines no classes | +| models/pixtral.py | 15 | ModelArgs | Implemented | models/pixtral.rb | | +| models/pixtral.py | 26 | Model | Implemented | models/pixtral.rb | | +| models/plamo.py | 15 | ModelArgs | Implemented | models/plamo.rb | | +| models/plamo.py | 28 | Attention | Implemented | models/plamo.rb | | +| models/plamo.py | 108 | MLP | Implemented | models/plamo.rb | | +| models/plamo.py | 122 | PlamoDecoderLayer | Partial | models/plamo.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, PlamoModel, Model | +| models/plamo.py | 156 | PlamoDecoder | Partial | models/plamo.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, PlamoModel, Model | +| models/plamo.py | 164 | PlamoModel | Implemented | models/plamo.rb | | +| models/plamo.py | 192 | Model | Implemented | models/plamo.rb | | +| models/plamo2.py | 18 | ModelArgs | Implemented | models/plamo2.rb | | +| models/plamo2.py | 40 | RMSNorm | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 58 | Mamba | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 223 | Attention | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 295 | MLP | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 312 | PlamoDecoderLayer | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 378 | PlamoDecoder | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 411 | PlamoModel | Partial | models/plamo2.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/plamo2.py | 439 | Model | Implemented | models/plamo2.rb | | +| models/qwen.py | 13 | ModelArgs | Implemented | models/qwen.rb | | +| models/qwen.py | 31 | Attention | Implemented | models/qwen.rb | | +| models/qwen.py | 76 | MLP | Implemented | models/qwen.rb | | +| models/qwen.py | 96 | TransformerBlock | Implemented | models/qwen.rb | | +| models/qwen.py | 117 | QwenModel | Implemented | models/qwen.rb | | +| models/qwen.py | 137 | Model | Implemented | models/qwen.rb | | +| models/qwen2.py | 16 | ModelArgs | Implemented | models/qwen2.rb | | +| models/qwen2.py | 32 | Attention | Implemented | models/qwen2.rb | | +| models/qwen2.py | 87 | MLP | Implemented | models/qwen2.rb | | +| models/qwen2.py | 98 | TransformerBlock | Implemented | models/qwen2.rb | | +| models/qwen2.py | 124 | Qwen2Model | Implemented | models/qwen2.rb | | +| models/qwen2.py | 158 | Model | Implemented | models/qwen2.rb | | +| models/qwen2_moe.py | 15 | ModelArgs | Implemented | models/qwen2_moe.rb | | +| models/qwen2_moe.py | 46 | Attention | Partial | models/qwen2_moe.rb | Ruby file exists; classes differ: ModelArgs, SharedExpertMLP, SparseMoeBlock, DecoderLayer, Qwen2MoeModel, Model | +| models/qwen2_moe.py | 99 | MLP | Partial | models/qwen2_moe.rb | Ruby file exists; classes differ: ModelArgs, SharedExpertMLP, SparseMoeBlock, DecoderLayer, Qwen2MoeModel, Model | +| models/qwen2_moe.py | 110 | Qwen2MoeSparseMoeBlock | Partial | models/qwen2_moe.rb | Ruby file exists; classes differ: ModelArgs, SharedExpertMLP, SparseMoeBlock, DecoderLayer, Qwen2MoeModel, Model | +| models/qwen2_moe.py | 148 | Qwen2MoeDecoderLayer | Partial | models/qwen2_moe.rb | Ruby file exists; classes differ: ModelArgs, SharedExpertMLP, SparseMoeBlock, DecoderLayer, Qwen2MoeModel, Model | +| models/qwen2_moe.py | 174 | Qwen2MoeModel | Implemented | models/qwen2_moe.rb | | +| models/qwen2_moe.py | 205 | Model | Implemented | models/qwen2_moe.rb | | +| models/qwen2_vl.py | 15 | ModelArgs | Implemented | models/qwen2_vl.rb | | +| models/qwen2_vl.py | 26 | Model | Implemented | models/qwen2_vl.rb | | +| models/qwen3.py | 16 | ModelArgs | Implemented | models/qwen3.rb | | +| models/qwen3.py | 32 | Attention | Implemented | models/qwen3.rb | | +| models/qwen3.py | 92 | MLP | Implemented | models/qwen3.rb | | +| models/qwen3.py | 103 | TransformerBlock | Implemented | models/qwen3.rb | | +| models/qwen3.py | 129 | Qwen3Model | Implemented | models/qwen3.rb | | +| models/qwen3.py | 163 | Model | Implemented | models/qwen3.rb | | +| models/qwen3_5.py | 24 | TextModelArgs | Partial | models/qwen3_5.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_5.py | 85 | GatedDeltaNet | Partial | models/qwen3_5.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_5.py | 191 | DecoderLayer | Partial | models/qwen3_5.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_5.py | 225 | Qwen3_5TextModel | Partial | models/qwen3_5.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_5.py | 260 | TextModel | Partial | models/qwen3_5.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_5.py | 338 | ModelArgs | Implemented | models/qwen3_5.rb | | +| models/qwen3_5.py | 349 | Model | Implemented | models/qwen3_5.rb | | +| models/qwen3_5_moe.py | 12 | ModelArgs | Implemented | models/qwen3_5_moe.rb | | +| models/qwen3_5_moe.py | 23 | Model | Implemented | models/qwen3_5_moe.rb | | +| models/qwen3_moe.py | 15 | ModelArgs | Implemented | models/qwen3_moe.rb | | +| models/qwen3_moe.py | 37 | Attention | Partial | models/qwen3_moe.rb | Ruby file exists; classes differ: ModelArgs, SparseMoeBlock, DecoderLayer, Qwen3MoeModel, Model | +| models/qwen3_moe.py | 99 | MLP | Partial | models/qwen3_moe.rb | Ruby file exists; classes differ: ModelArgs, SparseMoeBlock, DecoderLayer, Qwen3MoeModel, Model | +| models/qwen3_moe.py | 110 | Qwen3MoeSparseMoeBlock | Partial | models/qwen3_moe.rb | Ruby file exists; classes differ: ModelArgs, SparseMoeBlock, DecoderLayer, Qwen3MoeModel, Model | +| models/qwen3_moe.py | 142 | Qwen3MoeDecoderLayer | Partial | models/qwen3_moe.rb | Ruby file exists; classes differ: ModelArgs, SparseMoeBlock, DecoderLayer, Qwen3MoeModel, Model | +| models/qwen3_moe.py | 174 | Qwen3MoeModel | Implemented | models/qwen3_moe.rb | | +| models/qwen3_moe.py | 210 | Model | Implemented | models/qwen3_moe.rb | | +| models/qwen3_next.py | 25 | ModelArgs | Implemented | models/qwen3_next.rb | | +| models/qwen3_next.py | 56 | Qwen3NextRMSNormGated | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 71 | Qwen3NextAttention | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 151 | Qwen3NextMLP | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 162 | Qwen3NextGatedDeltaNet | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 298 | Qwen3NextSparseMoeBlock | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 337 | Qwen3NextDecoderLayer | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 372 | Qwen3NextModel | Partial | models/qwen3_next.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/qwen3_next.py | 404 | Model | Implemented | models/qwen3_next.rb | | +| models/qwen3_vl.py | 15 | ModelArgs | Implemented | models/qwen3_vl.rb | | +| models/qwen3_vl.py | 26 | Model | Implemented | models/qwen3_vl.rb | | +| models/qwen3_vl_moe.py | 15 | ModelArgs | Implemented | models/qwen3_vl_moe.rb | | +| models/qwen3_vl_moe.py | 20 | Model | Implemented | models/qwen3_vl_moe.rb | | +| models/recurrent_gemma.py | 15 | ModelArgs | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 39 | RMSNorm | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 79 | Conv1d | Partial | models/recurrent_gemma.rb | Ruby file exists; classes differ: ModelArgs, RMSNorm, RGLRU, RecurrentBlock, LocalAttentionBlock, MLPBlock, ResidualBlock, Griffin, Model | +| models/recurrent_gemma.py | 104 | RGLRU | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 170 | RecurrentBlock | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 220 | LocalAttentionBlock | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 273 | MLPBlock | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 287 | ResidualBlock | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 363 | Griffin | Implemented | models/recurrent_gemma.rb | | +| models/recurrent_gemma.py | 413 | Model | Implemented | models/recurrent_gemma.rb | | +| models/rope_utils.py | 10 | SuScaledRoPE | Implemented | models/rope_utils.rb | | +| models/rope_utils.py | 73 | Llama3RoPE | Implemented | models/rope_utils.rb | | +| models/rope_utils.py | 128 | YarnRoPE | Implemented | models/rope_utils.rb | | +| models/rwkv7.py | 15 | ModelArgs | Implemented | models/rwkv7.rb | | +| models/rwkv7.py | 151 | LayerNormPerHead | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 162 | LoRA | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 198 | TokenShift | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 209 | Rwkv7ChannelMixing | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 232 | Rwkv7TimeMixing | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 371 | Rwkv7Layer | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 392 | Rwkv7Model | Partial | models/rwkv7.rb | Ruby file exists; classes differ: ModelArgs, Model | +| models/rwkv7.py | 412 | Model | Implemented | models/rwkv7.rb | | +| models/seed_oss.py | 15 | ModelArgs | Implemented | models/seed_oss.rb | | +| models/seed_oss.py | 35 | Attention | Implemented | models/seed_oss.rb | | +| models/seed_oss.py | 92 | MLP | Implemented | models/seed_oss.rb | | +| models/seed_oss.py | 103 | TransformerBlock | Implemented | models/seed_oss.rb | | +| models/seed_oss.py | 128 | SeedModel | Implemented | models/seed_oss.rb | | +| models/seed_oss.py | 155 | Model | Implemented | models/seed_oss.rb | | +| models/smollm3.py | 13 | ModelArgs | Implemented | models/smollm3.rb | | +| models/smollm3.py | 29 | NoPE | Implemented | models/smollm3.rb | | +| models/smollm3.py | 36 | Model | Implemented | models/smollm3.rb | | +| models/solar_open.py | 11 | ModelArgs | Implemented | models/solar_open.rb | | +| models/stablelm.py | 14 | ModelArgs | Implemented | models/stablelm.rb | | +| models/stablelm.py | 30 | LayerNormPerHead | Partial | models/stablelm.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, StableLMModel, Model | +| models/stablelm.py | 44 | Attention | Implemented | models/stablelm.rb | | +| models/stablelm.py | 131 | MLP | Implemented | models/stablelm.rb | | +| models/stablelm.py | 142 | DecoderLayer | Implemented | models/stablelm.rb | | +| models/stablelm.py | 171 | StableLM | Partial | models/stablelm.rb | Ruby file exists; classes differ: ModelArgs, Attention, MLP, DecoderLayer, StableLMModel, Model | +| models/stablelm.py | 191 | Model | Implemented | models/stablelm.rb | | +| models/starcoder2.py | 13 | ModelArgs | Implemented | models/starcoder2.rb | | +| models/starcoder2.py | 26 | Attention | Implemented | models/starcoder2.rb | | +| models/starcoder2.py | 75 | MLP | Implemented | models/starcoder2.rb | | +| models/starcoder2.py | 85 | TransformerBlock | Implemented | models/starcoder2.rb | | +| models/starcoder2.py | 112 | Starcoder2Model | Implemented | models/starcoder2.rb | | +| models/starcoder2.py | 143 | Model | Implemented | models/starcoder2.rb | | +| models/step3p5.py | 25 | ClampedSwiGLU | Partial | models/step3p5.rb | Ruby file exists; classes differ: ModelArgs, ZeroCenteredRMSNorm, Step3p5MLP, Step3p5MoEGate, Step3p5MoE, Step3p5Attention, Step3p5DecoderLayer, Step3p5Model, Model | +| models/step3p5.py | 35 | ModelArgs | Implemented | models/step3p5.rb | | +| models/step3p5.py | 66 | ZeroCenteredRMSNorm | Implemented | models/step3p5.rb | | +| models/step3p5.py | 76 | Step3p5MLP | Implemented | models/step3p5.rb | | +| models/step3p5.py | 116 | Step3p5MoEGate | Implemented | models/step3p5.rb | | +| models/step3p5.py | 137 | Step3p5MoE | Implemented | models/step3p5.rb | | +| models/step3p5.py | 186 | Step3p5Attention | Implemented | models/step3p5.rb | | +| models/step3p5.py | 281 | Step3p5DecoderLayer | Implemented | models/step3p5.rb | | +| models/step3p5.py | 327 | Step3p5Model | Implemented | models/step3p5.rb | | +| models/step3p5.py | 376 | Model | Implemented | models/step3p5.rb | | +| models/switch_layers.py | 27 | QuantizedSwitchLinear | Implemented | models/switch_layers.rb | | +| models/switch_layers.py | 93 | SwitchLinear | Implemented | models/switch_layers.rb | | +| models/switch_layers.py | 152 | SwiGLU | Partial | models/switch_layers.rb | Ruby file exists; classes differ: SwitchLinear, QuantizedSwitchLinear, SwitchGLU, SwitchMLP | +| models/switch_layers.py | 160 | SwitchGLU | Implemented | models/switch_layers.rb | | +| models/switch_layers.py | 202 | SwitchMLP | Implemented | models/switch_layers.rb | | +| models/telechat3.py | 15 | ModelArgs | Implemented | models/telechat3.rb | | +| models/telechat3.py | 33 | Telechat3Attention | Implemented | models/telechat3.rb | | +| models/telechat3.py | 103 | Telechat3MLP | Implemented | models/telechat3.rb | | +| models/telechat3.py | 120 | Telechat3DecoderLayer | Implemented | models/telechat3.rb | | +| models/telechat3.py | 145 | Telechat3Model | Implemented | models/telechat3.rb | | +| models/telechat3.py | 178 | Model | Implemented | models/telechat3.rb | | +| models/youtu_llm.py | 15 | ModelArgs | Implemented | models/youtu_llm.rb | | +| models/youtu_llm.py | 38 | YoutuLLMAttention | Implemented | models/youtu_llm.rb | | +| models/youtu_llm.py | 141 | YoutuLLMMLP | Implemented | models/youtu_llm.rb | | +| models/youtu_llm.py | 158 | YoutuLLMDecoderLayer | Implemented | models/youtu_llm.rb | | +| models/youtu_llm.py | 182 | YoutuLLMModel | Implemented | models/youtu_llm.rb | | +| models/youtu_llm.py | 211 | Model | Implemented | models/youtu_llm.rb | | +| quant/awq.py | 25 | ScaleConfig | Missing | - | | +| quant/awq.py | 34 | AWQConfig | Missing | - | | +| quant/awq.py | 430 | Catcher | Missing | - | | +| quant/gptq.py | 40 | Catcher | Missing | - | | +| server.py | 64 | StopCondition | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 185 | LRUPromptCache | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 188 | CacheEntry | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 194 | SearchResult | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 352 | ModelDescription | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 359 | SamplingArguments | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 369 | LogitsProcessorArguments | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 376 | GenerationArguments | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 392 | CompletionRequest | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 403 | GenerationContext | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 424 | Response | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 432 | TimeBudget | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 472 | ModelProvider | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 618 | ResponseGenerator | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| server.py | 1059 | APIHandler | Partial | server.rb | Ruby file exists; classes differ: ChatCompletionRequest, ChatCompletionResponse, ChatCompletionChunk, ModelsListResponse | +| share.py | 27 | DirectoryEntry | Missing | - | | +| tokenizer_utils.py | 11 | StreamingDetokenizer | Implemented | tokenizer_utils.rb | | +| tokenizer_utils.py | 61 | NaiveStreamingDetokenizer | Partial | tokenizer_utils.rb | Ruby file exists; classes differ: TokenizerWrapper, StreamingDetokenizer | +| tokenizer_utils.py | 107 | SPMStreamingDetokenizer | Partial | tokenizer_utils.rb | Ruby file exists; classes differ: TokenizerWrapper, StreamingDetokenizer | +| tokenizer_utils.py | 155 | BPEStreamingDetokenizer | Partial | tokenizer_utils.rb | Ruby file exists; classes differ: TokenizerWrapper, StreamingDetokenizer | +| tokenizer_utils.py | 256 | TokenizerWrapper | Implemented | tokenizer_utils.rb | | +| tokenizer_utils.py | 401 | NewlineTokenizer | Partial | tokenizer_utils.rb | Ruby file exists; classes differ: TokenizerWrapper, StreamingDetokenizer | +| tuner/callbacks.py | 16 | TrainingCallback | Missing | - | | +| tuner/callbacks.py | 27 | WandBCallback | Missing | - | | +| tuner/callbacks.py | 65 | SwanLabCallback | Missing | - | | +| tuner/datasets.py | 11 | TextDataset | Missing | - | | +| tuner/datasets.py | 39 | ChatDataset | Missing | - | | +| tuner/datasets.py | 86 | CompletionsDataset | Missing | - | | +| tuner/datasets.py | 136 | ConcatenatedDataset | Missing | - | | +| tuner/datasets.py | 158 | CacheDataset | Missing | - | | +| tuner/dora.py | 9 | DoRALinear | Missing | - | | +| tuner/dora.py | 131 | DoRAEmbedding | Missing | - | | +| tuner/lora.py | 11 | LoRALinear | Implemented | tuner/lora.rb | | +| tuner/lora.py | 101 | LoRASwitchLinear | Partial | tuner/lora.rb | Ruby file exists; classes differ: LoRALinear, LoRAEmbedding | +| tuner/lora.py | 198 | LoRAEmbedding | Implemented | tuner/lora.rb | | +| tuner/trainer.py | 37 | TrainingArgs | Missing | - | | diff --git a/prd/2026_02_25_python_ruby_parity_prd.md b/prd/2026_02_25_python_ruby_parity_prd.md new file mode 100644 index 0000000..7928fcd --- /dev/null +++ b/prd/2026_02_25_python_ruby_parity_prd.md @@ -0,0 +1,281 @@ +# Python `mlx-lm` -> Ruby `mlx-ruby-lm` Parity PRD (Execution Revision) + +**Status:** Active (Architecture Parity Achieved; Class-Level Completion In Progress) +**Date:** 2026-02-26 +**Supersedes:** prior draft from 2026-02-25 +**References:** +- [Python-Ruby Parity Checklist](2026_02_25_python_ruby_parity_checklist.md) +- [Parity Inventory Snapshot](../test/reports/python_ruby_parity_inventory_snapshot.json) + +## 1. Purpose + +Maintain execution-grade parity for inference-focused model coverage between Python +`mlx-lm` and Ruby `mlx-ruby-lm`, with reproducible gates and artifacted validation. + +This revision updates the plan to match current repo reality: + +- inventory parity is complete at architecture-key level +- class-level parity now has an explicit closure plan (`768` classes total) +- inventory generation/checking is task-based (not script-based) +- ONNX reporting is task-based with Markdown/JSON artifacts +- default ONNX full-export behavior is gated to avoid unnecessary heavy runs + +## 2. Baseline Snapshot (Frozen) + +Source: `test/reports/python_ruby_parity_inventory_snapshot.json` +Regenerate with: `bundle exec rake parity:inventory` +Validate with: `bundle exec rake parity:inventory_check` + +- Python model files in `mlx-lm/mlx_lm/models`: **116** +- Python shared infra files: **10** +- Python architecture files: **106** +- Ruby model files in `lib/mlx_lm/models`: **115** +- Ruby shared infra files: **2** +- Ruby architecture files: **113** +- Ruby registered architecture keys: **106** +- Current architecture key gap: **0** +- Python classes discovered (`mlx-lm/mlx_lm/**/*.py`): **768** +- Ruby class parity status: **527 Implemented / 221 Partial / 20 Missing** + +## 3. Scope + +### In Scope + +- Maintain architecture-key parity between upstream Python model set and Ruby registry +- Execute class-level parity closure for remaining Partial/Missing classes +- Keep governance gates reproducible and green +- Run ONNX compatibility coverage and produce artifacted reports +- Track unsupported ONNX ops with exact invocation details + +### Out of Scope + +- Speculative decoding R&D paths +- Performance tuning and benchmark optimization work + +## 4. Hard Gates (Must Pass) + +### G1: Inventory Freeze Gate + +**Requirement:** `test/reports/python_ruby_parity_inventory_snapshot.json` is current. + +- Regenerate snapshot: `bundle exec rake parity:inventory` +- Check freshness: `bundle exec rake parity:inventory_check` +- CI/test gate: `test/parity/governance_parity_gates_test.rb` + +### G2: ONNX Submodule Minimum Commit Gate + +**Requirement:** `mlx-ruby/submodules/mlx-onnx` includes required lowering support. + +- Minimum commit: `33d4b2eed2aa342f0836298dda60b6c5eb011b0f` +- Validation method: `git merge-base --is-ancestor HEAD` +- CI/test gate: `test/parity/governance_parity_gates_test.rb` + +### G3: ONNX Reporting Gate + +**Requirement:** compat report generation is reproducible and artifacted. + +- Task: `bundle exec rake onnx:report` +- Artifacts: + - `test/reports/onnx_compat_test_output.txt` + - `test/reports/onnx_compat_full_report.json` + - `test/reports/onnx_compat_full_report.md` + - `test/reports/onnx_compat_missing_ops_invocations.csv` + +### G4: Class Parity Checklist Gate + +**Requirement:** class-level parity checklist is current and reviewed before merges +that affect parity-sensitive files. + +- Source of truth: `prd/2026_02_25_python_ruby_parity_checklist.md` +- Required fields: Python class, status, Ruby reference, notes +- Change policy: PRs changing model/server/generation/tuner code must update the checklist + +## 5. Public API / Workflow Contract Updates + +### Inventory workflow contract + +Inventory generation must run via task class, not ad hoc scripts: + +1. `tasks/parity_inventory_task.rb` is the implementation unit +2. `rake parity:inventory` writes snapshot output +3. `rake parity:inventory_check` enforces freshness and fails on drift + +### ONNX validation artifact contract + +Compat reporting must produce machine-readable and human-readable outputs: + +1. full per-model compat JSON payload +2. markdown summary report +3. unsupported-op union and per-model incidence +4. unsupported-op invocation list (indexed) + +## 6. Execution Plan (Post-Parity) + +## Phase A: Governance Stability + +**Objective:** Keep parity gates deterministic and low-maintenance. + +**Exit Criteria:** +- inventory and submodule governance tests are green +- inventory task outputs remain stable + +## Phase B: ONNX Compat Matrix Quality + +**Objective:** maximize compat coverage and minimize unsupported ops. + +**Exit Criteria:** +- compat suite runs for all ONNX model tests +- missing-op list is explicit with model + invocation context +- report artifacts regenerate cleanly through `rake onnx:report` + +## Phase C: ONNX Export Reliability + +**Objective:** reduce hangs and classify failures predictably. + +**Exit Criteria:** +- subprocess timeout behavior is enforced +- full export remains opt-in by default +- failing exports are categorized in report outputs + +## Phase D: Core Runtime API Class Closure + +**Objective:** close high-impact runtime API gaps that currently block feature parity. + +**Target areas:** +- `generate.py`: `GenerationResponse`, `BatchStats`, `BatchResponse`, `Batch`, `BatchGenerator`, `Response` +- `tokenizer_utils.py`: `NaiveStreamingDetokenizer`, `SPMStreamingDetokenizer`, + `BPEStreamingDetokenizer`, `NewlineTokenizer` +- `models/cache.py`: `_BaseCache` alignment, `ConcatenateKVCache`, + `BatchKVCache`, `BatchRotatingKVCache` + +**Exit Criteria:** +- targeted classes move from `Partial` to `Implemented` +- batch generation and detokenizer behavior have parity tests +- cache-family parity tests cover batch cache variants + +## Phase E: Server Surface Completion + +**Objective:** complete Python server class parity for request/response and runtime paths. + +**Target areas (`server.py`):** +- `StopCondition` +- `LRUPromptCache`, `CacheEntry`, `SearchResult` +- `ModelDescription`, `SamplingArguments`, `LogitsProcessorArguments`, + `GenerationArguments`, `CompletionRequest`, `GenerationContext`, `Response`, `TimeBudget` +- `ModelProvider`, `ResponseGenerator`, `APIHandler` + +**Exit Criteria:** +- `/v1/chat/completions`, `/v1/completions`, and `/v1/models` behavior reaches parity +- stop/tool and streaming behaviors are covered by parity tests +- `server.py` class inventory has zero `Partial` entries + +## Phase F: Model Internal Class Closure (High-Delta Files) + +**Objective:** close class-level internal gaps in architecture files that are currently only partially mirrored. + +**Wave F1 (SSM/Hybrid focus):** +- `models/falcon_h1.py`, `models/jamba.py`, `models/nemotron_h.py`, + `models/plamo2.py`, `models/qwen3_next.py`, `models/rwkv7.py`, `models/kimi_linear.py` + +**Wave F2 (MoE/DeepSeek focus):** +- `models/afm7.py`, `models/bailing_moe_linear.py`, `models/deepseek_v2.py`, + `models/deepseek_v3.py`, `models/deepseek_v32.py`, `models/granitemoehybrid.py`, + `models/glm4_moe_lite.py`, `models/longcat_flash.py`, `models/lfm2.py`, `models/minimax.py` + +**Wave F3 (Naming/alias cleanup):** +- close remaining class-name drift where structure exists but names differ + (for example `DBRX` vs `DbrxModel`, `CohereModel` vs `Cohere2Model`) + +**Exit Criteria:** +- high-delta model files above have no unresolved structural class gaps +- class parity `Partial` count is reduced materially from baseline (`221`) +- model parity and ONNX compat tests remain green for touched models + +## Phase G: Missing Utility Modules + +**Objective:** implement currently missing non-model utility classes. + +**Target areas (`Missing` classes):** +- `evaluate.py`: `MLXLM` +- `share.py`: `DirectoryEntry` +- `gguf.py`: `TokenType`, `GGMLFileType`, `HfVocab` +- `quant/awq.py`, `quant/gptq.py`: `ScaleConfig`, `AWQConfig`, `Catcher` + +**Exit Criteria:** +- utility class set above is implemented or explicitly superseded with documented mapping +- missing-class count is reduced to only tuner/training classes + +## Phase H: Tuner and Training Stack Parity + +**Objective:** complete the training-side classes currently missing in Ruby. + +**Target areas:** +- `tuner/trainer.py`: `TrainingArgs` +- `tuner/callbacks.py`: `TrainingCallback`, `WandBCallback`, `SwanLabCallback` +- `tuner/datasets.py`: `TextDataset`, `ChatDataset`, `CompletionsDataset`, + `ConcatenatedDataset`, `CacheDataset` +- `tuner/dora.py`: `DoRALinear`, `DoRAEmbedding` + +**Exit Criteria:** +- all tuner/training classes above have Ruby implementations +- training smoke tests cover LoRA + DoRA paths and dataset adapters +- `Missing` class count reaches `0` + +## Phase I: Final Parity Closure and Maintenance + +**Objective:** make class-level parity drift-resistant after closure. + +**Exit Criteria:** +- class inventory remains at `Missing = 0` +- remaining `Partial` entries are either closed or explicitly accepted as intentional aliases +- parity checklist updates are enforced in PR review flow + +## 7. Test Strategy + +### Governance tests + +- inventory snapshot freshness gate +- `mlx-onnx` minimum commit gate +- class parity checklist freshness/review gate + +### ONNX tests + +- compat-report path for full model set +- optional full-export path gated by env var + +### Class parity tests + +- per-phase class closure tests for touched modules +- API parity tests for generation/tokenizer/cache/server surfaces +- model parity tests for high-delta architecture files (Phase F waves) +- tuner smoke tests for callbacks/datasets/DoRA/trainer paths (Phase H) + +### Execution policy + +Every report refresh should produce committed artifacts when used for parity review: +- markdown summary +- full JSON payload +- missing-op invocation CSV + +## 8. Execution Status (Current) + +- [x] Architecture-key parity achieved (`106 / 106`) +- [x] Inventory task class implemented: `tasks/parity_inventory_task.rb` +- [x] Rake inventory tasks active: `parity:inventory`, `parity:inventory_check` +- [x] Governance parity gates active in parity test suite +- [x] ONNX report task active: `onnx:report` +- [x] ONNX report includes Markdown + JSON + invocation CSV artifacts +- [x] Full class-level checklist published (`768` classes tracked) +- [ ] Class parity closure backlog active (`221 Partial / 20 Missing`) +- [ ] Phase D-I implementation not yet started in code + +## 9. Success Definition + +This parity program is successful when: + +- architecture-key parity remains at zero gap against upstream inventory, +- class-level parity reaches `Missing = 0`, +- remaining `Partial` entries are zero or explicitly justified aliases, +- governance gates fail fast on drift, +- ONNX compat reports are reproducible and actionable, +- unresolved ONNX gaps are tracked by model, op, and invocation. diff --git a/prd/PRD.md b/prd/PRD.md deleted file mode 100644 index 0ff2ad1..0000000 --- a/prd/PRD.md +++ /dev/null @@ -1,41 +0,0 @@ -# Product Requirements Document: mlx-ruby-lm - -## 1. Product Overview - -**mlx-ruby-lm** is a Ruby gem that provides a complete port of the Python `mlx-lm` -package for large language model inference and fine-tuning, built on top of the -`mlx-ruby` gem. It targets 100% functional parity with the Python implementation -while being idiomatically Ruby. - -## 2. Goals - -- Provide Ruby developers with a native LLM inference library on MLX -- 100% functional parity with Python mlx-lm v0.30.7 -- Idiomatic Ruby API (snake_case, blocks, keyword arguments, Enumerable) -- Every feature validated by parity tests comparing Ruby ↔ Python output - -## 3. Non-Goals - -- GUI or web frontend -- Custom model training from scratch (only LoRA fine-tuning) -- Support for non-MLX backends - -## 4. Technical Architecture - -- **Runtime:** Ruby >= 3.3 -- **Test framework:** Minitest -- **ML backend:** mlx-ruby gem (wraps MLX C++ via native extension) -- **Tokenizer:** `tokenizers` gem (HuggingFace Rust bindings) -- **Weights:** `safetensors` gem for loading/saving -- **Namespace:** `MlxLm` module - -## 5. Implementation Phases - -See `prd/conversion_plan.md` for the complete 12-phase plan with parity tests. - -## 6. Success Criteria - -- All 76 parity tests pass (comparing Ruby output to Python output) -- `MlxLm.load` + `MlxLm.generate` produces identical output to Python -- All 100+ model architectures supported -- OpenAI-compatible HTTP server works identically diff --git a/prd/conversion_plan.md b/prd/conversion_plan.md deleted file mode 100644 index bdcfa89..0000000 --- a/prd/conversion_plan.md +++ /dev/null @@ -1,750 +0,0 @@ -# MLX-LM Python → Ruby Conversion Plan - -## Overview - -Convert the `mlx-lm` Python package into a Ruby gem (`mlx-ruby-lm`) that provides -100% identical functionality, built on top of the `mlx-ruby` gem. The Python -implementation serves as reference only — all code is rewritten idiomatically in Ruby. - -Each phase includes **parity tests** that run the equivalent operation in both Python -(via `mlx-lm`) and Ruby (via `mlx-ruby-lm`), comparing outputs numerically within -floating-point tolerance. - ---- - -## Phase 1: Project Scaffolding & Core Infrastructure - -**Goal:** Establish the gem structure, configuration system, and weight-loading primitives. - -### Deliverables - -1. **Gem skeleton** - - `mlx-ruby-lm.gemspec` with dependency on `mlx` gem (>= 0.30.7) - - `lib/mlx_lm.rb` entry point - - Directory layout: - ``` - lib/mlx_lm/ - version.rb - model_args.rb # BaseModelArgs equivalent - weight_utils.rb # safetensors loading / tree-map helpers - config.rb # config.json + generation_config.json parser - ``` - - `Rakefile` with test, build tasks - - `test/test_helper.rb` with parity-test harness (runs Python via subprocess, - compares against Ruby) - -2. **BaseModelArgs** - - Ruby data class (or `Struct`/`Data`) mirroring Python `BaseModelArgs` - - Supports arbitrary kwargs → instance attributes - - `from_dict` class method that filters unknown keys with warnings - -3. **Weight loading utilities** - - `load_safetensors(path)` → returns hash of `{ "layer.weight" => MLX::Core::Array }` - - `load_sharded_safetensors(directory)` → loads `model*.safetensors` shards - - `tree_unflatten(flat_hash)` → nested module-state hash - -4. **Config loading** - - Parse `config.json` → `ModelArgs` for the appropriate architecture - - Parse `generation_config.json` → generation defaults (temperature, top_p, etc.) - -### Parity Tests — Phase 1 - -| # | Test | Method | -|---|------|--------| -| 1.1 | `BaseModelArgs` round-trips from a config dict identically | Compare field values | -| 1.2 | `load_safetensors` returns tensors with identical shapes/dtypes | Shape + dtype check | -| 1.3 | `load_safetensors` tensor values match Python within `atol=1e-6` | Numerical comparison | -| 1.4 | `tree_unflatten` produces identical nested key structure | Key-tree diff | -| 1.5 | Config parsing extracts identical model hyperparameters | Field-by-field comparison | - ---- - -## Phase 2: Tokenizer Integration - -**Goal:** Provide tokenizer encode/decode that produces identical token IDs and text. - -### Deliverables - -1. **TokenizerWrapper** - - Wraps the `tokenizers` Ruby gem (HuggingFace Rust tokenizers bindings) - - Falls back to a Python subprocess bridge for tokenizers that require - `trust_remote_code` or custom HF `AutoTokenizer` behavior - - API: `encode(text) → Array`, `decode(ids) → String` - - `eos_token_id`, `bos_token_id`, `vocab_size` - -2. **Chat template support** - - Jinja2-compatible template rendering (via `liquid` or `jinja` Ruby gem, - or a minimal Jinja subset interpreter) - - `apply_chat_template(messages, add_generation_prompt:)` → token IDs - -3. **StreamingDetokenizer** - - Base class + SentencePiece and BPE implementations - - Incremental `add_token(id)` → emits text segments without O(T²) re-decoding - - `finalize` for trailing output - -### Parity Tests — Phase 2 - -| # | Test | Method | -|---|------|--------| -| 2.1 | `encode(text)` produces identical token ID arrays for 10+ diverse strings | Exact match | -| 2.2 | `decode(ids)` produces identical text | Exact string match | -| 2.3 | `apply_chat_template` on a multi-turn conversation → identical IDs | Exact match | -| 2.4 | StreamingDetokenizer accumulates to the same final string | String match | -| 2.5 | Special token handling (BOS, EOS, pad) matches Python | Token ID match | -| 2.6 | `encode` + `decode` round-trip preserves text for 20 samples | String match | - ---- - -## Phase 3: KV Cache & Base Model Architecture (Llama) - -**Goal:** Implement the Llama architecture as the foundational model pattern and -all KV cache variants. - -### Deliverables - -1. **KV Cache implementations** (in `lib/mlx_lm/models/cache.rb`) - - `KVCache` — simple concatenation cache - - `RotatingKVCache` — fixed-size circular buffer - - `QuantizedKVCache` — quantized K/V storage - - `CacheList` — container for per-layer caches - - `make_prompt_cache(model)` factory - - `save_prompt_cache` / `load_prompt_cache` for safetensors serialization - -2. **Llama model** (in `lib/mlx_lm/models/llama.rb`) - - `LlamaModelArgs < BaseModelArgs` - - `LlamaAttention` — GQA with RoPE, KV cache integration - - `LlamaMLP` — gate/up/down projections with SiLU - - `LlamaTransformerBlock` — attention + MLP + RMSNorm - - `LlamaModel` — embedding + N blocks + final norm + lm_head - - Forward pass: `model.call(input_ids, cache: nil) → logits` - -3. **Model registration** - - `MODEL_REGISTRY` hash mapping `model_type` string → model class - - `get_model_class(config)` lookup - -### Parity Tests — Phase 3 - -| # | Test | Method | -|---|------|--------| -| 3.1 | Llama forward pass (random weights, fixed input) → logits match `atol=1e-4` | Numerical | -| 3.2 | RoPE embeddings match Python for multiple sequence lengths | Numerical `atol=1e-6` | -| 3.3 | GQA attention output matches Python | Numerical `atol=1e-5` | -| 3.4 | KVCache state after N forward passes has identical shape/values | Numerical | -| 3.5 | RotatingKVCache wraps correctly at boundary | Numerical + shape | -| 3.6 | QuantizedKVCache quantize/dequantize round-trip within tolerance | Numerical `atol=0.1` | -| 3.7 | Full Llama model with loaded HF weights → identical logits for prompt | Numerical `atol=1e-3` | -| 3.8 | `make_prompt_cache` → save → load round-trip produces identical cache | Numerical | - ---- - -## Phase 4: Sampling & Generation Engine - -**Goal:** Implement token-by-token generation with all sampling strategies, producing -identical output sequences given the same seed. - -### Deliverables - -1. **Sampling** (`lib/mlx_lm/sample_utils.rb`) - - `make_sampler(temperature:, top_p:, top_k:, min_p:, xtc_probability:, xtc_threshold:)` - - Categorical sampling with MLX random - - Greedy (argmax) when temperature == 0 - -2. **Logits processors** (`lib/mlx_lm/sample_utils.rb`) - - `make_logits_processors(repetition_penalty:, repetition_context_size:, logit_bias:)` - - `RepetitionPenalty` — penalizes repeated tokens in context window - - `LogitBias` — additive bias to specific token logits - -3. **Generation core** (`lib/mlx_lm/generate.rb`) - - `generate_step(prompt_tokens, model, cache, sampler, logits_processors)` — yields - `(token, logprobs)` per step - - Prompt prefill with configurable `prefill_step_size` (default 2048) - - `generate(model, tokenizer, prompt, **kwargs)` → `GenerationResult` - - `stream_generate(model, tokenizer, prompt, **kwargs)` → `Enumerator` yielding - text chunks via `StreamingDetokenizer` - - `batch_generate(model, tokenizer, prompts, **kwargs)` → array of results - - Stop conditions: `max_tokens`, EOS, stop strings - -### Parity Tests — Phase 4 - -| # | Test | Method | -|---|------|--------| -| 4.1 | Greedy generation (temp=0) produces identical token sequence | Exact match | -| 4.2 | Top-p sampling with fixed seed → identical tokens | Exact match | -| 4.3 | Top-k sampling with fixed seed → identical tokens | Exact match | -| 4.4 | Min-p sampling with fixed seed → identical tokens | Exact match | -| 4.5 | Repetition penalty changes logits identically | Numerical `atol=1e-6` | -| 4.6 | Logit bias applied correctly | Numerical `atol=1e-6` | -| 4.7 | `generate` end-to-end with Llama → identical output text | String match | -| 4.8 | `stream_generate` yields identical incremental text segments | String match per chunk | -| 4.9 | Stop string detection halts at same position | Token count match | -| 4.10 | `batch_generate` produces identical results to sequential calls | String match | -| 4.11 | Generation timing/token count metadata matches | Approximate match | - ---- - -## Phase 5: Model Loading Pipeline & HuggingFace Integration - -**Goal:** Load models from HuggingFace Hub or local paths, matching Python's `load()` exactly. - -### Deliverables - -1. **HuggingFace Hub download** (`lib/mlx_lm/hub_utils.rb`) - - `snapshot_download(repo_id, revision:, allow_patterns:)` using the `huggingface_hub` - Ruby gem or HTTP API calls to HF - - Local cache management (`~/.cache/huggingface/hub/`) - - Support for gated/private models via HF token - -2. **Model loading** (`lib/mlx_lm/utils.rb`) - - `load(path_or_hf_repo, tokenizer_config:, adapter_path:, lazy:)` → - `[model, tokenizer]` - - Dynamic model class lookup via `MODEL_REGISTRY` - - Weight dequantization / requantization on load - - Adapter (LoRA) weight merging on load - -3. **Quantization-aware loading** - - Detect `quantization` key in config → apply MLX quantization - - Detect `quantization_config` → handle AWQ/GPTQ weight transforms - - `_dequantize_linear(model)` for full-precision fallback - -### Parity Tests — Phase 5 - -| # | Test | Method | -|---|------|--------| -| 5.1 | `load("mlx-community/Llama-3.2-1B-Instruct-4bit")` → identical model config | Field match | -| 5.2 | Loaded model weights have identical shapes and dtypes | Shape + dtype | -| 5.3 | Loaded model weights match numerically (for non-quantized) | `atol=1e-5` | -| 5.4 | Quantized model logits match Python for same prompt | `atol=1e-2` | -| 5.5 | Tokenizer loaded from HF produces identical encodings | Exact match | -| 5.6 | `load` with `lazy=True` defers evaluation correctly | Shape match, lazy check | - ---- - -## Phase 6: Popular Model Architectures (Batch 1) - -**Goal:** Implement the most commonly used architectures beyond Llama. - -### Deliverables - -1. **Mistral** — sliding window attention, different RoPE config -2. **Gemma / Gemma2** — GeGLU, different normalization, logit soft-capping -3. **Qwen2 / Qwen2.5** — bias in attention, different FFN -4. **Phi3** — su/yarn RoPE scaling, block-sparse attention -5. **Mixtral** — Mixture-of-Experts with top-k gating -6. **Starcoder2** — code-focused architecture -7. **Cohere** — layernorm placement variations - -Each model follows the established pattern: -- `ModelArgs < BaseModelArgs` -- `Attention`, `MLP`, `TransformerBlock`, `Model` classes -- Registered in `MODEL_REGISTRY` - -### Parity Tests — Phase 6 - -| # | Test | Method | -|---|------|--------| -| 6.1 | Each model: forward pass with random weights → logits match | `atol=1e-4` | -| 6.2 | Each model: loaded from HF → greedy generation matches Python | Exact token match | -| 6.3 | Mixtral: MoE routing selects same experts | Expert index match | -| 6.4 | Mixtral: MoE output matches Python | `atol=1e-4` | -| 6.5 | Gemma: soft-capping produces identical logits | `atol=1e-5` | -| 6.6 | Phi3: su/yarn RoPE matches Python | `atol=1e-5` | -| 6.7 | Each model: end-to-end `generate` matches Python output | String match | - ---- - -## Phase 7: Quantization Engine - -**Goal:** Full quantization support — convert, load, and run quantized models. - -### Deliverables - -1. **MLX native quantization** (`lib/mlx_lm/convert.rb`) - - `convert(model, config, quantize:)` → quantized model + config - - Affine quantization (default) — 2, 4, 8-bit - - MXFP4, MXFP8, NVFP4 modes - - Per-layer quantization config (skip embeddings, lm_head) - -2. **AWQ support** (`lib/mlx_lm/quant/awq.rb`) - - `AwqQuantizer` — activation-aware weight quantization - - `_transform_awq_weights` — convert AutoAWQ format to MLX format - -3. **GPTQ support** (`lib/mlx_lm/quant/gptq.rb`) - - `GptqQuantizer` — Hessian-based quantization - -4. **Model saving** (`lib/mlx_lm/utils.rb`) - - `save_model(path, model, tokenizer, config)` — safetensors + config.json - - Shard large models into multiple files - -### Parity Tests — Phase 7 - -| # | Test | Method | -|---|------|--------| -| 7.1 | Affine quantize → dequantize round-trip matches Python | `atol=0.5` (quantization noise) | -| 7.2 | 4-bit quantized Linear forward matches Python | `atol=1e-2` | -| 7.3 | AWQ weight transform produces identical packed weights | Exact match | -| 7.4 | Quantized model generation output matches Python | String match (greedy) | -| 7.5 | Saved quantized model loads back identically | Weight match | -| 7.6 | MXFP4 quantization matches Python | `atol=1e-1` | -| 7.7 | Per-layer quant config (skip lm_head) applied identically | Config match | - ---- - -## Phase 8: Model Architectures (Batch 2 — Extended) - -**Goal:** Expand architecture coverage to include advanced and specialized models. - -### Deliverables - -1. **Deepseek V2/V3** — MLA attention, MoE with shared experts -2. **Falcon** — multi-query attention, alibi -3. **StableLM** — partial rotary embeddings -4. **Qwen2-MoE** — shared expert + fine-grained MoE -5. **Gemma3** — interleaved local/global attention -6. **Mamba / Mamba2** — SSM (state-space model), `ArraysCache` -7. **RWKV7** — linear attention / RNN hybrid -8. **RecurrentGemma** — recurrent variant -9. **Grok** — large MoE -10. **OpenELM** — per-layer scaling - -### Parity Tests — Phase 8 - -| # | Test | Method | -|---|------|--------| -| 8.1 | Each model: forward pass → logits match Python | `atol=1e-4` | -| 8.2 | Mamba: SSM state update matches Python | `atol=1e-5` | -| 8.3 | Deepseek MLA attention matches Python | `atol=1e-4` | -| 8.4 | MoE routing + output for each MoE model matches | `atol=1e-4` | -| 8.5 | Each model: greedy generation matches Python | Token match | - ---- - -## Phase 9: LoRA & Fine-Tuning - -**Goal:** Full LoRA fine-tuning pipeline with training loop. - -### Deliverables - -1. **LoRA layers** (`lib/mlx_lm/tuner/lora.rb`) - - `LoRALinear` — low-rank adaptation of linear layers - - `LoRASwitchLinear` — MoE-compatible LoRA - - `LoRAEmbedding` — embedding LoRA - - `apply_lora(model, config)` — patch model layers with LoRA - -2. **Training framework** (`lib/mlx_lm/tuner/trainer.rb`) - - `TrainingArgs` data class (batch_size, epochs, lr, etc.) - - `train(model, tokenizer, args, train_dataset, val_dataset)` — main loop - - `evaluate(model, dataset)` — validation - - Cross-entropy loss with padding mask - - Gradient accumulation - - Checkpointing (save/resume) - -3. **Dataset loading** (`lib/mlx_lm/tuner/datasets.rb`) - - Load JSONL/JSON training data - - Chat-formatted dataset → token sequences - - Completion-formatted dataset - - Batch iterator with padding - -4. **Adapter fusion** (`lib/mlx_lm/fuse.rb`) - - `fuse_model(model, adapter_path)` — merge LoRA weights into base - - De-quantize before fusion option - - Save fused model - -### Parity Tests — Phase 9 - -| # | Test | Method | -|---|------|--------| -| 9.1 | LoRALinear forward matches Python for same weights | `atol=1e-5` | -| 9.2 | LoRA gradient computation matches Python | `atol=1e-4` | -| 9.3 | One training step produces identical weight updates | `atol=1e-4` | -| 9.4 | Training loss curve matches Python over 50 steps | `atol=1e-2` per step | -| 9.5 | Fused model weights match Python fusion | `atol=1e-5` | -| 9.6 | Fused model generation matches Python | String match (greedy) | -| 9.7 | Checkpoint save/load round-trip preserves training state | Exact match | - ---- - -## Phase 10: CLI & OpenAI-Compatible Server - -**Goal:** Feature-complete CLI and HTTP server matching Python's interface. - -### Deliverables - -1. **CLI** (`lib/mlx_lm/cli.rb`, `exe/mlx_lm`) - - `mlx_lm generate` — text generation - - `mlx_lm chat` — interactive REPL - - `mlx_lm convert` — model conversion/quantization - - `mlx_lm lora` — LoRA fine-tuning - - `mlx_lm fuse` — adapter fusion - - `mlx_lm server` — start HTTP server - - `mlx_lm benchmark` — performance benchmarking - - `mlx_lm evaluate` — model evaluation - - `mlx_lm manage` — download/cache management - - Argument parsing matching Python's argparse interface - -2. **HTTP Server** (`lib/mlx_lm/server.rb`) - - `POST /v1/completions` — text completion - - `POST /v1/chat/completions` — chat completion - - `GET /v1/models` — list models - - Server-Sent Events streaming - - Request queuing - - OpenAI-compatible request/response schema - - Built on `webrick` or `puma` - -3. **Chat REPL** (`lib/mlx_lm/chat.rb`) - - Interactive conversation with history - - System prompt support - - Streaming output to terminal - -### Parity Tests — Phase 10 - -| # | Test | Method | -|---|------|--------| -| 10.1 | CLI `generate` output matches Python CLI | String match | -| 10.2 | CLI argument parsing accepts all Python flags | Flag coverage check | -| 10.3 | Server `/v1/chat/completions` response schema matches OpenAI spec | Schema validation | -| 10.4 | Server streaming response matches Python server | Chunk-by-chunk match | -| 10.5 | Server non-streaming response body matches Python | JSON diff | -| 10.6 | Chat REPL multi-turn context matches Python | Token sequence match | - ---- - -## Phase 11: Model Architectures (Batch 3 — Long Tail) - -**Goal:** Complete coverage of remaining architectures. - -### Deliverables - -All remaining models from `mlx-lm/mlx_lm/models/`, including: - -1. **Transformer variants:** OLMo, OLMoE, Nemotron, Exaone, InternLM2, Minicpm, - GraniteSmall, GraniteMoE, Dbrx, Jamba, Arctic, Telechat, Hunyuan, Solar, PlaMo -2. **Multimodal text backbones:** Qwen3.5-text, Llama4-text, Gemma3-text, - LFM2-VL-text, Kimi-VL text encoder -3. **Specialized:** BitNet (ternary weights), Phixtral, Ministral3 - -Each follows the standard pattern and is registered in `MODEL_REGISTRY`. - -### Parity Tests — Phase 11 - -| # | Test | Method | -|---|------|--------| -| 11.1 | Each model: forward pass → logits match | `atol=1e-4` | -| 11.2 | Each model: greedy generation matches Python | Token match | -| 11.3 | BitNet ternary weight handling matches Python | Exact weight match | - ---- - -## Phase 12: Advanced Features & Polish - -**Goal:** Implement remaining advanced capabilities and achieve full feature parity. - -### Deliverables - -1. **Prompt caching** - - `cache_prompt(model, tokenizer, prompt)` → saved cache file - - Load cached prompt for fast repeated inference - -2. **Speculative decoding** - - Draft model integration in generation loop - - Accept/reject logic - -3. **Distributed inference** - - Pipeline parallelism (`pipeline_load`) - - Tensor parallelism (`sharded_load`) - - Distributed generation with `MLX::Distributed` - -4. **Evaluation** - - Perplexity computation (`lib/mlx_lm/perplexity.rb`) - - Log-likelihood scoring - -5. **Benchmarking** - - Tokens/sec measurement - - Memory tracking - - Prompt processing vs generation speed - -6. **GGUF support** - - GGUF tokenizer loading - - GGUF weight export - -7. **Model sharing / upload** - - `upload(path, repo_id)` to HuggingFace Hub - - Model card generation - -### Parity Tests — Phase 12 - -| # | Test | Method | -|---|------|--------| -| 12.1 | Cached prompt generation matches uncached | String match | -| 12.2 | Perplexity score matches Python within 0.1% | Numerical | -| 12.3 | Speculative decoding output matches standard generation | String match | -| 12.4 | GGUF export produces identical file | Byte-level or weight match | -| 12.5 | Benchmark tokens/sec within 20% of Python (same hardware) | Approximate | - ---- - -## Parity Test Infrastructure - -### Test Harness Design - -```ruby -# test/test_helper.rb - -module ParityTest - # Runs a Python snippet, captures output as JSON, returns parsed result - def python_eval(code) - result = `python3 -c "#{code}"` - JSON.parse(result) - end - - # Compares MLX arrays between Ruby and Python - def assert_array_parity(ruby_array, python_values, atol: 1e-5) - ruby_values = ruby_array.tolist - assert_in_delta_array(ruby_values, python_values, atol) - end - - # Compares generation output - def assert_generation_parity(ruby_text, python_text) - assert_equal python_text.strip, ruby_text.strip - end -end -``` - -### Running Parity Tests - -```bash -# Run all parity tests -bundle exec rake test:parity - -# Run parity tests for a specific phase -bundle exec rake test:parity[phase3] - -# Run parity tests for a specific model -bundle exec rake test:parity[llama] - -# Generate parity report -bundle exec rake test:parity_report -``` - -### CI Integration - -- Each phase PR must pass all parity tests for that phase and prior phases -- Parity test failures block merge -- Nightly full-suite parity run against latest `mlx-lm` release - ---- - -## Dependency Strategy - -| Python Dependency | Ruby Equivalent | Notes | -|---|---|---| -| `mlx` | `mlx` gem (mlx-ruby) | Already available, feature-complete | -| `transformers` | `tokenizers` gem + custom code | HF tokenizers Rust bindings for Ruby | -| `numpy` | `mlx` gem array ops | MLX covers all needed operations | -| `huggingface_hub` | HTTP API + `down` gem or custom | Direct HF Hub HTTP API calls | -| `sentencepiece` | `tokenizers` gem | Covers SentencePiece models | -| `jinja2` | `liquid` or custom Jinja subset | For chat templates | -| `pyyaml` | `yaml` (stdlib) | Built into Ruby | -| `tqdm` | `ruby-progressbar` | Progress bars | -| `protobuf` | `google-protobuf` gem | If needed for legacy tokenizers | - ---- - -## Milestone Summary - -| Phase | Scope | Est. Files | Cumulative Parity Tests | -|-------|-------|-----------|------------------------| -| 1 | Scaffolding & infrastructure | ~10 | 5 | -| 2 | Tokenizer | ~5 | 11 | -| 3 | KV Cache & Llama | ~5 | 19 | -| 4 | Generation engine | ~3 | 30 | -| 5 | Model loading & HF integration | ~4 | 36 | -| 6 | 7 popular architectures | ~7 | 43 | -| 7 | Quantization | ~5 | 50 | -| 8 | 10 extended architectures | ~10 | 55 | -| 9 | LoRA & fine-tuning | ~6 | 62 | -| 10 | CLI & server | ~5 | 68 | -| 11 | ~30 remaining architectures | ~30 | 71 | -| 12 | Advanced features & polish | ~8 | 76 | - -**After Phase 5**, the gem is usable for basic inference with Llama models. -**After Phase 7**, quantized inference works for 8 architectures. -**After Phase 10**, the gem is a full drop-in replacement for common use cases. -**After Phase 12**, the gem has 100% feature parity with `mlx-lm`. - ---- - -## Key Design Decisions - -1. **Idiomatic Ruby** — Use Ruby conventions (snake_case, blocks/procs for callbacks, - `Enumerable` for streaming, keyword arguments) rather than transliterating Python. - -2. **Module pattern for models** — Each model is a `MLX::NN::Module` subclass, matching - how `mlx-ruby` already structures neural network layers. - -3. **Lazy evaluation preserved** — Follow MLX's lazy computation model; only `eval` - when results are needed. - -4. **Tokenizer strategy** — Primary: `tokenizers` gem (Rust HF bindings). Fallback: - Python subprocess for exotic tokenizers. Goal: eliminate Python fallback by Phase 12. - -5. **No monkey-patching** — Clean module hierarchy under `MlxLm::` namespace. - -6. **Progressive testing** — Every phase is independently testable. Earlier phases - don't depend on later phases being complete. - ---- - -## MLX-Ruby Gaps & Workarounds - -The following issues were discovered during implementation. Each required a -workaround in the `mlx-ruby-lm` codebase. - -### API Issues - -| # | Issue | Severity | Workaround | -|---|-------|----------|------------| -| 1 | `mx.array(values, dtype: ...)` raises `ArgumentError` — the `dtype:` keyword is rejected even though the error message says it accepts `Dtype`, symbol, or string | High | Create float32 array first, then `.astype(mx.int32)` etc. | -| 2 | `mx.mean(x, axis, keepdims: true)` not supported — no `keepdims` parameter on `mean` | Medium | `mx.expand_dims(mx.mean(x, axis), -1)` | -| 3 | `MLX::NN::Dropout.new(p: 0.5)` rejects keyword arg — constructor only accepts positional `Dropout.new(0.5)` | Low | Use positional argument | -| 4 | `mx.random_uniform` only works with float dtypes — passing `mx.int32` raises error | Low | Generate as float32, then `.astype(mx.int32)` | -| 5 | `mx.save_safetensors` not available — MLX not compiled with `MLX_BUILD_SAFETENSORS=ON` | Medium | Use Ruby `safetensors` gem for serialization | -| 6 | No `SwitchGLU` layer — Python mlx-lm uses it for efficient stacked MoE experts | Medium | Per-token expert routing loop (functional but slower) | - -### Ruby ↔ MLX Coercion Issues - -| # | Issue | Severity | Workaround | -|---|-------|----------|------------| -| 7 | `Float * MLX::Array` raises `TypeError` — Ruby `Float#*` can't coerce MLX arrays | High | Always put MLX array on the left: `array * scalar` | -| 8 | `Float + MLX::Array` raises `TypeError` — same coercion issue with addition | High | `array + scalar` instead of `scalar + array` | -| 9 | No unary negation `-array` | Low | `MLX::Core.negative(array)` | -| 10 | No comparison operators `>`, `<` on arrays | Low | `MLX::Core.greater()`, `MLX::Core.less()` | - -### NN Module System - -| # | Issue | Severity | Workaround | -|---|-------|----------|------------| -| 11 | Instance variables invisible to Module traversal — `@x = Module.new(...)` doesn't register children in `@state`, so `children`, `leaf_modules`, `parameters`, `load_weights`, and `nn.quantize` all miss them | **Critical** | Must use `self.x = Module.new(...)` (goes through `method_missing` → `@state`) for every child module in every model. Required refactoring all model files. | -| 12 | `update_modules_impl` missing Module→Hash recursion — when `nn.quantize` replaces layers, the code didn't handle `current_value=Module` + `new_value=Hash` | High | Patched `mlx-ruby/lib/mlx/nn/base.rb` to add `elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) \|\| new_value.is_a?(Array))` branch for recursive descent. | - -### Safetensors Gem Compatibility - -| # | Issue | Severity | Workaround | -|---|-------|----------|------------| -| 13 | Dtype string format mismatch — Ruby `safetensors` gem v0.2.2 expects lowercase `"float32"` for serialization but returns uppercase `"F32"` on deserialization | Medium | `weight_utils.rb` accepts both formats: `dtype_str == "F32" \|\| dtype_str == "float32"` | - -### Impact Summary - -- **Issues 7, 8, 11** were the most pervasive — they affected every model file and - many utility modules. Issue 11 alone required rewriting every constructor in every - model architecture. -- **Issue 12** required patching mlx-ruby itself (the only upstream change). -- **Issue 1** affected any code path that creates typed arrays from Ruby values. -- **Issues 9, 10** are minor and only arise in specific model architectures. - -### Recommendations for mlx-ruby Upstream - -1. Fix `mx.array(values, dtype:)` to accept dtype keyword argument -2. Add `keepdims:` parameter to reduction ops (`mean`, `sum`, etc.) -3. Implement Ruby `coerce` on `MLX::Core::Array` so `Float * array` works -4. Compile with `MLX_BUILD_SAFETENSORS=ON` by default -5. Add `SwitchGLU` layer for MoE model support -6. Document the `self.x =` vs `@x =` requirement for Module children prominently -7. Apply the `update_modules_impl` fix below (Issue 12) -8. Add `GreaterEqual` to the mlx-onnx IR lowering pass (Issue 14 below) - -### Required Upstream: ONNX Export — Missing Ops & Tracer Limitations (Issue 14) - -Tested all 13 model architectures for ONNX export via `MLX::ONNX.export_onnx`. -**0/13 models export successfully.** The test file is at `test/parity/onnx_export_test.rb`. - -#### Category 1: Missing `GreaterEqual` op (11 dense models) - -All dense (non-MoE) models fail on a single missing ONNX lowering: -**`GreaterEqual`**, used in the causal attention mask (`>=` comparison). - -| Model | Total Nodes | Supported | Coverage | Missing Ops | -|-------|-------------|-----------|----------|-------------| -| llama | 255 | 253 | 99.2% | `GreaterEqual` | -| gemma | 266 | 264 | 99.2% | `GreaterEqual` | -| gemma2 | 353 | 351 | 99.4% | `GreaterEqual` | -| qwen2 | 261 | 259 | 99.2% | `GreaterEqual` | -| phi3 | 243 | 241 | 99.2% | `GreaterEqual` | -| starcoder2 | 298 | 296 | 99.3% | `GreaterEqual` | -| stablelm | 289 | 287 | 99.3% | `GreaterEqual` | -| cohere | 267 | 265 | 99.3% | `GreaterEqual` | -| olmo2 | 293 | 291 | 99.3% | `GreaterEqual` | -| gpt_neox | 264 | 262 | 99.2% | `GreaterEqual` | -| internlm2 | 245 | 243 | 99.2% | `GreaterEqual` | - -**Fix:** Add `GreaterEqual` → ONNX `GreaterOrEqual` mapping to `mlx-onnx`'s -IR lowering pass. This single op blocks every dense model — only 2 unsupported -nodes out of 240–350 per model. Once added, all 11 dense models should export. - -#### Category 2: Tracer crash — SIGILL (2 MoE models) - -| Model | Signal | Root Cause | -|-------|--------|------------| -| mixtral | SIGILL | `inds.tolist` in `SparseMoeBlock#call` forces materialization during tracing | -| deepseek | SIGILL | `inds.tolist` in `DeepseekMoE#call` forces materialization during tracing | - -Both MoE models use per-token expert routing with `inds.tolist` to extract -expert indices into Ruby arrays, creating **data-dependent control flow** that -the ONNX graph tracer cannot follow. The process crashes (SIGILL) before even -the compatibility report can complete. - -**Fix options:** -1. **(Preferred) Upstream:** Add `SwitchGLU` or a vectorized MoE dispatch - primitive to mlx-ruby so expert routing can be expressed as pure tensor ops - (no Ruby-level iteration over `tolist` indices). -2. **(Workaround) Model-level:** Rewrite the MoE `call` methods to avoid - `tolist` — e.g., use `mx.take` / `mx.scatter` / masked computation so the - entire routing graph stays in MLX and is traceable. - -#### Priority - -| Priority | Action | Impact | -|----------|--------|--------| -| **P0** | Add `GreaterEqual` to mlx-onnx | Unblocks all 11 dense model architectures | -| **P1** | Vectorized MoE dispatch (no tolist) | Unblocks mixtral + deepseek ONNX export | - -### Required Upstream Patch: `update_modules_impl` (Issue 12) - -**File:** `lib/mlx/nn/base.rb` - -When `nn.quantize` (or any `update_modules` call) replaces layers in a model, -it builds a nested Hash/Array tree of replacement modules. The existing code -handles `Module→Module` replacement and `Hash/Array→Hash/Array` recursion, but -misses the case where the current value is a `Module` and the replacement is a -`Hash` or `Array` (meaning "recurse into this Module's children"). Without this -fix, `nn.quantize` fails with `"Received invalid type: Hash"`. - -**Patch** (against commit `148f658`): - -```diff ---- a/lib/mlx/nn/base.rb -+++ b/lib/mlx/nn/base.rb -@@ -323,6 +323,8 @@ module MLX - current_value = dst[k] - if current_value.is_a?(Module) && new_value.is_a?(Module) - dst[k] = new_value -+ elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) || new_value.is_a?(Array)) -+ update_modules_impl(current_value, new_value, strict) - elsif current_value.is_a?(Hash) || current_value.is_a?(Array) - update_modules_impl(current_value, new_value, strict) - elsif strict && new_value != {} -@@ -337,6 +339,8 @@ module MLX - current_value = dst[i] - if current_value.is_a?(Module) && new_value.is_a?(Module) - dst[i] = new_value -+ elsif current_value.is_a?(Module) && (new_value.is_a?(Hash) || new_value.is_a?(Array)) -+ update_modules_impl(current_value, new_value, strict) - elsif current_value.is_a?(Hash) || current_value.is_a?(Array) - update_modules_impl(current_value, new_value, strict) - elsif strict && new_value != {} -``` - -This adds two `elsif` branches (one in the Hash iteration, one in the Array -iteration) that detect when a Module needs recursive descent rather than -direct replacement. The fix has been applied locally in the `mlx-ruby` -submodule at commit `85afc8a`. - diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..7dbbb7b --- /dev/null +++ b/requirements.txt @@ -0,0 +1,3 @@ +mlx>=0.30.6; platform_system == "Darwin" +mlx[cpu]>=0.30.6; platform_system == "Linux" +-e ./mlx-lm diff --git a/skills/phased-prd-red-green/SKILL.md b/skills/phased-prd-red-green/SKILL.md new file mode 100644 index 0000000..5dba6e0 --- /dev/null +++ b/skills/phased-prd-red-green/SKILL.md @@ -0,0 +1,54 @@ +--- +name: phased-prd-red-green +description: Create or update a phased PRD and then execute delivery using a strict red/green workflow. Use when a task spans multiple steps, has cross-cutting impact, needs explicit exit criteria, or requires reliable test-driven implementation sequencing. +--- + +# Phased PRD + Red/Green Delivery + +## Overview + +Use this skill to drive a task from planning to completion with explicit phase gates. + +Read `references/prd_red_green_template.md` before writing a new PRD. + +## Workflow + +1. Confirm scope, success target, and constraints. +2. Decide whether to create/update a PRD. +3. Draft or revise the PRD with phased red/green structure. +4. Execute each phase in order: red -> green -> refactor. +5. Track completion in the checklist and update status. + +## PRD Authoring Rules + +1. Place documents in `prd/`. +2. Use dated file names: `YYYY_MM_DD__prd.md`. +3. Include at minimum: + - Status + - Context + - Goals and non-goals + - Phased plan + - Exit criteria per phase + - Acceptance criteria + - Risks and mitigations + - Implementation checklist +4. Keep phases independently testable and incremental. + +## Red/Green Execution Rules + +1. Write failing tests/checks first for each phase. +2. Record the failing signal (error, unsupported op, mismatch, or failing assertion). +3. Implement the smallest possible change set to pass. +4. Re-run targeted tests, then adjacent regression tests. +5. Refactor only after green and keep semantics unchanged. + +## Evidence And Reporting + +1. Report which commands ran and which did not run. +2. Include key pass/fail outputs for each phase gate. +3. Do not mark a phase complete until exit criteria are verified. +4. Mark PRD as `Completed` only when all checklist items are done. + +## Reference + +Use `references/prd_red_green_template.md` as the default skeleton and execution checklist. diff --git a/skills/phased-prd-red-green/agents/openai.yaml b/skills/phased-prd-red-green/agents/openai.yaml new file mode 100644 index 0000000..470d78b --- /dev/null +++ b/skills/phased-prd-red-green/agents/openai.yaml @@ -0,0 +1,4 @@ +interface: + display_name: "Phased PRD + Red Green" + short_description: "Build phased PRDs and execute red/green delivery" + default_prompt: "Create a phased PRD with clear exit criteria, then execute each phase via red/green testing until complete." diff --git a/skills/phased-prd-red-green/references/prd_red_green_template.md b/skills/phased-prd-red-green/references/prd_red_green_template.md new file mode 100644 index 0000000..a3f9c34 --- /dev/null +++ b/skills/phased-prd-red-green/references/prd_red_green_template.md @@ -0,0 +1,45 @@ +# PRD + Red/Green Template + +## Header + +1. Title +2. Status (`Proposed` | `In Progress` | `Completed`) +3. Date + +## Core Sections + +1. Context +2. Goals +3. Non-goals +4. Scope +5. Risks and mitigations +6. Acceptance criteria + +## Phased Plan (Red/Green) + +For each phase, include: + +1. Objective +2. Red: + - Tests/checks to add first + - Expected failure signal +3. Green: + - Minimal implementation targets + - Verification commands +4. Exit criteria + +## Execution Checklist + +- [ ] Phase 0 complete +- [ ] Phase 1 complete +- [ ] Phase 2 complete +- [ ] Integration/regression checks complete +- [ ] Documentation/status updated + +## Command Log (recommended) + +Track key commands and outcomes: + +1. `command` + - result: `pass|fail|skipped` + - notes: short signal summary diff --git a/tasks/onnx_report_task.rb b/tasks/onnx_report_task.rb new file mode 100644 index 0000000..4b5ab95 --- /dev/null +++ b/tasks/onnx_report_task.rb @@ -0,0 +1,332 @@ +# frozen_string_literal: true + +require "json" +require "open3" +require "pathname" +require "time" + +class OnnxReportTask + ROOT = Pathname.new(__dir__).join("..").expand_path + REPORT_DIR = ROOT.join("test", "reports") + RAW_OUTPUT_PATH = REPORT_DIR.join("onnx_compat_test_output.txt") + JSON_PATH = REPORT_DIR.join("onnx_compat_full_report.json") + MARKDOWN_PATH = REPORT_DIR.join("onnx_compat_full_report.md") + INVOCATIONS_CSV_PATH = REPORT_DIR.join("onnx_compat_missing_ops_invocations.csv") + + COMMAND = [ + "bundle", "exec", "rake", "test", + "TEST=test/onnx/*_test.rb", + "TESTOPTS=--name=/test_onnx_compat_report/" + ].freeze + + STATUS_LINE_RE = /^\s*\[ONNX\]\s+(?[^:]+):\s+(?:(?PASS|FAIL)\s+[—-]\s+)?(?\d+)\/(?\d+)\s+nodes(?:\s+\((?[0-9.]+)%\))?\s+[—-]\s+missing:\s*(?.*?)(?:\s+\((?CRASH during export)\))?\s*$/.freeze + JSON_LINE_RE = /^\s*\[ONNX-JSON\]\s+(?[^:]+):\s+(?\{.*\})\s*$/.freeze + SUMMARY_LINE_RE = /(?\d+)\s+runs,\s+(?\d+)\s+assertions,\s+(?\d+)\s+failures,\s+(?\d+)\s+errors,\s+(?\d+)\s+skips/.freeze + + def self.run! + REPORT_DIR.mkpath + output, status = run_compat_suite + RAW_OUTPUT_PATH.write(output) + + models, test_summary, warnings = parse_output(output) + model_rows = models.values.sort_by { |entry| entry["model_type"] } + + unsupported_union = model_rows + .flat_map { |entry| entry.fetch("missing_ops", []) } + .uniq + .sort + + missing_op_model_counts = Hash.new(0) + model_rows.each do |entry| + entry.fetch("missing_ops", []).uniq.each { |op| missing_op_model_counts[op] += 1 } + end + + summary = { + "models_total" => model_rows.length, + "models_with_missing_ops" => model_rows.count { |entry| entry.fetch("missing_ops", []).any? }, + "pass_count" => model_rows.count { |entry| entry["status"] == "PASS" }, + "fail_count" => model_rows.count { |entry| entry["status"] == "FAIL" }, + "crash_count" => model_rows.count { |entry| entry["status"] == "CRASH" }, + "unknown_count" => model_rows.count { |entry| entry["status"] == "UNKNOWN" }, + "unsupported_ops_union_size" => unsupported_union.length + } + + report_payload = { + "generated_at" => Time.now.utc.iso8601, + "command" => COMMAND.map { |value| shell_escape(value) }.join(" "), + "exit_status" => status.exitstatus, + "test_summary" => test_summary, + "summary" => summary, + "unsupported_ops_union" => unsupported_union, + "missing_op_model_counts" => missing_op_model_counts.sort.to_h, + "models" => model_rows, + "warnings" => warnings + } + + JSON_PATH.write("#{JSON.pretty_generate(report_payload)}\n") + write_invocation_csv(INVOCATIONS_CSV_PATH, model_rows) + MARKDOWN_PATH.write(render_markdown(report_payload)) + + puts "\nWrote reports:" + puts "- #{RAW_OUTPUT_PATH}" + puts "- #{JSON_PATH}" + puts "- #{INVOCATIONS_CSV_PATH}" + puts "- #{MARKDOWN_PATH}" + puts "Models: #{summary['models_total']} | Missing-op models: #{summary['models_with_missing_ops']} | Unsupported ops: #{unsupported_union.join(', ')}" + + return report_payload if status.success? + + exit_code = status.exitstatus || 1 + raise "ONNX compat suite failed with exit code #{exit_code}" + end + + class << self + private + + def parse_missing_ops(value) + text = value.to_s.strip + return [] if text.empty? || text == "none" + + text.split(",").map(&:strip).reject(&:empty?).uniq.sort + end + + def shell_escape(value) + "'" + value.to_s.gsub("'", %q('"'"')) + "'" + end + + def markdown_escape(value) + value.to_s.gsub("|", "\\|") + end + + def extract_invocations(report) + return [] unless report.is_a?(Hash) + + nodes = report["nodes"] + return [] unless nodes.is_a?(Array) + + nodes.filter_map do |node| + next unless node.is_a?(Hash) + next unless node["supported"] == false + + { + "index" => node["index"], + "op" => node["op"], + "onnx_op_type" => node["onnx_op_type"] + } + end + end + + def run_compat_suite + env = { + "ONNX_COMPAT_REPORT_JSON" => "1", + "ONNX_LOG_LINES" => "1" + } + output = +"" + status = nil + + Open3.popen2e(env, *COMMAND, chdir: ROOT.to_s) do |stdin, stdout_and_stderr, wait_thr| + stdin.close + stdout_and_stderr.each_line do |line| + unless line.include?("[ONNX-JSON]") || line.include?("[ONNX-INV]") + print line + end + output << line + end + status = wait_thr.value + end + + [output, status] + end + + def parse_output(output) + models = {} + warnings = [] + summary = nil + + output.each_line do |line| + if (m = line.match(STATUS_LINE_RE)) + model = m[:model].strip + models[model] ||= { "model_type" => model } + status = if m[:suffix] == "CRASH during export" + "CRASH" + else + m[:status] || "UNKNOWN" + end + models[model]["status"] = status + models[model]["supported_nodes"] = m[:supported].to_i + models[model]["total_nodes"] = m[:total].to_i + models[model]["coverage_percent"] = if m[:total].to_i.positive? + ((m[:supported].to_f / m[:total].to_f) * 100).round(1) + else + 0.0 + end + models[model]["missing_ops"] = parse_missing_ops(m[:missing]) + next + end + + if (m = line.match(JSON_LINE_RE)) + model = m[:model].strip + models[model] ||= { "model_type" => model } + begin + models[model]["compat_report"] = JSON.parse(m[:payload]) + rescue JSON::ParserError => e + warnings << "#{model}: failed to parse ONNX-JSON line (#{e.message})" + end + next + end + + if (m = line.match(SUMMARY_LINE_RE)) + summary = { + "runs" => m[:runs].to_i, + "assertions" => m[:assertions].to_i, + "failures" => m[:failures].to_i, + "errors" => m[:errors].to_i, + "skips" => m[:skips].to_i + } + end + end + + models.each_value do |entry| + report = entry["compat_report"] + if report.is_a?(Hash) + entry["supported_nodes"] ||= report["supported_nodes"] + entry["total_nodes"] ||= report["total_nodes"] + entry["missing_ops"] ||= Array(report["unsupported_ops"]).map(&:to_s).sort + if entry["total_nodes"].to_i.positive? + entry["coverage_percent"] ||= ((entry["supported_nodes"].to_f / entry["total_nodes"].to_f) * 100).round(1) + else + entry["coverage_percent"] ||= 0.0 + end + else + entry["missing_ops"] ||= [] + end + + entry["status"] ||= "UNKNOWN" + entry["unsupported_invocations"] = extract_invocations(report) + end + + [models, summary, warnings] + end + + def write_invocation_csv(path, model_rows) + escape = lambda do |value| + text = value.nil? ? "" : value.to_s + if text.include?(",") || text.include?("\"") || text.include?("\n") + "\"#{text.gsub("\"", "\"\"")}\"" + else + text + end + end + + lines = [] + lines << "model_type,status,missing_op,onnx_op_type,node_index" + model_rows.each do |row| + invocations = row.fetch("unsupported_invocations", []) + if invocations.empty? + row.fetch("missing_ops", []).each do |op| + values = [row["model_type"], row["status"], op, nil, nil] + lines << values.map { |v| escape.call(v) }.join(",") + end + else + invocations.each do |inv| + values = [row["model_type"], row["status"], inv["op"], inv["onnx_op_type"], inv["index"]] + lines << values.map { |v| escape.call(v) }.join(",") + end + end + end + + path.write("#{lines.join("\n")}\n") + end + + def markdown_table_row(values) + "| #{values.map { |v| markdown_escape(v) }.join(' | ')} |" + end + + def render_markdown(report_payload) + model_rows = report_payload.fetch("models") + invocation_rows = model_rows.flat_map do |model| + model.fetch("unsupported_invocations", []).map do |inv| + [model["model_type"], inv["op"], inv["onnx_op_type"], inv["index"]] + end + end + + lines = [] + lines << "# ONNX Compat Report" + lines << "" + lines << "- Generated at: `#{report_payload.fetch("generated_at")}`" + lines << "- Command: `#{report_payload.fetch("command")}`" + lines << "- Exit status: `#{report_payload.fetch("exit_status")}`" + lines << "- Models: `#{report_payload.dig("summary", "models_total")}`" + lines << "- Models with missing ops: `#{report_payload.dig("summary", "models_with_missing_ops")}`" + lines << "- Unsupported op union size: `#{report_payload.dig("summary", "unsupported_ops_union_size")}`" + lines << "" + + test_summary = report_payload["test_summary"] + if test_summary + lines << "## Test Summary" + lines << "" + lines << markdown_table_row(%w[Runs Assertions Failures Errors Skips]) + lines << markdown_table_row(%w[--- --- --- --- ---]) + lines << markdown_table_row([ + test_summary["runs"], + test_summary["assertions"], + test_summary["failures"], + test_summary["errors"], + test_summary["skips"] + ]) + lines << "" + end + + lines << "## Per-Model Coverage" + lines << "" + lines << markdown_table_row(["Model", "Status", "Supported/Total", "Coverage %", "Missing Ops"]) + lines << markdown_table_row(%w[--- --- --- --- ---]) + model_rows.each do |row| + lines << markdown_table_row([ + row["model_type"], + row["status"], + "#{row["supported_nodes"]}/#{row["total_nodes"]}", + row["coverage_percent"], + row.fetch("missing_ops", []).join(", ") + ]) + end + lines << "" + + lines << "## Unsupported Ops Union" + lines << "" + unsupported_union = report_payload.fetch("unsupported_ops_union") + if unsupported_union.empty? + lines << "none" + else + unsupported_union.each do |op| + count = report_payload.dig("missing_op_model_counts", op).to_i + lines << "- `#{op}`: #{count} model(s)" + end + end + lines << "" + + lines << "## Unsupported Node Invocations" + lines << "" + if invocation_rows.empty? + lines << "none" + else + lines << markdown_table_row(["Model", "Op", "ONNX op type", "Node index"]) + lines << markdown_table_row(%w[--- --- --- ---]) + invocation_rows.each do |row| + lines << markdown_table_row(row) + end + end + lines << "" + + warnings = report_payload.fetch("warnings", []) + unless warnings.empty? + lines << "## Warnings" + lines << "" + warnings.each { |warning| lines << "- #{warning}" } + lines << "" + end + + lines.join("\n") + end + end +end diff --git a/tasks/parity_inventory_task.rb b/tasks/parity_inventory_task.rb new file mode 100644 index 0000000..d6426b6 --- /dev/null +++ b/tasks/parity_inventory_task.rb @@ -0,0 +1,150 @@ +# frozen_string_literal: true + +require "json" +require "pathname" + +class ParityInventoryTask + ROOT = Pathname.new(__dir__).join("..").expand_path + DEFAULT_OUTPUT = ROOT.join("test", "reports", "python_ruby_parity_inventory_snapshot.json") + + PY_MODELS_DIR = ROOT.join("mlx-lm", "mlx_lm", "models") + RB_MODELS_DIR = ROOT.join("lib", "mlx_lm", "models") + RB_REGISTRY_FILE = ROOT.join("lib", "mlx_lm", "models.rb") + + PY_INFRA_FILES = %w[ + activations.py + base.py + bitlinear_layers.py + cache.py + gated_delta.py + mla.py + pipeline.py + rope_utils.py + ssm.py + switch_layers.py + ].freeze + + RB_INFRA_FILES = %w[ + cache.rb + switch_layers.rb + ].freeze + + def self.run!(check: false, output: DEFAULT_OUTPUT) + output_path = resolve_output(output) + snapshot = build_snapshot + + if check + check_snapshot(output_path, snapshot) + else + write_snapshot(output_path, snapshot) + puts "wrote parity inventory snapshot: #{output_path}" + true + end + end + + class << self + private + + def resolve_output(path) + output = path.is_a?(Pathname) ? path : Pathname.new(path.to_s) + output.relative? ? ROOT.join(output) : output + end + + def model_files(dir, ext) + Dir.glob(dir.join("*.#{ext}").to_s) + .map { |path| File.basename(path) } + .reject { |name| name == "__init__.py" } + .sort + end + + def parse_registered_model_keys + keys = [] + Dir.glob(RB_MODELS_DIR.join("*.rb").to_s).sort.each do |path| + File.read(path).scan(/Models\.register\("([^"]+)"/) do |match| + keys << match.first + end + end + keys.uniq.sort + end + + def parse_remappings + content = File.read(RB_REGISTRY_FILE) + remap_block = content[/REMAPPING\s*=\s*\{(.*?)\}\s*\.freeze/m, 1] + return {} if remap_block.nil? + + remappings = {} + remap_block.scan(/"([^"]+)"\s*=>\s*"([^"]+)"/) do |from, to| + remappings[from] = to + end + remappings + end + + def build_snapshot + py_all = model_files(PY_MODELS_DIR, "py") + py_arch = py_all - PY_INFRA_FILES + py_arch_keys = py_arch.map { |name| name.sub(/\.py\z/, "") } + + rb_all = model_files(RB_MODELS_DIR, "rb") + rb_arch = rb_all - RB_INFRA_FILES + rb_arch_keys = rb_arch.map { |name| name.sub(/\.rb\z/, "") } + rb_registered = parse_registered_model_keys + remappings = parse_remappings + missing_architecture_files = (py_arch_keys - rb_registered).sort + extra_registered_model_keys = (rb_registered - py_arch_keys).sort + + { + "inventory_version" => 1, + "source_paths" => { + "python_models_dir" => "mlx-lm/mlx_lm/models", + "ruby_models_dir" => "lib/mlx_lm/models", + "ruby_registry_file" => "lib/mlx_lm/models.rb" + }, + "python" => { + "model_files_total" => py_all.length, + "shared_infra_files" => PY_INFRA_FILES, + "architecture_files_total" => py_arch.length, + "architecture_files" => py_arch + }, + "ruby" => { + "model_files_total" => rb_all.length, + "shared_infra_files" => RB_INFRA_FILES, + "architecture_files_total" => rb_arch.length, + "architecture_files" => rb_arch, + "architecture_model_keys_total" => rb_arch_keys.length, + "architecture_model_keys" => rb_arch_keys, + "registered_model_keys_total" => rb_registered.length, + "registered_model_keys" => rb_registered, + "remappings_total" => remappings.length, + "remappings" => remappings + }, + "parity" => { + "missing_architecture_file_count" => missing_architecture_files.length, + "missing_architecture_files" => missing_architecture_files, + "extra_registered_model_keys_count" => extra_registered_model_keys.length, + "extra_registered_model_keys" => extra_registered_model_keys + } + } + end + + def write_snapshot(output, snapshot) + output.dirname.mkpath + output.write("#{JSON.pretty_generate(snapshot)}\n") + end + + def check_snapshot(output, snapshot) + unless output.exist? + warn "snapshot file missing: #{output}" + return false + end + + current = JSON.parse(output.read) + if current == snapshot + puts "parity inventory snapshot is up-to-date" + true + else + warn "parity inventory snapshot is stale: #{output}" + false + end + end + end +end diff --git a/test/onnx/afm7_test.rb b/test/onnx/afm7_test.rb new file mode 100644 index 0000000..9c16980 --- /dev/null +++ b/test/onnx/afm7_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportAfm7Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("afm7") + end + + def test_onnx_compat_report + assert_onnx_compat_report("afm7") + end +end diff --git a/test/onnx/afmoe_test.rb b/test/onnx/afmoe_test.rb new file mode 100644 index 0000000..0d48a90 --- /dev/null +++ b/test/onnx/afmoe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportAfmoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("afmoe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("afmoe") + end +end diff --git a/test/onnx/apertus_test.rb b/test/onnx/apertus_test.rb new file mode 100644 index 0000000..83862cb --- /dev/null +++ b/test/onnx/apertus_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportApertusTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("apertus") + end + + def test_onnx_compat_report + assert_onnx_compat_report("apertus") + end +end diff --git a/test/onnx/baichuan_m1_test.rb b/test/onnx/baichuan_m1_test.rb new file mode 100644 index 0000000..af013a0 --- /dev/null +++ b/test/onnx/baichuan_m1_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportBaichuanM1Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("baichuan_m1") + end + + def test_onnx_compat_report + assert_onnx_compat_report("baichuan_m1") + end +end diff --git a/test/onnx/bailing_moe_linear_test.rb b/test/onnx/bailing_moe_linear_test.rb new file mode 100644 index 0000000..7102eb5 --- /dev/null +++ b/test/onnx/bailing_moe_linear_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportBailingMoeLinearTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("bailing_moe_linear") + end + + def test_onnx_compat_report + assert_onnx_compat_report("bailing_moe_linear") + end +end diff --git a/test/onnx/bailing_moe_test.rb b/test/onnx/bailing_moe_test.rb new file mode 100644 index 0000000..3a4c0db --- /dev/null +++ b/test/onnx/bailing_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportBailingMoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("bailing_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("bailing_moe") + end +end diff --git a/test/onnx/bitnet_test.rb b/test/onnx/bitnet_test.rb new file mode 100644 index 0000000..add277d --- /dev/null +++ b/test/onnx/bitnet_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportBitnetTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("bitnet") + end + + def test_onnx_compat_report + assert_onnx_compat_report("bitnet") + end +end diff --git a/test/onnx/cohere2_test.rb b/test/onnx/cohere2_test.rb new file mode 100644 index 0000000..bfc175b --- /dev/null +++ b/test/onnx/cohere2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportCohere2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("cohere2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("cohere2") + end +end diff --git a/test/onnx/cohere_test.rb b/test/onnx/cohere_test.rb new file mode 100644 index 0000000..b6e3dca --- /dev/null +++ b/test/onnx/cohere_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportCohereTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("cohere") + end + + def test_onnx_compat_report + assert_onnx_compat_report("cohere") + end +end diff --git a/test/onnx/dbrx_test.rb b/test/onnx/dbrx_test.rb new file mode 100644 index 0000000..8c42b2b --- /dev/null +++ b/test/onnx/dbrx_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportDbrxTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("dbrx") + end + + def test_onnx_compat_report + assert_onnx_compat_report("dbrx") + end +end diff --git a/test/onnx/deepseek_test.rb b/test/onnx/deepseek_test.rb new file mode 100644 index 0000000..a58fe8d --- /dev/null +++ b/test/onnx/deepseek_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportDeepseekTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("deepseek") + end + + def test_onnx_compat_report + assert_onnx_compat_report("deepseek") + end +end diff --git a/test/onnx/deepseek_v2_test.rb b/test/onnx/deepseek_v2_test.rb new file mode 100644 index 0000000..9e82189 --- /dev/null +++ b/test/onnx/deepseek_v2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportDeepseekV2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("deepseek_v2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("deepseek_v2") + end +end diff --git a/test/onnx/deepseek_v32_test.rb b/test/onnx/deepseek_v32_test.rb new file mode 100644 index 0000000..3453ae8 --- /dev/null +++ b/test/onnx/deepseek_v32_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportDeepseekV32Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("deepseek_v32") + end + + def test_onnx_compat_report + assert_onnx_compat_report("deepseek_v32") + end +end diff --git a/test/onnx/deepseek_v3_test.rb b/test/onnx/deepseek_v3_test.rb new file mode 100644 index 0000000..88e6eca --- /dev/null +++ b/test/onnx/deepseek_v3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportDeepseekV3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("deepseek_v3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("deepseek_v3") + end +end diff --git a/test/onnx/dots1_test.rb b/test/onnx/dots1_test.rb new file mode 100644 index 0000000..3f30a32 --- /dev/null +++ b/test/onnx/dots1_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportDots1Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("dots1") + end + + def test_onnx_compat_report + assert_onnx_compat_report("dots1") + end +end diff --git a/test/onnx/ernie4_5_moe_test.rb b/test/onnx/ernie4_5_moe_test.rb new file mode 100644 index 0000000..59049dc --- /dev/null +++ b/test/onnx/ernie4_5_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportErnie45MoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("ernie4_5_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("ernie4_5_moe") + end +end diff --git a/test/onnx/ernie4_5_test.rb b/test/onnx/ernie4_5_test.rb new file mode 100644 index 0000000..760613f --- /dev/null +++ b/test/onnx/ernie4_5_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportErnie45Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("ernie4_5") + end + + def test_onnx_compat_report + assert_onnx_compat_report("ernie4_5") + end +end diff --git a/test/onnx/exaone4_test.rb b/test/onnx/exaone4_test.rb new file mode 100644 index 0000000..060fc39 --- /dev/null +++ b/test/onnx/exaone4_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportExaone4Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("exaone4") + end + + def test_onnx_compat_report + assert_onnx_compat_report("exaone4") + end +end diff --git a/test/onnx/exaone_moe_test.rb b/test/onnx/exaone_moe_test.rb new file mode 100644 index 0000000..1e01f55 --- /dev/null +++ b/test/onnx/exaone_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportExaoneMoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("exaone_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("exaone_moe") + end +end diff --git a/test/onnx/exaone_test.rb b/test/onnx/exaone_test.rb new file mode 100644 index 0000000..4929c1f --- /dev/null +++ b/test/onnx/exaone_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportExaoneTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("exaone") + end + + def test_onnx_compat_report + assert_onnx_compat_report("exaone") + end +end diff --git a/test/onnx/falcon_h1_test.rb b/test/onnx/falcon_h1_test.rb new file mode 100644 index 0000000..f0abf4a --- /dev/null +++ b/test/onnx/falcon_h1_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportFalconH1Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("falcon_h1") + end + + def test_onnx_compat_report + assert_onnx_compat_report("falcon_h1") + end +end diff --git a/test/onnx/gemma2_test.rb b/test/onnx/gemma2_test.rb new file mode 100644 index 0000000..3308a46 --- /dev/null +++ b/test/onnx/gemma2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGemma2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gemma2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gemma2") + end +end diff --git a/test/onnx/gemma3_test.rb b/test/onnx/gemma3_test.rb new file mode 100644 index 0000000..c123d12 --- /dev/null +++ b/test/onnx/gemma3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGemma3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gemma3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gemma3") + end +end diff --git a/test/onnx/gemma3_text_test.rb b/test/onnx/gemma3_text_test.rb new file mode 100644 index 0000000..674f304 --- /dev/null +++ b/test/onnx/gemma3_text_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGemma3TextTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gemma3_text") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gemma3_text") + end +end diff --git a/test/onnx/gemma3n_test.rb b/test/onnx/gemma3n_test.rb new file mode 100644 index 0000000..d89970a --- /dev/null +++ b/test/onnx/gemma3n_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGemma3nTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gemma3n") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gemma3n") + end +end diff --git a/test/onnx/gemma_test.rb b/test/onnx/gemma_test.rb new file mode 100644 index 0000000..334ff04 --- /dev/null +++ b/test/onnx/gemma_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGemmaTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gemma") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gemma") + end +end diff --git a/test/onnx/glm4_moe_lite_test.rb b/test/onnx/glm4_moe_lite_test.rb new file mode 100644 index 0000000..0ee4293 --- /dev/null +++ b/test/onnx/glm4_moe_lite_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGlm4MoeLiteTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("glm4_moe_lite") + end + + def test_onnx_compat_report + assert_onnx_compat_report("glm4_moe_lite") + end +end diff --git a/test/onnx/glm4_moe_test.rb b/test/onnx/glm4_moe_test.rb new file mode 100644 index 0000000..a133bb2 --- /dev/null +++ b/test/onnx/glm4_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGlm4MoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("glm4_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("glm4_moe") + end +end diff --git a/test/onnx/glm4_test.rb b/test/onnx/glm4_test.rb new file mode 100644 index 0000000..083db9a --- /dev/null +++ b/test/onnx/glm4_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGlm4Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("glm4") + end + + def test_onnx_compat_report + assert_onnx_compat_report("glm4") + end +end diff --git a/test/onnx/glm_moe_dsa_test.rb b/test/onnx/glm_moe_dsa_test.rb new file mode 100644 index 0000000..a88e0e8 --- /dev/null +++ b/test/onnx/glm_moe_dsa_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGlmMoeDsaTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("glm_moe_dsa") + end + + def test_onnx_compat_report + assert_onnx_compat_report("glm_moe_dsa") + end +end diff --git a/test/onnx/glm_test.rb b/test/onnx/glm_test.rb new file mode 100644 index 0000000..ff38610 --- /dev/null +++ b/test/onnx/glm_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGlmTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("glm") + end + + def test_onnx_compat_report + assert_onnx_compat_report("glm") + end +end diff --git a/test/onnx/gpt2_test.rb b/test/onnx/gpt2_test.rb new file mode 100644 index 0000000..8f10d52 --- /dev/null +++ b/test/onnx/gpt2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGpt2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gpt2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gpt2") + end +end diff --git a/test/onnx/gpt_bigcode_test.rb b/test/onnx/gpt_bigcode_test.rb new file mode 100644 index 0000000..356ebf9 --- /dev/null +++ b/test/onnx/gpt_bigcode_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGptBigcodeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gpt_bigcode") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gpt_bigcode") + end +end diff --git a/test/onnx/gpt_neox_test.rb b/test/onnx/gpt_neox_test.rb new file mode 100644 index 0000000..fd9c698 --- /dev/null +++ b/test/onnx/gpt_neox_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGptNeoxTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gpt_neox") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gpt_neox") + end +end diff --git a/test/onnx/gpt_oss_test.rb b/test/onnx/gpt_oss_test.rb new file mode 100644 index 0000000..f227cb4 --- /dev/null +++ b/test/onnx/gpt_oss_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGptOssTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("gpt_oss") + end + + def test_onnx_compat_report + assert_onnx_compat_report("gpt_oss") + end +end diff --git a/test/onnx/granite_test.rb b/test/onnx/granite_test.rb new file mode 100644 index 0000000..c2d9a4b --- /dev/null +++ b/test/onnx/granite_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGraniteTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("granite") + end + + def test_onnx_compat_report + assert_onnx_compat_report("granite") + end +end diff --git a/test/onnx/granitemoe_test.rb b/test/onnx/granitemoe_test.rb new file mode 100644 index 0000000..8a588a2 --- /dev/null +++ b/test/onnx/granitemoe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGranitemoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("granitemoe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("granitemoe") + end +end diff --git a/test/onnx/granitemoehybrid_test.rb b/test/onnx/granitemoehybrid_test.rb new file mode 100644 index 0000000..6ffb595 --- /dev/null +++ b/test/onnx/granitemoehybrid_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportGranitemoehybridTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("granitemoehybrid") + end + + def test_onnx_compat_report + assert_onnx_compat_report("granitemoehybrid") + end +end diff --git a/test/onnx/helium_test.rb b/test/onnx/helium_test.rb new file mode 100644 index 0000000..fd94317 --- /dev/null +++ b/test/onnx/helium_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportHeliumTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("helium") + end + + def test_onnx_compat_report + assert_onnx_compat_report("helium") + end +end diff --git a/test/onnx/hunyuan_test.rb b/test/onnx/hunyuan_test.rb new file mode 100644 index 0000000..0ae5edf --- /dev/null +++ b/test/onnx/hunyuan_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportHunyuanTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("hunyuan") + end + + def test_onnx_compat_report + assert_onnx_compat_report("hunyuan") + end +end diff --git a/test/onnx/hunyuan_v1_dense_test.rb b/test/onnx/hunyuan_v1_dense_test.rb new file mode 100644 index 0000000..0ff73c3 --- /dev/null +++ b/test/onnx/hunyuan_v1_dense_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportHunyuanV1DenseTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("hunyuan_v1_dense") + end + + def test_onnx_compat_report + assert_onnx_compat_report("hunyuan_v1_dense") + end +end diff --git a/test/onnx/internlm2_test.rb b/test/onnx/internlm2_test.rb new file mode 100644 index 0000000..f88ad3c --- /dev/null +++ b/test/onnx/internlm2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportInternlm2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("internlm2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("internlm2") + end +end diff --git a/test/onnx/internlm3_test.rb b/test/onnx/internlm3_test.rb new file mode 100644 index 0000000..45cb7cc --- /dev/null +++ b/test/onnx/internlm3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportInternlm3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("internlm3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("internlm3") + end +end diff --git a/test/onnx/iquestloopcoder_test.rb b/test/onnx/iquestloopcoder_test.rb new file mode 100644 index 0000000..5326cbf --- /dev/null +++ b/test/onnx/iquestloopcoder_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportIquestloopcoderTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("iquestloopcoder") + end + + def test_onnx_compat_report + assert_onnx_compat_report("iquestloopcoder") + end +end diff --git a/test/onnx/jamba_test.rb b/test/onnx/jamba_test.rb new file mode 100644 index 0000000..13ea25d --- /dev/null +++ b/test/onnx/jamba_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportJambaTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("jamba") + end + + def test_onnx_compat_report + assert_onnx_compat_report("jamba") + end +end diff --git a/test/onnx/kimi_k25_test.rb b/test/onnx/kimi_k25_test.rb new file mode 100644 index 0000000..225b204 --- /dev/null +++ b/test/onnx/kimi_k25_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportKimiK25Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("kimi_k25") + end + + def test_onnx_compat_report + assert_onnx_compat_report("kimi_k25") + end +end diff --git a/test/onnx/kimi_linear_test.rb b/test/onnx/kimi_linear_test.rb new file mode 100644 index 0000000..506a0a6 --- /dev/null +++ b/test/onnx/kimi_linear_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportKimiLinearTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("kimi_linear") + end + + def test_onnx_compat_report + assert_onnx_compat_report("kimi_linear") + end +end diff --git a/test/onnx/kimi_vl_test.rb b/test/onnx/kimi_vl_test.rb new file mode 100644 index 0000000..b2b6c96 --- /dev/null +++ b/test/onnx/kimi_vl_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportKimiVlTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("kimi_vl") + end + + def test_onnx_compat_report + assert_onnx_compat_report("kimi_vl") + end +end diff --git a/test/onnx/klear_test.rb b/test/onnx/klear_test.rb new file mode 100644 index 0000000..112b568 --- /dev/null +++ b/test/onnx/klear_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportKlearTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("Klear") + end + + def test_onnx_compat_report + assert_onnx_compat_report("Klear") + end +end diff --git a/test/onnx/lfm2_moe_test.rb b/test/onnx/lfm2_moe_test.rb new file mode 100644 index 0000000..96eab68 --- /dev/null +++ b/test/onnx/lfm2_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLfm2MoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("lfm2_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("lfm2_moe") + end +end diff --git a/test/onnx/lfm2_test.rb b/test/onnx/lfm2_test.rb new file mode 100644 index 0000000..a58129d --- /dev/null +++ b/test/onnx/lfm2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLfm2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("lfm2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("lfm2") + end +end diff --git a/test/onnx/lfm2_vl_test.rb b/test/onnx/lfm2_vl_test.rb new file mode 100644 index 0000000..1098f06 --- /dev/null +++ b/test/onnx/lfm2_vl_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLfm2VlTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("lfm2-vl") + end + + def test_onnx_compat_report + assert_onnx_compat_report("lfm2-vl") + end +end diff --git a/test/onnx/lille_130m_test.rb b/test/onnx/lille_130m_test.rb new file mode 100644 index 0000000..1efbcd9 --- /dev/null +++ b/test/onnx/lille_130m_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLille130mTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("lille-130m") + end + + def test_onnx_compat_report + assert_onnx_compat_report("lille-130m") + end +end diff --git a/test/onnx/llama4_test.rb b/test/onnx/llama4_test.rb new file mode 100644 index 0000000..f789fdf --- /dev/null +++ b/test/onnx/llama4_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLlama4Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("llama4") + end + + def test_onnx_compat_report + assert_onnx_compat_report("llama4") + end +end diff --git a/test/onnx/llama4_text_test.rb b/test/onnx/llama4_text_test.rb new file mode 100644 index 0000000..0ea5bdf --- /dev/null +++ b/test/onnx/llama4_text_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLlama4TextTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("llama4_text") + end + + def test_onnx_compat_report + assert_onnx_compat_report("llama4_text") + end +end diff --git a/test/onnx/llama_test.rb b/test/onnx/llama_test.rb new file mode 100644 index 0000000..2da77c5 --- /dev/null +++ b/test/onnx/llama_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLlamaTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("llama") + end + + def test_onnx_compat_report + assert_onnx_compat_report("llama") + end +end diff --git a/test/onnx/longcat_flash_ngram_test.rb b/test/onnx/longcat_flash_ngram_test.rb new file mode 100644 index 0000000..f3af9c2 --- /dev/null +++ b/test/onnx/longcat_flash_ngram_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLongcatFlashNgramTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("longcat_flash_ngram") + end + + def test_onnx_compat_report + assert_onnx_compat_report("longcat_flash_ngram") + end +end diff --git a/test/onnx/longcat_flash_test.rb b/test/onnx/longcat_flash_test.rb new file mode 100644 index 0000000..895138f --- /dev/null +++ b/test/onnx/longcat_flash_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportLongcatFlashTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("longcat_flash") + end + + def test_onnx_compat_report + assert_onnx_compat_report("longcat_flash") + end +end diff --git a/test/onnx/mamba2_test.rb b/test/onnx/mamba2_test.rb new file mode 100644 index 0000000..0391f93 --- /dev/null +++ b/test/onnx/mamba2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMamba2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("mamba2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("mamba2") + end +end diff --git a/test/onnx/mamba_test.rb b/test/onnx/mamba_test.rb new file mode 100644 index 0000000..18f2519 --- /dev/null +++ b/test/onnx/mamba_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMambaTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("mamba") + end + + def test_onnx_compat_report + assert_onnx_compat_report("mamba") + end +end diff --git a/test/onnx/mimo_test.rb b/test/onnx/mimo_test.rb new file mode 100644 index 0000000..9c21dd5 --- /dev/null +++ b/test/onnx/mimo_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMimoTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("mimo") + end + + def test_onnx_compat_report + assert_onnx_compat_report("mimo") + end +end diff --git a/test/onnx/mimo_v2_flash_test.rb b/test/onnx/mimo_v2_flash_test.rb new file mode 100644 index 0000000..3fabcba --- /dev/null +++ b/test/onnx/mimo_v2_flash_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMimoV2FlashTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("mimo_v2_flash") + end + + def test_onnx_compat_report + assert_onnx_compat_report("mimo_v2_flash") + end +end diff --git a/test/onnx/minicpm3_test.rb b/test/onnx/minicpm3_test.rb new file mode 100644 index 0000000..4b034c4 --- /dev/null +++ b/test/onnx/minicpm3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMinicpm3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("minicpm3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("minicpm3") + end +end diff --git a/test/onnx/minicpm_test.rb b/test/onnx/minicpm_test.rb new file mode 100644 index 0000000..24ca1ce --- /dev/null +++ b/test/onnx/minicpm_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMinicpmTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("minicpm") + end + + def test_onnx_compat_report + assert_onnx_compat_report("minicpm") + end +end diff --git a/test/onnx/minimax_test.rb b/test/onnx/minimax_test.rb new file mode 100644 index 0000000..dbe3d43 --- /dev/null +++ b/test/onnx/minimax_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMinimaxTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("minimax") + end + + def test_onnx_compat_report + assert_onnx_compat_report("minimax") + end +end diff --git a/test/onnx/ministral3_test.rb b/test/onnx/ministral3_test.rb new file mode 100644 index 0000000..a81942f --- /dev/null +++ b/test/onnx/ministral3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMinistral3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("ministral3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("ministral3") + end +end diff --git a/test/onnx/mistral3_test.rb b/test/onnx/mistral3_test.rb new file mode 100644 index 0000000..88a457a --- /dev/null +++ b/test/onnx/mistral3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMistral3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("mistral3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("mistral3") + end +end diff --git a/test/onnx/mixtral_test.rb b/test/onnx/mixtral_test.rb new file mode 100644 index 0000000..36c3943 --- /dev/null +++ b/test/onnx/mixtral_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportMixtralTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("mixtral") + end + + def test_onnx_compat_report + assert_onnx_compat_report("mixtral") + end +end diff --git a/test/onnx/nanochat_test.rb b/test/onnx/nanochat_test.rb new file mode 100644 index 0000000..44d1d20 --- /dev/null +++ b/test/onnx/nanochat_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportNanochatTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("nanochat") + end + + def test_onnx_compat_report + assert_onnx_compat_report("nanochat") + end +end diff --git a/test/onnx/nemotron_h_test.rb b/test/onnx/nemotron_h_test.rb new file mode 100644 index 0000000..38bbc2c --- /dev/null +++ b/test/onnx/nemotron_h_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportNemotronHTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("nemotron_h") + end + + def test_onnx_compat_report + assert_onnx_compat_report("nemotron_h") + end +end diff --git a/test/onnx/nemotron_nas_test.rb b/test/onnx/nemotron_nas_test.rb new file mode 100644 index 0000000..0900c40 --- /dev/null +++ b/test/onnx/nemotron_nas_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportNemotronNasTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("nemotron-nas") + end + + def test_onnx_compat_report + assert_onnx_compat_report("nemotron-nas") + end +end diff --git a/test/onnx/nemotron_test.rb b/test/onnx/nemotron_test.rb new file mode 100644 index 0000000..28b6ee5 --- /dev/null +++ b/test/onnx/nemotron_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportNemotronTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("nemotron") + end + + def test_onnx_compat_report + assert_onnx_compat_report("nemotron") + end +end diff --git a/test/onnx/olmo2_test.rb b/test/onnx/olmo2_test.rb new file mode 100644 index 0000000..3b72006 --- /dev/null +++ b/test/onnx/olmo2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportOlmo2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("olmo2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("olmo2") + end +end diff --git a/test/onnx/olmo3_test.rb b/test/onnx/olmo3_test.rb new file mode 100644 index 0000000..3d952d9 --- /dev/null +++ b/test/onnx/olmo3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportOlmo3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("olmo3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("olmo3") + end +end diff --git a/test/onnx/olmo_test.rb b/test/onnx/olmo_test.rb new file mode 100644 index 0000000..6c60fa9 --- /dev/null +++ b/test/onnx/olmo_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportOlmoTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("olmo") + end + + def test_onnx_compat_report + assert_onnx_compat_report("olmo") + end +end diff --git a/test/onnx/olmoe_test.rb b/test/onnx/olmoe_test.rb new file mode 100644 index 0000000..d7efd02 --- /dev/null +++ b/test/onnx/olmoe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportOlmoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("olmoe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("olmoe") + end +end diff --git a/test/onnx/onnx_export_test.rb b/test/onnx/onnx_export_test.rb new file mode 100644 index 0000000..ebfea81 --- /dev/null +++ b/test/onnx/onnx_export_test.rb @@ -0,0 +1,954 @@ +# frozen_string_literal: true + +require "json" +require "open3" +require "tmpdir" + +module OnnxExportTestHelper + EXPLICIT_TINY_CONFIGS = { + "llama" => { + "model_type" => "llama", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "tie_word_embeddings" => true, + }, + "gemma" => { + "model_type" => "gemma", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "head_dim" => 32, + "tie_word_embeddings" => true, + }, + "gemma2" => { + "model_type" => "gemma2", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "head_dim" => 32, + "query_pre_attn_scalar" => 32.0, + }, + "qwen2" => { + "model_type" => "qwen2", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "tie_word_embeddings" => true, + }, + "phi3" => { + "model_type" => "phi3", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "tie_word_embeddings" => true, + }, + "starcoder2" => { + "model_type" => "starcoder2", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "tie_word_embeddings" => true, + }, + "stablelm" => { + "model_type" => "stablelm", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + }, + "cohere" => { + "model_type" => "cohere", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + }, + "olmo2" => { + "model_type" => "olmo2", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "tie_word_embeddings" => true, + }, + "gpt_neox" => { + "model_type" => "gpt_neox", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "vocab_size" => 128, + "intermediate_size" => 256, + }, + "mixtral" => { + "model_type" => "mixtral", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "tie_word_embeddings" => true, + }, + "deepseek" => { + "model_type" => "deepseek", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "moe_intermediate_size" => 64, + "vocab_size" => 128, + "n_routed_experts" => 2, + "num_experts_per_tok" => 1, + "n_shared_experts" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 1, + }, + "internlm2" => { + "model_type" => "internlm2", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "bias" => false, + "tie_word_embeddings" => true, + }, + }.freeze + + DEFAULT_TINY_CONFIG = { + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "intermediate_size_mlp" => 128, + "vocab_size" => 128, + "rms_norm_eps" => 1e-5, + "norm_eps" => 1e-5, + "layer_norm_eps" => 1e-5, + "layer_norm_epsilon" => 1e-5, + "norm_epsilon" => 1e-5, + "head_dim" => 32, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + "attention_bias" => false, + "mlp_bias" => false, + "use_bias" => false, + "use_conv_bias" => true, + "num_layers" => 2, + "n_layer" => 2, + "n_head" => 2, + "n_embd" => 64, + "n_inner" => 128, + "n_positions" => 256, + "n_ctx" => 256, + "num_local_experts" => 2, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "num_shared_experts" => 1, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "moe_intermediate_size" => 64, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 1, + "n_group" => 1, + "topk_group" => 1, + "norm_topk_prob" => true, + "routed_scaling_factor" => 1.0, + "partial_rotary_factor" => 0.5, + "scoring_func" => "sigmoid", + "topk_method" => "noaux_tc", + "sliding_window" => 64, + "num_heads" => 2, + "state_size" => 8, + "conv_kernel" => 3, + "n_groups" => 1, + "shared_intermediate_size" => 64, + "rotary_dim" => 16, + "hidden_act" => "silu", + }.freeze + + MODEL_TINY_CONFIG_OVERRIDES = { + "dots1" => { + "first_k_dense_replace" => 1, + "moe_intermediate_size" => 48, + "n_routed_experts" => 2, + "n_shared_experts" => 1, + "num_experts_per_tok" => 1, + "norm_topk_prob" => true, + "routed_scaling_factor" => 1.0, + "head_dim" => 8, + "scoring_func" => "noaux_tc", + "tie_word_embeddings" => false, + }, + "ernie4_5" => { + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "use_bias" => false, + "tie_word_embeddings" => true, + }, + "ernie4_5_moe" => { + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "use_bias" => false, + "tie_word_embeddings" => false, + "moe_num_experts" => 2, + "moe_k" => 1, + "moe_layer_interval" => 1, + "moe_layer_start_index" => 0, + "moe_num_shared_experts" => 1, + "moe_gate_act" => "softmax", + }, + "exaone" => { + "num_layers" => 2, + "layer_norm_epsilon" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + "attention_bias" => false, + "mlp_bias" => false, + }, + "exaone4" => { + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "head_dim" => 16, + "tie_word_embeddings" => false, + "sliding_window" => 4, + "sliding_window_pattern" => "LLGL", + }, + "exaone_moe" => { + "moe_intermediate_size" => 16, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "num_shared_experts" => 1, + "max_position_embeddings" => 256, + "sliding_window" => 2, + "layer_types" => ["full_attention", "sliding_attention"], + "is_moe_layer" => [true, false], + "norm_topk_prob" => true, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + }, + "glm4_moe" => { + "max_position_embeddings" => 256, + "moe_intermediate_size" => 16, + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "num_experts_per_tok" => 1, + "first_k_dense_replace" => 0, + "use_qk_norm" => true, + "tie_word_embeddings" => false, + "attention_bias" => false, + "partial_rotary_factor" => 0.5, + }, + "iquestloopcoder" => { + "hidden_size" => 32, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "intermediate_size" => 64, + "head_dim" => 8, + "loop_num" => 2, + "loop_window_size" => 4, + "tie_word_embeddings" => false, + }, + "llama4" => { + "text_config" => { + "model_type" => "llama4_text", + "hidden_size" => 32, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "num_hidden_layers" => 2, + "vocab_size" => 97, + "intermediate_size" => 64, + "intermediate_size_mlp" => 64, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "interleave_moe_layer_step" => 2, + "attention_chunk_size" => 4, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "head_dim" => 8, + "rms_norm_eps" => 1e-5, + "attention_bias" => false, + "use_qk_norm" => true, + }, + }, + "llama4_text" => { + "intermediate_size" => 64, + "intermediate_size_mlp" => 64, + "no_rope_layers" => [0, 1], + "use_qk_norm" => true, + }, + "mamba2" => { + "num_heads" => 4, + "head_dim" => 4, + "hidden_size" => 32, + "intermediate_size" => 16, + "state_size" => 8, + "num_hidden_layers" => 2, + "layer_norm_epsilon" => 1e-5, + "conv_kernel" => 3, + "n_groups" => 2, + "use_bias" => true, + "use_conv_bias" => true, + "time_step_limit" => [0.001, 10.0], + "time_step_rank" => "auto", + }, + "mimo_v2_flash" => { + "num_experts_per_tok" => 1, + "hybrid_layer_pattern" => [0, 1], + "moe_layer_freq" => [0, 1], + "sliding_window_size" => 2, + "moe_intermediate_size" => 48, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "topk_method" => "noaux_tc", + "scoring_func" => "sigmoid", + "norm_topk_prob" => true, + "layernorm_epsilon" => 1e-5, + "swa_rope_theta" => 20_000.0, + "swa_num_attention_heads" => 4, + "swa_num_key_value_heads" => 2, + "head_dim" => 8, + "v_head_dim" => 8, + "swa_head_dim" => 8, + "swa_v_head_dim" => 8, + "partial_rotary_factor" => 1.0, + }, + "minimax" => { + "max_position_embeddings" => 128, + "num_experts_per_tok" => 1, + "num_local_experts" => 2, + "shared_intermediate_size" => 32, + "rotary_dim" => 8, + "tie_word_embeddings" => false, + "use_qk_norm" => true, + }, + "ministral3" => { + "num_hidden_layers" => 4, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "head_dim" => 8, + "max_position_embeddings" => 128, + "tie_word_embeddings" => true, + "sliding_window" => 8, + "layer_types" => ["sliding_attention", "full_attention", "sliding_attention", "full_attention"], + "rope_parameters" => { + "rope_theta" => 10_000.0, + "llama_4_scaling_beta" => 0.1, + "original_max_position_embeddings" => 128, + }, + }, + "nemotron" => { + "hidden_act" => "relu2", + "norm_eps" => 1e-5, + "partial_rotary_factor" => 0.5, + "rope_scaling" => { "type" => "linear", "factor" => 2.0 }, + "tie_word_embeddings" => false, + }, + "nemotron-nas" => { + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "vocab_size" => 101, + "rms_norm_eps" => 1e-5, + "hidden_act" => "silu", + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "rope_scaling" => { "type" => "linear", "factor" => 2.0 }, + "max_position_embeddings" => 128, + "tie_word_embeddings" => true, + "block_configs" => [ + { + "attention" => { "n_heads_in_group" => 2 }, + "ffn" => { "ffn_mult" => 1.5 }, + }, + { + "attention" => { "no_op" => true }, + "ffn" => { "replace_with_linear" => true }, + }, + ], + }, + "olmo3" => { + "hidden_size" => 48, + "num_hidden_layers" => 4, + "intermediate_size" => 96, + "num_attention_heads" => 4, + "vocab_size" => 128, + "max_position_embeddings" => 256, + "sliding_window" => 4, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }, + "mistral3" => { + "text_config" => { + "model_type" => "llama", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }, + }, + "baichuan_m1" => { + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rope_theta" => 10_000.0, + "sliding_window" => 4, + "sliding_window_layers" => [0], + "conv_window" => 2, + "rms_norm_eps" => 1e-5, + "tie_word_embeddings" => false, + }, + "dbrx" => { + "vocab_size" => 101, + "d_model" => 24, + "n_layers" => 0, + "n_heads" => 4, + "attn_config" => { + "kv_n_heads" => 2, + "clip_qkv" => 8.0, + "rope_theta" => 10_000.0, + }, + "ffn_config" => { + "ffn_hidden_size" => 16, + "moe_num_experts" => 2, + "moe_top_k" => 1, + }, + }, + "granite" => { + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 128, + "logits_scaling" => 2.0, + "attention_multiplier" => 0.25, + "embedding_multiplier" => 1.5, + "residual_multiplier" => 0.75, + "max_position_embeddings" => 256, + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + }, + "granitemoe" => { + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 97, + "logits_scaling" => 2.0, + "attention_multiplier" => 0.25, + "embedding_multiplier" => 1.25, + "residual_multiplier" => 0.75, + "max_position_embeddings" => 256, + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "tie_word_embeddings" => true, + }, + "lfm2_moe" => { + "vocab_size" => 101, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 48, + "num_hidden_layers" => 3, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "norm_topk_prob" => true, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "max_position_embeddings" => 128, + "use_expert_bias" => true, + "num_dense_layers" => 1, + "norm_eps" => 1e-5, + "conv_bias" => false, + "conv_L_cache" => 3, + "layer_types" => ["full_attention", "conv", "full_attention"], + "rope_parameters" => { "rope_theta" => 10_000.0 }, + }, + "lille-130m" => { + "block_size" => 128, + "layer_norm_eps" => 1e-5, + "n_embd" => 96, + "n_head" => 4, + "n_kv_heads" => 2, + "n_layer" => 2, + "rope_theta" => 10_000.0, + "vocab_size" => 89, + "tie_word_embeddings" => true, + }, + "minicpm" => { + "hidden_size" => 32, + "dim_model_base" => 16, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 96, + "scale_depth" => 1.4, + "scale_emb" => 8.0, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }, + "minicpm3" => { + "hidden_size" => 32, + "dim_model_base" => 16, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 89, + "q_lora_rank" => 8, + "qk_nope_head_dim" => 4, + "qk_rope_head_dim" => 4, + "kv_lora_rank" => 8, + "scale_depth" => 1.0, + "scale_emb" => 1.25, + "max_position_embeddings" => 256, + "attention_bias" => false, + "rope_theta" => 10_000.0, + "rope_scaling" => { + "original_max_position_embeddings" => 128, + "short_factor" => 1.0, + "long_factor" => 1.0, + }, + "tie_word_embeddings" => false, + }, + "phi3small" => { + "hidden_size" => 32, + "dense_attention_every_n_layers" => 1, + "ff_intermediate_size" => 64, + "gegelu_limit" => 16.0, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "layer_norm_epsilon" => 1e-5, + "vocab_size" => 97, + "num_key_value_heads" => 2, + "mup_attn_multiplier" => 1.0, + "mup_use_scaling" => true, + "mup_embedding_multiplier" => 1.0, + "mup_width_multiplier" => 1.0, + "rope_embedding_base" => 10_000.0, + "rope_position_scale" => 1.0, + "blocksparse_block_size" => 64, + "blocksparse_num_local_blocks" => 4, + "blocksparse_vert_stride" => 2, + }, + "plamo" => { + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "n_shared_head" => 2, + "rope_theta" => 10_000.0, + "rope_traditional" => false, + }, + "qwen" => { + "hidden_size" => 32, + "num_attention_heads" => 2, + "num_hidden_layers" => 2, + "kv_channels" => 16, + "intermediate_size" => 64, + "vocab_size" => 100, + "no_bias" => true, + }, + "qwen2_moe" => { + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 109, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "tie_word_embeddings" => false, + "num_experts_per_tok" => 1, + "num_experts" => 2, + "moe_intermediate_size" => 16, + "shared_expert_intermediate_size" => 24, + }, + "recurrent_gemma" => { + "hidden_size" => 32, + "attention_bias" => false, + "conv1d_width" => 3, + "intermediate_size" => 64, + "logits_soft_cap" => 1.5, + "num_attention_heads" => 4, + "num_hidden_layers" => 3, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_window_size" => 4, + "vocab_size" => 97, + "block_types" => ["recurrent", "attention"], + }, + "step3p5" => { + "hidden_size" => 32, + "num_hidden_layers" => 3, + "vocab_size" => 103, + "num_attention_heads" => 4, + "num_attention_groups" => 2, + "head_dim" => 8, + "intermediate_size" => 64, + "rms_norm_eps" => 1e-5, + "rope_theta" => [10_000.0, 10_000.0, 10_000.0], + "sliding_window" => 4, + "layer_types" => ["full_attention", "sliding_attention", "full_attention"], + "partial_rotary_factors" => [0.5, 1.0, 0.5], + "attention_other_setting" => { + "num_attention_heads" => 4, + "num_attention_groups" => 2, + }, + "use_head_wise_attn_gate" => true, + "moe_num_experts" => 2, + "moe_top_k" => 1, + "moe_intermediate_size" => 48, + "share_expert_dim" => 48, + "moe_layers_enum" => "1,2", + }, + "pixtral" => { + "text_config" => { + "model_type" => "llama", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "intermediate_size" => 128, + "vocab_size" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }, + }, + }.freeze + + def tiny_config_for(model_type) + explicit = EXPLICIT_TINY_CONFIGS[model_type] + return explicit if explicit + + DEFAULT_TINY_CONFIG + .merge("model_type" => model_type) + .merge(MODEL_TINY_CONFIG_OVERRIDES.fetch(model_type, {})) + end + + SUBPROCESS_SCRIPT = <<~'RUBY' + require "json" + require "tmpdir" + + $LOAD_PATH.unshift File.expand_path("lib", __dir__) + $LOAD_PATH.unshift File.expand_path("mlx-ruby/lib", __dir__) + require "mlx" + require "mlx_lm" + + config = JSON.parse(ARGV[0]) + model_type = config["model_type"] + result = { "model_type" => model_type } + + begin + mx = MLX::Core + model_class, args_class = MlxLm::Models.get_classes(config) + args = args_class.from_dict(config) + model = model_class.new(args) + + params = MLX::Utils.tree_flatten(model.parameters).map { |_, v| v } + mx.eval(*params) unless params.empty? + + input = mx.array([[1, 2, 3]], dtype: mx.int32) + fun = ->(x) { model.call(x) } + + # Run compatibility report first (does not require full lowering) + begin + report = MLX::ONNX.export_onnx_compatibility_report(fun, input) + include_full_nodes = ENV["ONNX_COMPAT_REPORT_JSON"] == "1" + unsupported_invocations = Array(report["nodes"]).filter_map do |node| + next unless node.is_a?(Hash) && node["supported"] == false + + { + "index" => node["index"], + "op" => node["op"], + "onnx_op_type" => node["onnx_op_type"], + } + end + + result["compat_report"] = { + "total_nodes" => report["total_nodes"], + "supported_nodes" => report["supported_nodes"], + "unsupported_nodes" => report["unsupported_nodes"], + "unsupported_ops" => report["unsupported_ops"], + "ready" => report["ready_for_stub_conversion"], + "unsupported_invocations" => unsupported_invocations, + } + if include_full_nodes + result["compat_report"]["nodes"] = report["nodes"] + result["compat_report"]["format"] = report["format"] + result["compat_report"]["ir_version"] = report["ir_version"] + end + rescue => e + result["compat_error"] = "#{e.class}: #{e.message}" + end + + # Attempt full ONNX export + begin + Dir.mktmpdir do |dir| + path = File.join(dir, "#{model_type}.onnx") + MLX::ONNX.export_onnx(path, fun, input) + result["export"] = "success" + result["onnx_size"] = File.size(path) + end + rescue NotImplementedError, RuntimeError => e + result["export"] = "failed" + result["export_error"] = e.message + end + rescue => e + result["fatal"] = "#{e.class}: #{e.message}" + end + + puts JSON.generate(result) + RUBY + + def run_model_in_subprocess(model_type) + config_json = JSON.generate(tiny_config_for(model_type)) + project_root = File.expand_path("../..", __dir__) + timeout_seconds = Integer(ENV.fetch("ONNX_EXPORT_SUBPROCESS_TIMEOUT", "180")) + + out = +"" + err = +"" + status = nil + + Open3.popen3("ruby", "-e", SUBPROCESS_SCRIPT, config_json, chdir: project_root) do |stdin, stdout, stderr, wait_thr| + stdin.close + out_reader = Thread.new { stdout.read.to_s } + err_reader = Thread.new { stderr.read.to_s } + + unless wait_thr.join(timeout_seconds) + pid = wait_thr.pid + begin + Process.kill("TERM", pid) + rescue Errno::ESRCH + # already exited + end + unless wait_thr.join(5) + begin + Process.kill("KILL", pid) + rescue Errno::ESRCH + # already exited + end + wait_thr.join + end + + out = out_reader.value + err = err_reader.value + return { + "model_type" => model_type, + "export" => "timeout", + "timeout_seconds" => timeout_seconds, + "stdout" => out.lines.first(10).join, + "stderr" => err.lines.first(10).join, + } + end + + status = wait_thr.value + out = out_reader.value + err = err_reader.value + end + + if status.signaled? + sig = status.termsig + signal_name = Signal.signame(sig) rescue sig.to_s + return { + "model_type" => model_type, + "export" => "crashed", + "crash_signal" => signal_name, + "stderr" => err.lines.first(5).join, + } + end + + unless status.success? + return { + "model_type" => model_type, + "export" => "process_error", + "exit_code" => status.exitstatus, + "stderr" => err.lines.first(10).join, + } + end + + JSON.parse(out) + rescue JSON::ParserError + { + "model_type" => model_type, + "export" => "parse_error", + "stdout" => out.to_s[0, 500], + "stderr" => err.to_s[0, 500], + } + end + + def onnx_log_lines_enabled? + value = ENV.fetch("ONNX_LOG_LINES", "0").strip.downcase + %w[1 true yes on].include?(value) + end + + def onnx_log_line(text) + puts text if onnx_log_lines_enabled? + end + + def onnx_full_export_enabled? + value = ENV.fetch("ONNX_FULL_EXPORT", "0").strip.downcase + %w[1 true yes on].include?(value) + end + + def assert_onnx_export(model_type) + unless onnx_full_export_enabled? + skip "#{model_type}: full ONNX export disabled by default (set ONNX_FULL_EXPORT=1 to enable)" + end + + result = run_model_in_subprocess(model_type) + + case result["export"] + when "success" + assert true, "#{model_type}: ONNX export succeeded (#{result['onnx_size']} bytes)" + report = result["compat_report"] + if report + onnx_log_line("\n [ONNX] #{model_type}: PASS — #{report['supported_nodes']}/#{report['total_nodes']} nodes, #{result['onnx_size']} bytes") + end + when "failed" + report = result["compat_report"] + msg = "#{model_type}: ONNX export failed — #{result['export_error']}" + if report + unsupported = report["unsupported_ops"] || [] + msg += "\n Nodes: #{report['supported_nodes']}/#{report['total_nodes']} supported" + msg += "\n Missing ops: #{unsupported.join(', ')}" + end + flunk(msg) + when "crashed" + report = result["compat_report"] + msg = "#{model_type}: ONNX tracing crashed with signal #{result['crash_signal']}" + if report + unsupported = report["unsupported_ops"] || [] + msg += "\n Compat report (pre-crash): #{report['supported_nodes']}/#{report['total_nodes']} nodes" + msg += "\n Missing ops: #{unsupported.empty? ? 'none' : unsupported.join(', ')}" + end + msg += "\n (MoE models crash because tolist forces data-dependent control flow during tracing)" + flunk(msg) + when "process_error" + flunk("#{model_type}: ONNX export process failed (exit #{result['exit_code']}): #{result['stderr']}") + when "parse_error" + flunk("#{model_type}: ONNX export subprocess parse error: stdout=#{result['stdout']}, stderr=#{result['stderr']}") + when "timeout" + flunk("#{model_type}: ONNX export subprocess timed out after #{result['timeout_seconds']}s") + else + flunk("#{model_type}: unexpected result — #{result.inspect}") + end + end + + def assert_onnx_compat_report(model_type) + result = run_model_in_subprocess(model_type) + + if result["compat_error"] + skip "#{model_type}: compat report unavailable — #{result['compat_error']}" + end + + if result["crash_signal"] + if result["compat_report"] + report = result["compat_report"] + assert_kind_of Integer, report["total_nodes"] + unsupported = report["unsupported_ops"] || [] + unsupported_invocations = report["unsupported_invocations"] || [] + pct = report["total_nodes"] > 0 ? (report["supported_nodes"].to_f / report["total_nodes"] * 100).round(1) : 0 + onnx_log_line("\n [ONNX] #{model_type}: #{report['supported_nodes']}/#{report['total_nodes']} nodes (#{pct}%) — missing: #{unsupported.empty? ? 'none' : unsupported.join(', ')} (CRASH during export)") + unsupported_invocations.each do |inv| + op = inv["op"] || "unknown" + onnx = inv["onnx_op_type"] || "nil" + index = inv["index"].nil? ? "nil" : inv["index"] + onnx_log_line(" [ONNX-INV] #{model_type}: op=#{op} onnx=#{onnx} index=#{index}") + end + if ENV["ONNX_COMPAT_REPORT_JSON"] == "1" + onnx_log_line(" [ONNX-JSON] #{model_type}: #{JSON.generate(report)}") + end + else + skip "#{model_type}: process crashed (signal #{result['crash_signal']}) before compat report" + end + return + end + + report = result["compat_report"] + skip "#{model_type}: no compat report in result" unless report + + assert_kind_of Integer, report["total_nodes"] + assert_kind_of Integer, report["supported_nodes"] + assert_kind_of Integer, report["unsupported_nodes"] + + total = report["total_nodes"] + supported = report["supported_nodes"] + unsupported_ops = report["unsupported_ops"] || [] + unsupported_invocations = report["unsupported_invocations"] || [] + pct = total > 0 ? (supported.to_f / total * 100).round(1) : 0 + status = result["export"] == "success" ? "PASS" : "FAIL" + onnx_log_line("\n [ONNX] #{model_type}: #{status} — #{supported}/#{total} nodes (#{pct}%) — missing: #{unsupported_ops.empty? ? 'none' : unsupported_ops.join(', ')}") + unsupported_invocations.each do |inv| + op = inv["op"] || "unknown" + onnx = inv["onnx_op_type"] || "nil" + index = inv["index"].nil? ? "nil" : inv["index"] + onnx_log_line(" [ONNX-INV] #{model_type}: op=#{op} onnx=#{onnx} index=#{index}") + end + if ENV["ONNX_COMPAT_REPORT_JSON"] == "1" + onnx_log_line(" [ONNX-JSON] #{model_type}: #{JSON.generate(report)}") + end + end +end diff --git a/test/onnx/openelm_test.rb b/test/onnx/openelm_test.rb new file mode 100644 index 0000000..ed0b5cc --- /dev/null +++ b/test/onnx/openelm_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportOpenelmTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("openelm") + end + + def test_onnx_compat_report + assert_onnx_compat_report("openelm") + end +end diff --git a/test/onnx/phi3_test.rb b/test/onnx/phi3_test.rb new file mode 100644 index 0000000..21a3291 --- /dev/null +++ b/test/onnx/phi3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPhi3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("phi3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("phi3") + end +end diff --git a/test/onnx/phi3small_test.rb b/test/onnx/phi3small_test.rb new file mode 100644 index 0000000..9ce66d6 --- /dev/null +++ b/test/onnx/phi3small_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPhi3smallTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("phi3small") + end + + def test_onnx_compat_report + assert_onnx_compat_report("phi3small") + end +end diff --git a/test/onnx/phi_test.rb b/test/onnx/phi_test.rb new file mode 100644 index 0000000..00ae140 --- /dev/null +++ b/test/onnx/phi_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPhiTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("phi") + end + + def test_onnx_compat_report + assert_onnx_compat_report("phi") + end +end diff --git a/test/onnx/phimoe_test.rb b/test/onnx/phimoe_test.rb new file mode 100644 index 0000000..409f638 --- /dev/null +++ b/test/onnx/phimoe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPhimoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("phimoe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("phimoe") + end +end diff --git a/test/onnx/phixtral_test.rb b/test/onnx/phixtral_test.rb new file mode 100644 index 0000000..81e3050 --- /dev/null +++ b/test/onnx/phixtral_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPhixtralTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("phixtral") + end + + def test_onnx_compat_report + assert_onnx_compat_report("phixtral") + end +end diff --git a/test/onnx/pixtral_test.rb b/test/onnx/pixtral_test.rb new file mode 100644 index 0000000..02b0732 --- /dev/null +++ b/test/onnx/pixtral_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPixtralTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("pixtral") + end + + def test_onnx_compat_report + assert_onnx_compat_report("pixtral") + end +end diff --git a/test/onnx/plamo2_test.rb b/test/onnx/plamo2_test.rb new file mode 100644 index 0000000..f87cdcb --- /dev/null +++ b/test/onnx/plamo2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPlamo2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("plamo2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("plamo2") + end +end diff --git a/test/onnx/plamo_test.rb b/test/onnx/plamo_test.rb new file mode 100644 index 0000000..c0f4c23 --- /dev/null +++ b/test/onnx/plamo_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportPlamoTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("plamo") + end + + def test_onnx_compat_report + assert_onnx_compat_report("plamo") + end +end diff --git a/test/onnx/qwen2_moe_test.rb b/test/onnx/qwen2_moe_test.rb new file mode 100644 index 0000000..cff0235 --- /dev/null +++ b/test/onnx/qwen2_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen2MoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen2_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen2_moe") + end +end diff --git a/test/onnx/qwen2_test.rb b/test/onnx/qwen2_test.rb new file mode 100644 index 0000000..4a8e099 --- /dev/null +++ b/test/onnx/qwen2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen2") + end +end diff --git a/test/onnx/qwen2_vl_test.rb b/test/onnx/qwen2_vl_test.rb new file mode 100644 index 0000000..c182821 --- /dev/null +++ b/test/onnx/qwen2_vl_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen2VlTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen2_vl") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen2_vl") + end +end diff --git a/test/onnx/qwen3_5_moe_test.rb b/test/onnx/qwen3_5_moe_test.rb new file mode 100644 index 0000000..f39e34f --- /dev/null +++ b/test/onnx/qwen3_5_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen35MoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3_5_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3_5_moe") + end +end diff --git a/test/onnx/qwen3_5_test.rb b/test/onnx/qwen3_5_test.rb new file mode 100644 index 0000000..9a1ab50 --- /dev/null +++ b/test/onnx/qwen3_5_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen35Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3_5") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3_5") + end +end diff --git a/test/onnx/qwen3_moe_test.rb b/test/onnx/qwen3_moe_test.rb new file mode 100644 index 0000000..d509894 --- /dev/null +++ b/test/onnx/qwen3_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen3MoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3_moe") + end +end diff --git a/test/onnx/qwen3_next_test.rb b/test/onnx/qwen3_next_test.rb new file mode 100644 index 0000000..850954f --- /dev/null +++ b/test/onnx/qwen3_next_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen3NextTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3_next") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3_next") + end +end diff --git a/test/onnx/qwen3_test.rb b/test/onnx/qwen3_test.rb new file mode 100644 index 0000000..1d512be --- /dev/null +++ b/test/onnx/qwen3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3") + end +end diff --git a/test/onnx/qwen3_vl_moe_test.rb b/test/onnx/qwen3_vl_moe_test.rb new file mode 100644 index 0000000..6a39d30 --- /dev/null +++ b/test/onnx/qwen3_vl_moe_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen3VlMoeTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3_vl_moe") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3_vl_moe") + end +end diff --git a/test/onnx/qwen3_vl_test.rb b/test/onnx/qwen3_vl_test.rb new file mode 100644 index 0000000..2c6f39a --- /dev/null +++ b/test/onnx/qwen3_vl_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwen3VlTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen3_vl") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen3_vl") + end +end diff --git a/test/onnx/qwen_test.rb b/test/onnx/qwen_test.rb new file mode 100644 index 0000000..7233dd0 --- /dev/null +++ b/test/onnx/qwen_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportQwenTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("qwen") + end + + def test_onnx_compat_report + assert_onnx_compat_report("qwen") + end +end diff --git a/test/onnx/recurrent_gemma_test.rb b/test/onnx/recurrent_gemma_test.rb new file mode 100644 index 0000000..663208e --- /dev/null +++ b/test/onnx/recurrent_gemma_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportRecurrentGemmaTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("recurrent_gemma") + end + + def test_onnx_compat_report + assert_onnx_compat_report("recurrent_gemma") + end +end diff --git a/test/onnx/registry_coverage_test.rb b/test/onnx/registry_coverage_test.rb new file mode 100644 index 0000000..467d6b6 --- /dev/null +++ b/test/onnx/registry_coverage_test.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class OnnxRegistryCoverageTest < Minitest::Test + def test_every_registered_model_has_onnx_wrapper_test + expected = MlxLm::Models::REGISTRY.keys.map { |model_type| "#{normalized_model_basename(model_type)}_test.rb" }.sort + actual = Dir.glob(File.join(__dir__, "*_test.rb")).map { |path| File.basename(path) } + + missing = expected - actual + assert_empty missing, "Missing ONNX wrapper tests for: #{missing.join(', ')}" + end + + private + + def normalized_model_basename(model_type) + model_type.downcase.gsub(/[^a-z0-9]+/, "_").gsub(/_+/, "_").gsub(/^_|_$/, "") + end +end diff --git a/test/onnx/rwkv7_test.rb b/test/onnx/rwkv7_test.rb new file mode 100644 index 0000000..77b15d2 --- /dev/null +++ b/test/onnx/rwkv7_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportRwkv7Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("rwkv7") + end + + def test_onnx_compat_report + assert_onnx_compat_report("rwkv7") + end +end diff --git a/test/onnx/seed_oss_test.rb b/test/onnx/seed_oss_test.rb new file mode 100644 index 0000000..1752909 --- /dev/null +++ b/test/onnx/seed_oss_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportSeedOssTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("seed_oss") + end + + def test_onnx_compat_report + assert_onnx_compat_report("seed_oss") + end +end diff --git a/test/onnx/smollm3_test.rb b/test/onnx/smollm3_test.rb new file mode 100644 index 0000000..5c87ccd --- /dev/null +++ b/test/onnx/smollm3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportSmollm3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("smollm3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("smollm3") + end +end diff --git a/test/onnx/solar_open_test.rb b/test/onnx/solar_open_test.rb new file mode 100644 index 0000000..4f93174 --- /dev/null +++ b/test/onnx/solar_open_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportSolarOpenTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("solar_open") + end + + def test_onnx_compat_report + assert_onnx_compat_report("solar_open") + end +end diff --git a/test/onnx/stablelm_test.rb b/test/onnx/stablelm_test.rb new file mode 100644 index 0000000..436201d --- /dev/null +++ b/test/onnx/stablelm_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportStablelmTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("stablelm") + end + + def test_onnx_compat_report + assert_onnx_compat_report("stablelm") + end +end diff --git a/test/onnx/starcoder2_test.rb b/test/onnx/starcoder2_test.rb new file mode 100644 index 0000000..2b3c6c3 --- /dev/null +++ b/test/onnx/starcoder2_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportStarcoder2Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("starcoder2") + end + + def test_onnx_compat_report + assert_onnx_compat_report("starcoder2") + end +end diff --git a/test/onnx/step3p5_test.rb b/test/onnx/step3p5_test.rb new file mode 100644 index 0000000..fe05777 --- /dev/null +++ b/test/onnx/step3p5_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportStep3p5Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("step3p5") + end + + def test_onnx_compat_report + assert_onnx_compat_report("step3p5") + end +end diff --git a/test/onnx/telechat3_test.rb b/test/onnx/telechat3_test.rb new file mode 100644 index 0000000..cedd19d --- /dev/null +++ b/test/onnx/telechat3_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportTelechat3Test < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("telechat3") + end + + def test_onnx_compat_report + assert_onnx_compat_report("telechat3") + end +end diff --git a/test/onnx/youtu_llm_test.rb b/test/onnx/youtu_llm_test.rb new file mode 100644 index 0000000..e01cc91 --- /dev/null +++ b/test/onnx/youtu_llm_test.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "onnx_export_test" + +class OnnxExportYoutuLlmTest < Minitest::Test + include OnnxExportTestHelper + + def test_onnx_export + assert_onnx_export("youtu_llm") + end + + def test_onnx_compat_report + assert_onnx_compat_report("youtu_llm") + end +end diff --git a/test/parity/afm7_bailing_moe_linear_models_test.rb b/test/parity/afm7_bailing_moe_linear_models_test.rb new file mode 100644 index 0000000..b260de4 --- /dev/null +++ b/test/parity/afm7_bailing_moe_linear_models_test.rb @@ -0,0 +1,152 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/afmoe" +require_relative "../../lib/mlx_lm/models/bailing_moe" +require_relative "../../lib/mlx_lm/models/afm7" +require_relative "../../lib/mlx_lm/models/bailing_moe_linear" + +class Phase27DenseLaneAOAfm7Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_afm7_construct_forward_shape_sanitize_passthrough_and_predicates + args = MlxLm::Models::Afm7::ModelArgs.from_dict({ + "model_type" => "afm7", + "vocab_size" => 67, + "hidden_dim" => 32, + "num_layers" => 3, + "num_kv_reuse_layers" => 1, + "num_heads" => 4, + "num_kv_heads" => 2, + "hidden_dim_scale_factor" => 2.0, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 128, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Afm7::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 67], output.shape + assert_equal 3, model.layers.length + + weights = { + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([4]).astype(@mx.float32), + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([67, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + assert sanitized.key?("model.embed_tokens.weight") + + cast_predicate = model.cast_predicate + quant_predicate = model.quant_predicate + assert_equal false, cast_predicate.call("model.layers.0.mlp.expert_bias") + assert_equal true, cast_predicate.call("model.layers.0.self_attn.q_proj.weight") + assert_equal({group_size: 64, bits: 8}, quant_predicate.call("model.layers.2.mlp.router.gate", nil)) + assert_equal true, quant_predicate.call("model.layers.2.mlp.experts.gate_proj", nil) + end +end + +class Phase27DenseLaneAOBailingMoeLinearTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_bailing_moe_linear_construct_forward_shape_sanitize_stack_and_predicates + args = MlxLm::Models::BailingMoeLinear::ModelArgs.from_dict({ + "model_type" => "bailing_moe_linear", + "hidden_size" => 32, + "intermediate_size" => 64, + "max_position_embeddings" => 128, + "moe_intermediate_size" => 24, + "num_experts" => 2, + "num_shared_experts" => 1, + "norm_topk_prob" => true, + "num_attention_heads" => 4, + "num_experts_per_tok" => 2, + "num_hidden_layers" => 3, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "vocab_size" => 79, + "first_k_dense_replace" => 1, + "layer_group_size" => 2, + "group_norm_size" => 1, + "use_bias" => false, + "use_qkv_bias" => false, + "tie_word_embeddings" => false, + "score_function" => "softmax", + "n_group" => 1, + "topk_group" => 1, + "moe_router_enable_expert_bias" => true, + "moe_router_enable_shared_expert" => true, + }) + + model = MlxLm::Models::BailingMoeLinear::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 79], output.shape + assert_equal 3, model.layers.length + + weights = { + "model.layers.1.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.0.down_proj.weight" => @mx.array([[0.0, 1.0], [1.0, 0.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.down_proj.weight" => @mx.array([[2.0, 3.0], [4.0, 5.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 1.0], [1.0, 1.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.up_proj.weight" => @mx.array([[2.0, 2.0], [2.0, 2.0]], dtype: @mx.float32), + "model.layers.1.mlp.gate.weight" => @mx.array([[1.0, 0.0], [0.0, 1.0]], dtype: @mx.float32), + "model.layers.1.mlp.gate.bias" => @mx.array([0.1, 0.2], dtype: @mx.float32), + "model.word_embeddings.weight" => @mx.zeros([79, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate = sanitized["model.layers.1.mlp.switch_mlp.gate_proj.weight"] + down = sanitized["model.layers.1.mlp.switch_mlp.down_proj.weight"] + up = sanitized["model.layers.1.mlp.switch_mlp.up_proj.weight"] + @mx.eval(gate, down, up) + + assert_equal [2, 2, 2], gate.shape + assert_equal [2, 2, 2], down.shape + assert_equal [2, 2, 2], up.shape + assert sanitized.key?("model.layers.1.mlp.gate.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.gate.gate_proj.bias") + refute sanitized.key?("model.layers.1.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.gate.weight") + refute sanitized.key?("model.layers.1.mlp.gate.bias") + assert sanitized.key?("model.word_embeddings.weight") + + cast_predicate = model.cast_predicate + quant_predicate = model.quant_predicate + assert_equal false, cast_predicate.call("model.layers.1.mlp.gate.expert_bias") + assert_equal true, cast_predicate.call("model.layers.1.mlp.gate.gate_proj.weight") + assert_equal({group_size: 64, bits: 8}, quant_predicate.call("model.layers.1.mlp.gate.gate_proj", nil)) + assert_equal true, quant_predicate.call("model.layers.1.mlp.switch_mlp.gate_proj", nil) + end +end + +class Phase27DenseLaneAORegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("afm7"), "afm7 should be registered" + assert MlxLm::Models::REGISTRY.key?("bailing_moe_linear"), "bailing_moe_linear should be registered" + + model_class, args_class = MlxLm::Models.get_classes({"model_type" => "afm7"}) + assert_equal MlxLm::Models::Afm7::Model, model_class + assert_equal MlxLm::Models::Afm7::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({"model_type" => "bailing_moe_linear"}) + assert_equal MlxLm::Models::BailingMoeLinear::Model, model_class + assert_equal MlxLm::Models::BailingMoeLinear::ModelArgs, args_class + end +end diff --git a/test/parity/afmoe_bailing_moe_models_test.rb b/test/parity/afmoe_bailing_moe_models_test.rb new file mode 100644 index 0000000..ede83f8 --- /dev/null +++ b/test/parity/afmoe_bailing_moe_models_test.rb @@ -0,0 +1,173 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/afmoe" +require_relative "../../lib/mlx_lm/models/bailing_moe" + +class Phase26DenseLaneAKAfmoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_afmoe_construct_forward_shape_sanitize_and_make_cache + args = MlxLm::Models::Afmoe::ModelArgs.from_dict({ + "model_type" => "afmoe", + "vocab_size" => 73, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 4, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "head_dim" => 8, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + "num_experts" => 2, + "num_experts_per_tok" => 2, + "num_shared_experts" => 1, + "num_dense_layers" => 1, + "route_norm" => true, + "route_scale" => 1.0, + "score_func" => "sigmoid", + "n_group" => 1, + "topk_group" => 1, + "sliding_window" => 4, + "mup_enabled" => false, + "layer_types" => [ + "full_attention", + "sliding_attention", + "sliding_attention", + "full_attention", + ], + }) + + model = MlxLm::Models::Afmoe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 73], output.shape + + weights = { + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([4]).astype(@mx.float32), + "model.layers.2.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.2.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.2.mlp.experts.0.down_proj.weight" => @mx.array([[0.0, 1.0], [1.0, 0.0]], dtype: @mx.float32), + "model.layers.2.mlp.experts.1.down_proj.weight" => @mx.array([[2.0, 3.0], [4.0, 5.0]], dtype: @mx.float32), + "model.layers.2.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 1.0], [1.0, 1.0]], dtype: @mx.float32), + "model.layers.2.mlp.experts.1.up_proj.weight" => @mx.array([[2.0, 2.0], [2.0, 2.0]], dtype: @mx.float32), + "model.embed_tokens.weight" => @mx.zeros([73, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate = sanitized["model.layers.2.mlp.experts.gate_proj.weight"] + down = sanitized["model.layers.2.mlp.experts.down_proj.weight"] + up = sanitized["model.layers.2.mlp.experts.up_proj.weight"] + @mx.eval(gate, down, up) + + assert_equal [2, 2, 2], gate.shape + assert_equal [2, 2, 2], down.shape + assert_equal [2, 2, 2], up.shape + + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("model.layers.2.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.2.mlp.experts.1.gate_proj.weight") + assert sanitized.key?("model.embed_tokens.weight") + + cache = model.make_cache + assert_equal 4, cache.length + assert_instance_of MlxLm::KVCache, cache[0] + assert_instance_of MlxLm::RotatingKVCache, cache[1] + assert_instance_of MlxLm::RotatingKVCache, cache[2] + assert_instance_of MlxLm::KVCache, cache[3] + end +end + +class Phase26DenseLaneAKBailingMoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_bailing_moe_construct_forward_shape_and_sanitize + args = MlxLm::Models::BailingMoe::ModelArgs.from_dict({ + "model_type" => "bailing_moe", + "hidden_size" => 32, + "intermediate_size" => 64, + "max_position_embeddings" => 128, + "moe_intermediate_size" => 24, + "num_experts" => 2, + "num_shared_experts" => 1, + "norm_topk_prob" => true, + "num_attention_heads" => 4, + "num_experts_per_tok" => 2, + "num_hidden_layers" => 3, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "vocab_size" => 79, + "first_k_dense_replace" => 1, + "use_bias" => false, + "use_qkv_bias" => false, + "tie_word_embeddings" => false, + "score_function" => "softmax", + "n_group" => 1, + "topk_group" => 1, + "moe_router_enable_expert_bias" => true, + "moe_router_enable_shared_expert" => true, + }) + + model = MlxLm::Models::BailingMoe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 79], output.shape + + weights = { + "model.layers.1.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.0.down_proj.weight" => @mx.array([[0.0, 1.0], [1.0, 0.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.down_proj.weight" => @mx.array([[2.0, 3.0], [4.0, 5.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 1.0], [1.0, 1.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.up_proj.weight" => @mx.array([[2.0, 2.0], [2.0, 2.0]], dtype: @mx.float32), + "model.layers.1.mlp.gate.weight" => @mx.array([[1.0, 0.0], [0.0, 1.0]], dtype: @mx.float32), + "model.layers.1.mlp.gate.bias" => @mx.array([0.1, 0.2], dtype: @mx.float32), + "model.word_embeddings.weight" => @mx.zeros([79, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate = sanitized["model.layers.1.mlp.switch_mlp.gate_proj.weight"] + down = sanitized["model.layers.1.mlp.switch_mlp.down_proj.weight"] + up = sanitized["model.layers.1.mlp.switch_mlp.up_proj.weight"] + @mx.eval(gate, down, up) + + assert_equal [2, 2, 2], gate.shape + assert_equal [2, 2, 2], down.shape + assert_equal [2, 2, 2], up.shape + + assert sanitized.key?("model.layers.1.mlp.gate.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.gate.gate_proj.bias") + refute sanitized.key?("model.layers.1.mlp.gate.weight") + refute sanitized.key?("model.layers.1.mlp.gate.bias") + refute sanitized.key?("model.layers.1.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.gate_proj.weight") + assert sanitized.key?("model.word_embeddings.weight") + end +end + +class Phase26DenseLaneAKRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("afmoe"), "afmoe should be registered" + assert MlxLm::Models::REGISTRY.key?("bailing_moe"), "bailing_moe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "afmoe" }) + assert_equal MlxLm::Models::Afmoe::Model, model_class + assert_equal MlxLm::Models::Afmoe::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "bailing_moe" }) + assert_equal MlxLm::Models::BailingMoe::Model, model_class + assert_equal MlxLm::Models::BailingMoe::ModelArgs, args_class + end +end diff --git a/test/parity/apertus_youtu_llm_models_test.rb b/test/parity/apertus_youtu_llm_models_test.rb new file mode 100644 index 0000000..7034c92 --- /dev/null +++ b/test/parity/apertus_youtu_llm_models_test.rb @@ -0,0 +1,98 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/apertus" +require_relative "../../lib/mlx_lm/models/youtu_llm" + +class Phase18DenseLaneKApertusTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_apertus_construct_forward_shape_and_registry_resolution + args = MlxLm::Models::Apertus::ModelArgs.from_dict({ + "model_type" => "apertus", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "mlp_bias" => false, + "num_attention_heads" => 4, + "attention_bias" => false, + "rms_norm_eps" => 1e-6, + "vocab_size" => 96, + "num_key_value_heads" => 2, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "post_norm" => false, + "qk_norm" => true, + "tie_word_embeddings" => false, + "rope_traditional" => false, + }) + + model = MlxLm::Models::Apertus::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + assert MlxLm::Models::REGISTRY.key?("apertus"), "apertus should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "apertus" }) + assert_equal MlxLm::Models::Apertus::Model, model_class + assert_equal MlxLm::Models::Apertus::ModelArgs, args_class + end +end + +class Phase18DenseLaneKYoutuLLMTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_youtu_llm_construct_forward_shape_and_registry_resolution + args = MlxLm::Models::YoutuLLM::ModelArgs.from_dict({ + "model_type" => "youtu_llm", + "vocab_size" => 128, + "hidden_size" => 64, + "intermediate_size" => 128, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "kv_lora_rank" => 16, + "q_lora_rank" => 24, + "qk_rope_head_dim" => 8, + "v_head_dim" => 16, + "qk_nope_head_dim" => 8, + "max_position_embeddings" => 256, + "rms_norm_eps" => 1e-6, + "rope_theta" => 10_000.0, + "rope_traditional" => true, + "attention_bias" => false, + "mlp_bias" => false, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::YoutuLLM::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 3, 128], output.shape + assert MlxLm::Models::REGISTRY.key?("youtu_llm"), "youtu_llm should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "youtu_llm" }) + assert_equal MlxLm::Models::YoutuLLM::Model, model_class + assert_equal MlxLm::Models::YoutuLLM::ModelArgs, args_class + end +end diff --git a/test/parity/bitlinear_quantization_test.rb b/test/parity/bitlinear_quantization_test.rb new file mode 100644 index 0000000..92dd676 --- /dev/null +++ b/test/parity/bitlinear_quantization_test.rb @@ -0,0 +1,156 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/bitlinear_layers" + +class Phase15DummyBitnetBlock < MLX::NN::Module + def initialize + super() + self.keep = MLX::NN::Linear.new(4, 6, bias: true) + self.skip = MLX::NN::Linear.new(6, 4, bias: false) + end +end + +class Phase15DummyBitnetModel < MLX::NN::Module + def initialize + super() + self.block = Phase15DummyBitnetBlock.new + self.head = MLX::NN::Linear.new(4, 3, bias: false) + end +end + +class Phase15BitLinearTest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_bitlinear_forward_fallback_matches_dense_reference + layer = MlxLm::Models::BitLinear.new(3, 5, bias: true, invert_weight_scales: false) + + full_weights = [ + [1, 0, -1], + [-1, 1, 0], + [0, 1, 1], + [1, -1, 0], + [-1, 0, 1], + ] + layer.weight = @mx.array(pack_ternary_weights(full_weights), dtype: @mx.uint8) + layer.weight_scale = @mx.array([2.0], dtype: @mx.float32) + layer.bias = @mx.array([0.25, -0.5, 0.0, 1.0, -1.5], dtype: @mx.float32) + + x = @mx.array( + [ + [[1.0, 2.0, -1.0], [0.5, 0.0, 1.0]], + [[-2.0, 1.0, 3.0], [1.5, -0.5, 0.25]], + ], + dtype: @mx.float32 + ) + + dense_weight = @mx.array(full_weights, dtype: @mx.float32) + expected = @mx.add(@mx.multiply(@mx.matmul(x, dense_weight.T), 2.0), layer.bias) + actual = layer.call(x) + + @mx.eval(expected, actual) + assert_equal [2, 2, 5], actual.shape + assert_equal expected.shape, actual.shape + assert_arrays_close(expected.tolist, actual.tolist, atol: 1e-6, msg: "bitlinear fallback output mismatch") + end + + def test_bitlinear_inverted_weight_scale_matches_reference + layer = MlxLm::Models::BitLinear.new(2, 4, bias: false, invert_weight_scales: true) + + full_weights = [ + [1, -1], + [0, 1], + [-1, 0], + [1, 1], + ] + layer.weight = @mx.array(pack_ternary_weights(full_weights), dtype: @mx.uint8) + layer.weight_scale = @mx.array([4.0], dtype: @mx.float32) + + x = @mx.array( + [ + [2.0, -1.0], + [0.5, 0.25], + [-3.0, 1.0], + ], + dtype: @mx.float32 + ) + + dense_weight = @mx.array(full_weights, dtype: @mx.float32) + expected = @mx.multiply(@mx.matmul(x, dense_weight.T), 0.25) + actual = layer.call(x) + + @mx.eval(expected, actual) + assert_equal expected.shape, actual.shape + assert_arrays_close(expected.tolist, actual.tolist, atol: 1e-6, msg: "inverted scale mismatch") + end + + def test_bitnet_quantize_replaces_linear_layers_with_skip_list_support + model = Phase15DummyBitnetModel.new + converted = MlxLm::Models.bitnet_quantize( + model, + { + modules_to_not_convert: ["block.skip"], + linear_class: "autobitlinear", + } + ) + + assert_same model, converted + + assert_instance_of MlxLm::Models::BitLinear, model.block.keep + assert_instance_of MLX::NN::Linear, model.block.skip + assert_instance_of MlxLm::Models::BitLinear, model.head + + assert_equal false, model.block.keep.invert_weight_scales + assert_equal false, model.head.invert_weight_scales + + assert model.block.keep.state.key?("bias") + assert_equal [2, 4], model.block.keep.weight.shape + refute model.head.state.key?("bias") + end + + def test_bitnet_quantize_defaults_to_inverted_weight_scales + model = Phase15DummyBitnetModel.new + MlxLm::Models.bitnet_quantize(model, {}) + + assert_instance_of MlxLm::Models::BitLinear, model.block.keep + assert_equal true, model.block.keep.invert_weight_scales + end + + private + + def pack_ternary_weights(full_weights) + out_features = full_weights.length + in_features = full_weights.first.length + packed_out_features = (out_features + 3) / 4 + + packed = Array.new(packed_out_features) { Array.new(in_features, 0) } + packed_out_features.times do |packed_row| + in_features.times do |input_col| + byte = 0 + 4.times do |group_idx| + out_row = packed_row + (group_idx * packed_out_features) + ternary = out_row < out_features ? full_weights[out_row][input_col] : 0 + encoded = encode_ternary(ternary) + byte |= (encoded << (group_idx * 2)) + end + packed[packed_row][input_col] = byte + end + end + packed + end + + def encode_ternary(value) + case value + when -1 + 0 + when 0 + 1 + when 1 + 2 + else + raise ArgumentError, "Expected ternary value in {-1, 0, 1}, got #{value.inspect}" + end + end +end diff --git a/test/parity/bitnet_openelm_models_test.rb b/test/parity/bitnet_openelm_models_test.rb new file mode 100644 index 0000000..2b9d9d1 --- /dev/null +++ b/test/parity/bitnet_openelm_models_test.rb @@ -0,0 +1,102 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/bitlinear_layers" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/bitnet" +require_relative "../../lib/mlx_lm/models/openelm" + +class Phase21DenseLaneUBitnetTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_bitnet_construct_forward_shape_and_sanitize_tied_embeddings + args = MlxLm::Models::Bitnet::ModelArgs.from_dict({ + "model_type" => "bitnet", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "head_dim" => 8, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Bitnet::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 101], output.shape + + weights = { + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("lm_head.weight") + assert sanitized.key?("model.embed_tokens.weight") + end +end + +class Phase21DenseLaneUOpenELMTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_openelm_construct_forward_shape + args = MlxLm::Models::OpenELM::ModelArgs.from_dict({ + "model_type" => "openelm", + "head_dim" => 8, + "num_transformer_layers" => 2, + "model_dim" => 32, + "vocab_size" => 89, + "ffn_dim_divisor" => 8, + "num_query_heads" => [4, 4], + "num_kv_heads" => [2, 2], + "ffn_multipliers" => [2.0, 2.5], + "normalize_qk_projections" => true, + "share_input_output_layers" => false, + }) + + model = MlxLm::Models::OpenELM::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 89], output.shape + assert_equal 2, model.layers.length + refute_nil model.lm_head + end +end + +class Phase21DenseLaneURegistryTest < Minitest::Test + def test_models_registered_and_resolvable + assert MlxLm::Models::REGISTRY.key?("bitnet"), "bitnet should be registered" + assert MlxLm::Models::REGISTRY.key?("openelm"), "openelm should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "bitnet" }) + assert_equal MlxLm::Models::Bitnet::Model, model_class + assert_equal MlxLm::Models::Bitnet::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "openelm" }) + assert_equal MlxLm::Models::OpenELM::Model, model_class + assert_equal MlxLm::Models::OpenELM::ModelArgs, args_class + end +end diff --git a/test/parity/cache_extensions_contract_test.rb b/test/parity/cache_extensions_contract_test.rb new file mode 100644 index 0000000..00d87f5 --- /dev/null +++ b/test/parity/cache_extensions_contract_test.rb @@ -0,0 +1,136 @@ +require_relative "../test_helper" + +class Phase14CacheExtensionsTest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_base_cache_default_contract + cache = MlxLm::BaseCache.new + + assert_equal [], cache.state + assert_equal "", cache.meta_state + assert_raises(ArgumentError) { cache.state = [@mx.ones([1])] } + assert_raises(ArgumentError) { cache.meta_state = "invalid" } + end + + def test_arrays_cache_masks_extract_and_finalize + cache = MlxLm::ArraysCache.new(2, left_padding: [1, 2]) + cache[0] = @mx.ones([2, 3]) + cache[1] = @mx.ones([2, 3]) * 2.0 + + mask = cache.make_mask(4) + @mx.eval(mask) + assert_equal [[false, true, true, true], [false, false, true, true]], mask.tolist + + cache.left_padding = nil + cache.prepare(lengths: [3, 1]) + len_mask = cache.make_mask(4) + @mx.eval(len_mask) + assert_equal [[true, true, true, false], [true, false, false, false]], len_mask.tolist + + extracted = cache.extract(1) + @mx.eval(extracted[0], extracted[1]) + assert_equal [1, 3], extracted[0].shape + assert_equal [1, 3], extracted[1].shape + + cache.advance(1) + advanced_mask = cache.make_mask(4) + @mx.eval(advanced_mask) + assert_equal [[true, true, false, false], [false, false, false, false]], advanced_mask.tolist + + cache.finalize + assert_nil cache.make_mask(4) + end + + def test_quantized_kv_cache_update_trim_and_restore + cache = MlxLm::QuantizedKVCache.new(group_size: 32, bits: 4) + + k1 = @mx.ones([1, 2, 3, 32]) + v1 = @mx.ones([1, 2, 3, 32]) * 2.0 + qk, qv = cache.update_and_fetch(k1, v1) + @mx.eval(*qk, *qv) + + assert_equal 3, cache.offset + assert_equal [1, 2, 3, 4], qk[0].shape + assert_equal [1, 2, 3, 1], qk[1].shape + + k2 = @mx.ones([1, 2, 2, 32]) * 3.0 + v2 = @mx.ones([1, 2, 2, 32]) * 4.0 + qk, qv = cache.update_and_fetch(k2, v2) + @mx.eval(*qk, *qv) + + assert_equal 5, cache.offset + assert_equal [1, 2, 5, 4], qk[0].shape + + trimmed = cache.trim(2) + assert_equal 2, trimmed + assert_equal 3, cache.offset + assert_equal 3, cache.state[0][0].shape[2] + + restored = MlxLm::QuantizedKVCache.from_state(cache.state, cache.meta_state) + assert_equal 3, restored.offset + assert_equal 32, restored.group_size + assert_equal 4, restored.bits + end + + def test_chunked_kv_cache_front_trim_and_trim + cache = MlxLm::ChunkedKVCache.new(4) + + k1 = @mx.ones([1, 1, 3, 8]) + v1 = @mx.ones([1, 1, 3, 8]) + cache.update_and_fetch(k1, v1) + + k2 = @mx.ones([1, 1, 3, 8]) * 2.0 + v2 = @mx.ones([1, 1, 3, 8]) * 3.0 + keys, values = cache.update_and_fetch(k2, v2) + @mx.eval(keys, values) + + assert_equal [1, 1, 6, 8], keys.shape + assert_equal 6, cache.offset + + cache.maybe_trim_front + keys, values = cache.state + @mx.eval(keys, values) + + assert_equal 2, cache.start_position + assert_equal [1, 1, 4, 8], keys.shape + assert_equal 4, cache.size + + trimmed = cache.trim(1) + assert_equal 1, trimmed + assert_equal 5, cache.offset + assert_equal 3, cache.size + assert_equal [1, 1, 3, 8], cache.state[0].shape + end + + def test_cache_list_trim_and_restore + k = @mx.ones([1, 1, 4, 8]) + v = @mx.ones([1, 1, 4, 8]) + + kv1 = MlxLm::KVCache.new + kv2 = MlxLm::KVCache.new + kv1.update_and_fetch(k, v) + kv2.update_and_fetch(k, v) + + list = MlxLm::CacheList.new(kv1, kv2) + assert list.is_trimmable + assert_equal 4, list.size + + trimmed = list.trim(2) + assert_equal 2, trimmed + assert_equal 2, kv1.offset + assert_equal 2, kv2.offset + + arr = MlxLm::ArraysCache.new(1) + arr[0] = @mx.ones([1, 4]) + mixed = MlxLm::CacheList.new(arr, MlxLm::KVCache.new) + refute mixed.is_trimmable + + restored = MlxLm::CacheList.from_state(mixed.state, mixed.meta_state) + assert_instance_of MlxLm::ArraysCache, restored[0] + assert_instance_of MlxLm::KVCache, restored[1] + end +end diff --git a/test/parity/phase10_test.rb b/test/parity/cli_server_schema_chat_template_test.rb similarity index 100% rename from test/parity/phase10_test.rb rename to test/parity/cli_server_schema_chat_template_test.rb diff --git a/test/parity/cohere2_internlm3_models_test.rb b/test/parity/cohere2_internlm3_models_test.rb new file mode 100644 index 0000000..3461ee3 --- /dev/null +++ b/test/parity/cohere2_internlm3_models_test.rb @@ -0,0 +1,104 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/cohere2" +require_relative "../../lib/mlx_lm/models/internlm3" + +class Phase17DenseLaneECohere2Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_cohere2_construct_forward_shape_and_cache_pattern + args = MlxLm::Models::Cohere2::ModelArgs.from_dict({ + "model_type" => "cohere2", + "hidden_size" => 32, + "head_dim" => 16, + "num_hidden_layers" => 4, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "vocab_size" => 128, + "sliding_window" => 8, + "sliding_window_pattern" => 2, + "attention_bias" => false, + "layer_norm_bias" => false, + "logit_scale" => 0.5, + }) + + model = MlxLm::Models::Cohere2::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 128], output.shape + + caches = model.make_cache + assert_equal 4, caches.length + assert_instance_of MlxLm::RotatingKVCache, caches[0] + assert_instance_of MlxLm::KVCache, caches[1] + assert_instance_of MlxLm::RotatingKVCache, caches[2] + assert_instance_of MlxLm::KVCache, caches[3] + end +end + +class Phase17DenseLaneEInternLM3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_internlm3_construct_forward_shape_and_rope_scaling_validation + args = MlxLm::Models::InternLM3::ModelArgs.from_dict({ + "model_type" => "internlm3", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-6, + "vocab_size" => 96, + "bias" => true, + "qkv_bias" => true, + "rope_scaling" => { + "factor" => 2.0, + "rope_type" => "linear", + }, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::InternLM3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + + error = assert_raises(ArgumentError) do + MlxLm::Models::InternLM3::ModelArgs.from_dict({ + "model_type" => "internlm3", + "hidden_size" => 32, + "num_hidden_layers" => 1, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "vocab_size" => 96, + "rms_norm_eps" => 1e-6, + "rope_scaling" => { + "factor" => 2.0, + "rope_type" => "unsupported", + }, + }) + end + assert_match("rope_type", error.message) + end +end diff --git a/test/parity/deepseek_v2_v3_models_test.rb b/test/parity/deepseek_v2_v3_models_test.rb new file mode 100644 index 0000000..4571bef --- /dev/null +++ b/test/parity/deepseek_v2_v3_models_test.rb @@ -0,0 +1,129 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/deepseek" +require_relative "../../lib/mlx_lm/models/deepseek_v2" +require_relative "../../lib/mlx_lm/models/deepseek_v3" + +class Phase22DenseLaneYDeepseekV2Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_deepseek_v2_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::DeepseekV2::ModelArgs.from_dict({ + "model_type" => "deepseek_v2", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "vocab_size" => 97, + "n_routed_experts" => 2, + "num_experts_per_tok" => 1, + "n_shared_experts" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + }) + + model = MlxLm::Models::DeepseekV2::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + weights = { + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + @mx.eval(stacked) + + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert_equal [2, 2, 2], stacked.shape + end +end + +class Phase22DenseLaneYDeepseekV3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_deepseek_v3_construct_forward_shape_and_sanitize_prunes_extra_keys + args = MlxLm::Models::DeepseekV3::ModelArgs.from_dict({ + "model_type" => "deepseek_v3", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "vocab_size" => 103, + "n_routed_experts" => 2, + "num_experts_per_tok" => 1, + "n_shared_experts" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + }) + + model = MlxLm::Models::DeepseekV3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 103], output.shape + + weights = { + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "model.layers.61.mlp.down_proj.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + @mx.eval(stacked) + + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("model.layers.61.mlp.down_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + assert_equal [2, 2, 2], stacked.shape + end +end + +class Phase22DenseLaneYRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("deepseek_v2"), "deepseek_v2 should be registered" + assert MlxLm::Models::REGISTRY.key?("deepseek_v3"), "deepseek_v3 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "deepseek_v2" }) + assert_equal MlxLm::Models::DeepseekV2::Model, model_class + assert_equal MlxLm::Models::DeepseekV2::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "deepseek_v3" }) + assert_equal MlxLm::Models::DeepseekV3::Model, model_class + assert_equal MlxLm::Models::DeepseekV3::ModelArgs, args_class + end +end diff --git a/test/parity/deepseek_v32_glm_moe_dsa_models_test.rb b/test/parity/deepseek_v32_glm_moe_dsa_models_test.rb new file mode 100644 index 0000000..53be200 --- /dev/null +++ b/test/parity/deepseek_v32_glm_moe_dsa_models_test.rb @@ -0,0 +1,144 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/deepseek" +require_relative "../../lib/mlx_lm/models/deepseek_v32" +require_relative "../../lib/mlx_lm/models/glm_moe_dsa" + +class Phase22DenseLaneZDeepseekV32Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_deepseek_v32_construct_forward_shape_and_sanitize + args = MlxLm::Models::DeepseekV32::ModelArgs.from_dict({ + "model_type" => "deepseek_v32", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "vocab_size" => 97, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "n_routed_experts" => 2, + "n_shared_experts" => 1, + "num_experts_per_tok" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + }) + + model = MlxLm::Models::DeepseekV32::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 97], output.shape + + weights = { + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.layers.2.mlp.gate_proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + @mx.eval(stacked) + + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("model.layers.2.mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + assert_equal [2, 2, 2], stacked.shape + assert_equal [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], stacked.to_a + end +end + +class Phase22DenseLaneZGlmMoeDsaTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_glm_moe_dsa_rope_parameters_mapping_construct_forward_shape_and_sanitize + args = MlxLm::Models::GlmMoeDsa::ModelArgs.from_dict({ + "model_type" => "glm_moe_dsa", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "vocab_size" => 83, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "n_routed_experts" => 2, + "n_shared_experts" => 1, + "num_experts_per_tok" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "rope_parameters" => { + "rope_theta" => 12_345.0, + "type" => "yarn", + "factor" => 8.0, + }, + }) + + assert_equal args.rope_parameters, args.rope_scaling + assert_equal 12_345.0, args.rope_theta + + model = MlxLm::Models::GlmMoeDsa::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 83], output.shape + + weights = { + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "model.layers.2.mlp.up_proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + @mx.eval(stacked) + + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("model.layers.2.mlp.up_proj.weight") + assert_equal [2, 2, 2], stacked.shape + end +end + +class Phase22DenseLaneZRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("deepseek_v32"), "deepseek_v32 should be registered" + assert MlxLm::Models::REGISTRY.key?("glm_moe_dsa"), "glm_moe_dsa should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "deepseek_v32" }) + assert_equal MlxLm::Models::DeepseekV32::Model, model_class + assert_equal MlxLm::Models::DeepseekV32::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "glm_moe_dsa" }) + assert_equal MlxLm::Models::GlmMoeDsa::Model, model_class + assert_equal MlxLm::Models::GlmMoeDsa::ModelArgs, args_class + end +end diff --git a/test/parity/ernie45_baichuan_m1_models_test.rb b/test/parity/ernie45_baichuan_m1_models_test.rb new file mode 100644 index 0000000..9ae7519 --- /dev/null +++ b/test/parity/ernie45_baichuan_m1_models_test.rb @@ -0,0 +1,104 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/ernie4_5" +require_relative "../../lib/mlx_lm/models/baichuan_m1" + +class Phase18DenseLaneLErnie45Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_ernie45_construct_forward_shape_and_head_dim_default + args = MlxLm::Models::Ernie45::ModelArgs.from_dict({ + "model_type" => "ernie4_5", + "hidden_size" => 48, + "intermediate_size" => 96, + "max_position_embeddings" => 256, + "num_attention_heads" => 3, + "num_key_value_heads" => 3, + "num_hidden_layers" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 97, + "rope_theta" => 10_000.0, + "use_bias" => false, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Ernie45::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 97], output.shape + assert_equal 16, model.layers.first.self_attn.instance_variable_get(:@head_dim) + end +end + +class Phase18DenseLaneLBaichuanM1Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_baichuan_m1_construct_forward_shape_and_sanitize_normalizes_lm_head + args = MlxLm::Models::BaichuanM1::ModelArgs.from_dict({ + "model_type" => "baichuan_m1", + "vocab_size" => 96, + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rope_theta" => 10_000.0, + "sliding_window" => 4, + "sliding_window_layers" => [0], + "conv_window" => 2, + "rms_norm_eps" => 1e-5, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::BaichuanM1::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + + weights = { + "lm_head.weight" => @mx.array([[3.0, 4.0], [0.0, 5.0]], dtype: @mx.float32), + } + sanitized = model.sanitize(weights) + norms = @mx.norm(sanitized["lm_head.weight"], nil, -1) + @mx.eval(norms) + + assert_in_delta 1.0, norms.to_a[0], 1e-5 + assert_in_delta 1.0, norms.to_a[1], 1e-5 + end +end + +class Phase18DenseLaneLRegistryTest < Minitest::Test + def test_models_registered + assert MlxLm::Models::REGISTRY.key?("ernie4_5"), "ernie4_5 should be registered" + assert MlxLm::Models::REGISTRY.key?("baichuan_m1"), "baichuan_m1 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "ernie4_5" }) + assert_equal MlxLm::Models::Ernie45::Model, model_class + assert_equal MlxLm::Models::Ernie45::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "baichuan_m1" }) + assert_equal MlxLm::Models::BaichuanM1::Model, model_class + assert_equal MlxLm::Models::BaichuanM1::ModelArgs, args_class + end +end diff --git a/test/parity/exaone4_nanochat_models_test.rb b/test/parity/exaone4_nanochat_models_test.rb new file mode 100644 index 0000000..bf6310d --- /dev/null +++ b/test/parity/exaone4_nanochat_models_test.rb @@ -0,0 +1,90 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/exaone4" +require_relative "../../lib/mlx_lm/models/nanochat" + +class Phase17DenseLaneHExaone4Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_exaone4_construct_forward_shape_and_cache_pattern + args = MlxLm::Models::Exaone4::ModelArgs.from_dict({ + "model_type" => "exaone4", + "hidden_size" => 64, + "num_hidden_layers" => 4, + "intermediate_size" => 128, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 96, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "head_dim" => 16, + "tie_word_embeddings" => false, + "sliding_window" => 4, + "sliding_window_pattern" => "LLGL", + }) + + model = MlxLm::Models::Exaone4::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + + caches = model.make_cache + assert_equal 4, caches.length + assert_instance_of MlxLm::RotatingKVCache, caches[0] + assert_instance_of MlxLm::RotatingKVCache, caches[1] + assert_instance_of MlxLm::KVCache, caches[2] + assert_instance_of MlxLm::RotatingKVCache, caches[3] + end +end + +class Phase17DenseLaneHNanochatTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_nanochat_construct_forward_shape_and_softcap + args = MlxLm::Models::Nanochat::ModelArgs.from_dict({ + "model_type" => "nanochat", + "hidden_size" => 40, + "num_hidden_layers" => 2, + "num_attention_heads" => 5, + "num_key_value_heads" => 5, + "vocab_size" => 64, + "intermediate_size" => 80, + "rope_theta" => 10_000.0, + }) + + model = MlxLm::Models::Nanochat::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 3, 64], output.shape + + values = output.to_a.flatten + assert values.all? { |v| v <= 15.0 + 1e-5 && v >= -15.0 - 1e-5 }, "nanochat logits should be softcapped to [-15, 15]" + end +end + +class Phase17DenseLaneHRegistryTest < Minitest::Test + def test_models_registered + assert MlxLm::Models::REGISTRY.key?("exaone4"), "exaone4 should be registered" + assert MlxLm::Models::REGISTRY.key?("nanochat"), "nanochat should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "exaone4" }) + assert_equal MlxLm::Models::Exaone4::Model, model_class + assert_equal MlxLm::Models::Exaone4::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "nanochat" }) + assert_equal MlxLm::Models::Nanochat::Model, model_class + assert_equal MlxLm::Models::Nanochat::ModelArgs, args_class + end +end diff --git a/test/parity/exaone_moe_glm4_moe_models_test.rb b/test/parity/exaone_moe_glm4_moe_models_test.rb new file mode 100644 index 0000000..d7318f5 --- /dev/null +++ b/test/parity/exaone_moe_glm4_moe_models_test.rb @@ -0,0 +1,152 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/exaone_moe" +require_relative "../../lib/mlx_lm/models/glm4_moe" + +class Phase24DenseLaneALExaoneMoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_exaone_moe_construct_forward_shape_sanitize_and_make_cache + args = MlxLm::Models::ExaoneMoe::ModelArgs.from_dict({ + "model_type" => "exaone_moe", + "vocab_size" => 97, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "head_dim" => 8, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "num_shared_experts" => 1, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 256, + "sliding_window" => 2, + "layer_types" => ["full_attention", "sliding_attention"], + "is_moe_layer" => [true, false], + "n_group" => 1, + "topk_group" => 1, + "routed_scaling_factor" => 1.0, + "norm_topk_prob" => true, + "rope_theta" => 10_000.0, + "rope_parameters" => { "rope_theta" => 20_000.0 }, + "tie_word_embeddings" => true, + }) + + assert_equal 20_000.0, args.rope_theta + + model = MlxLm::Models::ExaoneMoe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + weights = { + "mtp.head.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([97, 32]).astype(@mx.float32), + "model.layers.0.mlp.e_score_correction_bias" => @mx.array([0.1, -0.1], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + @mx.eval(stacked) + + refute sanitized.key?("mtp.head.weight") + refute sanitized.key?("lm_head.weight") + refute sanitized.key?("model.layers.0.mlp.e_score_correction_bias") + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.gate.e_score_correction_bias") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + assert_equal [2, 2, 2], stacked.shape + + cache = model.make_cache + assert_equal 2, cache.length + assert_instance_of MlxLm::KVCache, cache[0] + assert_instance_of MlxLm::RotatingKVCache, cache[1] + end +end + +class Phase24DenseLaneALGlm4MoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_glm4_moe_construct_forward_shape_and_sanitize + args = MlxLm::Models::Glm4Moe::ModelArgs.from_dict({ + "model_type" => "glm4_moe", + "vocab_size" => 101, + "hidden_size" => 32, + "intermediate_size" => 64, + "max_position_embeddings" => 256, + "moe_intermediate_size" => 16, + "norm_topk_prob" => true, + "num_attention_heads" => 4, + "n_group" => 1, + "head_dim" => 8, + "topk_group" => 1, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "num_experts_per_tok" => 1, + "first_k_dense_replace" => 0, + "num_hidden_layers" => 2, + "num_key_value_heads" => 4, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "use_qk_norm" => true, + "tie_word_embeddings" => false, + "attention_bias" => false, + "partial_rotary_factor" => 0.5, + }) + + model = MlxLm::Models::Glm4Moe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + + weights = { + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.layers.2.mlp.up_proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + @mx.eval(stacked) + + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.2.mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + assert_equal [2, 2, 2], stacked.shape + end +end + +class Phase24DenseLaneALRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("exaone_moe"), "exaone_moe should be registered" + assert MlxLm::Models::REGISTRY.key?("glm4_moe"), "glm4_moe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "exaone_moe" }) + assert_equal MlxLm::Models::ExaoneMoe::Model, model_class + assert_equal MlxLm::Models::ExaoneMoe::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "glm4_moe" }) + assert_equal MlxLm::Models::Glm4Moe::Model, model_class + assert_equal MlxLm::Models::Glm4Moe::ModelArgs, args_class + end +end diff --git a/test/parity/falcon_h1_glm4_moe_lite_models_test.rb b/test/parity/falcon_h1_glm4_moe_lite_models_test.rb new file mode 100644 index 0000000..8d6ba87 --- /dev/null +++ b/test/parity/falcon_h1_glm4_moe_lite_models_test.rb @@ -0,0 +1,154 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/falcon_h1" +require_relative "../../lib/mlx_lm/models/glm4_moe_lite" + +class Phase26DenseLaneAPFalconH1Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_falcon_h1_wrapper_construct_forward_shape_sanitize_mapping_and_cache + args = MlxLm::Models::FalconH1::ModelArgs.from_dict({ + "model_type" => "falcon_h1", + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "mamba_d_conv" => 3, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "vocab_size" => 89, + "max_position_embeddings" => 128, + "attention_window_size" => 4, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::FalconH1::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 89], output.shape + + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 1, 6]) + weights = { + "model.layers.0.mamba.conv1d.weight" => conv_weight, + "model.layers.0.feed_forward.gate_proj.weight" => @mx.zeros([64, 32]).astype(@mx.float32), + "model.layers.0.temporal_block.linear_x.weight" => @mx.ones([32, 32]).astype(@mx.float32), + "model.final_layernorm.weight" => @mx.ones([32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["model.layers.0.temporal_block.conv_1d.weight"] + sanitized_mlp = sanitized["model.layers.0.mlp_block.gate_proj.weight"] + sanitized_pass_through = sanitized["model.layers.0.temporal_block.linear_x.weight"] + sanitized_final_norm = sanitized["model.final_norm.weight"] + @mx.eval(sanitized_conv, sanitized_mlp, sanitized_pass_through, sanitized_final_norm) + + refute sanitized.key?("model.layers.0.mamba.conv1d.weight") + refute sanitized.key?("model.layers.0.feed_forward.gate_proj.weight") + refute sanitized.key?("model.final_layernorm.weight") + assert sanitized.key?("model.layers.0.temporal_block.conv_1d.weight") + assert sanitized.key?("model.layers.0.mlp_block.gate_proj.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + assert sanitized.key?("model.final_norm.weight") + assert_equal [4, 6, 1], sanitized_conv.shape + assert_equal @mx.swapaxes(conv_weight, 1, 2).to_a, sanitized_conv.to_a + + cache = model.make_cache + assert_equal 2, cache.length + end +end + +class Phase26DenseLaneAPGlm4MoeLiteTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_glm4_moe_lite_wrapper_construct_forward_shape_and_sanitize_mapping_and_stacking + args = MlxLm::Models::Glm4MoeLite::ModelArgs.from_dict({ + "model_type" => "glm4_moe_lite", + "vocab_size" => 101, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "kv_lora_rank" => 8, + "q_lora_rank" => 8, + "qk_rope_head_dim" => 8, + "qk_nope_head_dim" => 8, + "v_head_dim" => 8, + "topk_method" => "noaux_tc", + "scoring_func" => "sigmoid", + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "num_experts_per_tok" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "max_position_embeddings" => 256, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "rope_scaling" => nil, + "attention_bias" => false, + "partial_rotary_factor" => 1.0, + "tie_word_embeddings" => false, + "num_nextn_predict_layers" => 1, + }) + + model = MlxLm::Models::Glm4MoeLite::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + + embed_q_weight = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + weights = { + "model.layers.0.self_attn.embed_q.weight" => embed_q_weight, + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.layers.2.mlp.up_proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + mapped_q_proj = sanitized["model.layers.0.self_attn.q_proj.weight"] + pass_through_q_proj = sanitized["model.layers.1.self_attn.q_proj.weight"] + @mx.eval(stacked, mapped_q_proj, pass_through_q_proj) + + refute sanitized.key?("model.layers.0.self_attn.embed_q.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.2.mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + assert_equal [2, 2, 2], stacked.shape + assert_equal embed_q_weight.to_a, mapped_q_proj.to_a + end +end + +class Phase26DenseLaneAPRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("falcon_h1"), "falcon_h1 should be registered" + assert MlxLm::Models::REGISTRY.key?("glm4_moe_lite"), "glm4_moe_lite should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "falcon_h1" }) + assert_equal MlxLm::Models::FalconH1::Model, model_class + assert_equal MlxLm::Models::FalconH1::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "glm4_moe_lite" }) + assert_equal MlxLm::Models::Glm4MoeLite::Model, model_class + assert_equal MlxLm::Models::Glm4MoeLite::ModelArgs, args_class + end +end diff --git a/test/parity/gated_delta_activations_parity_test.rb b/test/parity/gated_delta_activations_parity_test.rb new file mode 100644 index 0000000..5eda3b0 --- /dev/null +++ b/test/parity/gated_delta_activations_parity_test.rb @@ -0,0 +1,194 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/gated_delta" + +class Phase15ActivationsParityTest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_swiglu_xielu_and_xielu_module_match_python + x = @mx.array([[-2.0, -0.5, 0.0, 1.5]], dtype: @mx.float32) + gate = @mx.array([[0.25, -1.0, 0.5, 2.0]], dtype: @mx.float32) + alpha_p = @mx.array(0.2, dtype: @mx.float32) + alpha_n = @mx.array(-0.1, dtype: @mx.float32) + beta = @mx.array(0.5, dtype: @mx.float32) + eps = @mx.array(-1e-6, dtype: @mx.float32) + + swiglu = MlxLm::Models::Activations.swiglu(gate, x) + xielu = MlxLm::Models::Activations.xielu(x, alpha_p, alpha_n, beta, eps) + layer = MlxLm::Models::Activations::XieLU.new + layer_out = layer.call(x) + + @mx.eval(swiglu, xielu, layer_out, layer.alpha_p, layer.alpha_n, layer.beta, layer.eps) + + py = python_eval(<<~PY) + import json + import sys + import mlx.core as mx + + sys.path.insert(0, "mlx-lm") + from mlx_lm.models.activations import swiglu, xielu, XieLU + + x = mx.array([[-2.0, -0.5, 0.0, 1.5]], dtype=mx.float32) + gate = mx.array([[0.25, -1.0, 0.5, 2.0]], dtype=mx.float32) + alpha_p = mx.array(0.2, dtype=mx.float32) + alpha_n = mx.array(-0.1, dtype=mx.float32) + beta = mx.array(0.5, dtype=mx.float32) + eps = mx.array(-1e-6, dtype=mx.float32) + + y_swiglu = swiglu(gate, x) + y_xielu = xielu(x, alpha_p, alpha_n, beta, eps) + + layer = XieLU() + y_layer = layer(x) + + mx.eval(y_swiglu, y_xielu, y_layer, layer.alpha_p, layer.alpha_n, layer.beta, layer.eps) + print(json.dumps({ + "swiglu": y_swiglu.tolist(), + "xielu": y_xielu.tolist(), + "xielu_layer": y_layer.tolist(), + "alpha_p": float(layer.alpha_p), + "alpha_n": float(layer.alpha_n), + "beta": float(layer.beta), + "eps": float(layer.eps), + })) + PY + + assert_arrays_close py["swiglu"], swiglu.to_a, atol: 1e-6, msg: "swiglu parity" + assert_arrays_close py["xielu"], xielu.to_a, atol: 1e-6, msg: "xielu parity" + assert_arrays_close py["xielu_layer"], layer_out.to_a, atol: 1e-6, msg: "XieLU layer parity" + assert_in_delta py["alpha_p"], layer.alpha_p.to_a, 1e-6 + assert_in_delta py["alpha_n"], layer.alpha_n.to_a, 1e-6 + assert_in_delta py["beta"], layer.beta.to_a, 1e-6 + assert_in_delta py["eps"], layer.eps.to_a, 1e-6 + end +end + +class Phase15GatedDeltaParityTest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_gated_delta_ops_scalar_gating_masked_matches_python + q, k, v, state, mask = base_tensors + + g = @mx.sigmoid((@mx.arange(0, 12, 1, @mx.float32).reshape([1, 3, 4]) - 6.0) / 4.0) + beta = @mx.sigmoid((@mx.arange(0, 12, 1, @mx.float32).reshape([1, 3, 4]) - 4.0) / 5.0) + + y_rb, st_rb = MlxLm::Models::GatedDelta.gated_delta_ops(q, k, v, g, beta, state, mask) + @mx.eval(y_rb, st_rb) + + py = python_eval(<<~PY) + import json + import sys + import mlx.core as mx + + sys.path.insert(0, "mlx-lm") + from mlx_lm.models.gated_delta import gated_delta_ops + + q = (mx.arange(0, 18, dtype=mx.float32).reshape(1, 3, 2, 3) - 9.0) / 8.0 + k = (mx.arange(100, 118, dtype=mx.float32).reshape(1, 3, 2, 3) - 109.0) / 9.0 + v = (mx.arange(0, 24, dtype=mx.float32).reshape(1, 3, 4, 2) - 12.0) / 7.0 + state = (mx.arange(0, 24, dtype=mx.float32).reshape(1, 4, 2, 3) - 10.0) / 6.0 + mask = mx.array([[True, False, True]]) + + g = mx.sigmoid((mx.arange(0, 12, dtype=mx.float32).reshape(1, 3, 4) - 6.0) / 4.0) + beta = mx.sigmoid((mx.arange(0, 12, dtype=mx.float32).reshape(1, 3, 4) - 4.0) / 5.0) + + y, st = gated_delta_ops(q, k, v, g, beta, state, mask) + mx.eval(y, st) + print(json.dumps({"y": y.tolist(), "state": st.tolist()})) + PY + + assert_arrays_close py["y"], y_rb.to_a, atol: 1e-5, msg: "gated_delta_ops output parity" + assert_arrays_close py["state"], st_rb.to_a, atol: 1e-5, msg: "gated_delta_ops state parity" + end + + def test_gated_delta_update_vectorized_gating_matches_python_and_ops_reference + q, k, v, state, mask = base_tensors + a = (@mx.arange(0, 36, 1, @mx.float32).reshape([1, 3, 4, 3]) - 18.0) / 10.0 + b = (@mx.arange(0, 12, 1, @mx.float32).reshape([1, 3, 4]) - 5.0) / 6.0 + a_log = @mx.log(@mx.array([[1.1], [1.4], [1.8], [2.2]], dtype: @mx.float32)) + dt_bias = @mx.array( + [ + [0.05, -0.02, 0.01], + [0.10, 0.00, -0.10], + [-0.03, 0.07, 0.02], + [0.00, -0.04, 0.08], + ], + dtype: @mx.float32 + ) + + beta = @mx.sigmoid(b) + g = MlxLm::Models::GatedDelta.compute_g(a_log, a, dt_bias) + y_ref, st_ref = MlxLm::Models::GatedDelta.gated_delta_ops(q, k, v, g, beta, state, mask) + + y_update, st_update = MlxLm::Models::GatedDelta.gated_delta_update( + q, k, v, a, b, a_log, dt_bias, state, mask, use_kernel: false + ) + y_kernel, st_kernel = MlxLm::Models::GatedDelta.gated_delta_update( + q, k, v, a, b, a_log, dt_bias, state, mask, use_kernel: true + ) + + @mx.eval(y_ref, st_ref, y_update, st_update, y_kernel, st_kernel) + + assert_arrays_close y_ref.to_a, y_update.to_a, atol: 1e-5, msg: "update should match ops reference (output)" + assert_arrays_close st_ref.to_a, st_update.to_a, atol: 1e-5, msg: "update should match ops reference (state)" + assert_arrays_close y_update.to_a, y_kernel.to_a, atol: 1e-6, msg: "kernel path should match fallback output" + assert_arrays_close st_update.to_a, st_kernel.to_a, atol: 1e-6, msg: "kernel path should match fallback state" + + py = python_eval(<<~PY) + import json + import sys + import mlx.core as mx + + sys.path.insert(0, "mlx-lm") + from mlx_lm.models.gated_delta import gated_delta_update + + q = (mx.arange(0, 18, dtype=mx.float32).reshape(1, 3, 2, 3) - 9.0) / 8.0 + k = (mx.arange(100, 118, dtype=mx.float32).reshape(1, 3, 2, 3) - 109.0) / 9.0 + v = (mx.arange(0, 24, dtype=mx.float32).reshape(1, 3, 4, 2) - 12.0) / 7.0 + state = (mx.arange(0, 24, dtype=mx.float32).reshape(1, 4, 2, 3) - 10.0) / 6.0 + mask = mx.array([[True, False, True]]) + + a = (mx.arange(0, 36, dtype=mx.float32).reshape(1, 3, 4, 3) - 18.0) / 10.0 + b = (mx.arange(0, 12, dtype=mx.float32).reshape(1, 3, 4) - 5.0) / 6.0 + A_log = mx.log(mx.array([[1.1], [1.4], [1.8], [2.2]], dtype=mx.float32)) + dt_bias = mx.array( + [ + [0.05, -0.02, 0.01], + [0.10, 0.00, -0.10], + [-0.03, 0.07, 0.02], + [0.00, -0.04, 0.08], + ], + dtype=mx.float32, + ) + + y, st = gated_delta_update( + q, k, v, a, b, A_log, dt_bias, state=state, mask=mask, use_kernel=False + ) + mx.eval(y, st) + print(json.dumps({"y": y.tolist(), "state": st.tolist()})) + PY + + assert_arrays_close py["y"], y_update.to_a, atol: 1e-5, msg: "gated_delta_update output parity" + assert_arrays_close py["state"], st_update.to_a, atol: 1e-5, msg: "gated_delta_update state parity" + end + + private + + def base_tensors + q = (@mx.arange(0, 18, 1, @mx.float32).reshape([1, 3, 2, 3]) - 9.0) / 8.0 + k = (@mx.arange(100, 118, 1, @mx.float32).reshape([1, 3, 2, 3]) - 109.0) / 9.0 + v = (@mx.arange(0, 24, 1, @mx.float32).reshape([1, 3, 4, 2]) - 12.0) / 7.0 + state = (@mx.arange(0, 24, 1, @mx.float32).reshape([1, 4, 2, 3]) - 10.0) / 6.0 + mask = @mx.array([[true, false, true]]) + + [q, k, v, state, mask] + end +end diff --git a/test/parity/gemma3_text_gemma3_models_test.rb b/test/parity/gemma3_text_gemma3_models_test.rb new file mode 100644 index 0000000..ca45544 --- /dev/null +++ b/test/parity/gemma3_text_gemma3_models_test.rb @@ -0,0 +1,132 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/gemma3_text" +require_relative "../../lib/mlx_lm/models/gemma3" + +class Phase20DenseLaneQGemma3TextTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_gemma3_text_construct_forward_shape_and_sanitize_ties_embeddings + args = MlxLm::Models::Gemma3Text::ModelArgs.from_dict({ + "model_type" => "gemma3_text", + "hidden_size" => 32, + "num_hidden_layers" => 4, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "head_dim" => 8, + "rms_norm_eps" => 1e-6, + "vocab_size" => 83, + "sliding_window" => 8, + "sliding_window_pattern" => 3, + "rope_theta" => 10_000.0, + "rope_local_base_freq" => 10_000.0, + "max_position_embeddings" => 128, + "query_pre_attn_scalar" => 16.0, + }) + + model = MlxLm::Models::Gemma3Text::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 83], output.shape + + cache = model.make_cache + assert_equal 4, cache.length + assert_instance_of MlxLm::RotatingKVCache, cache[0] + assert_instance_of MlxLm::RotatingKVCache, cache[1] + assert_instance_of MlxLm::KVCache, cache[2] + assert_instance_of MlxLm::RotatingKVCache, cache[3] + + weights = { + "model.embed_tokens.weight" => @mx.zeros([83, 32]).astype(@mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "model.norm.weight" => @mx.zeros([32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("model.embed_tokens.weight") + assert sanitized.key?("model.norm.weight") + assert_equal true, model.instance_variable_get(:@tie_word_embeddings) + assert_nil model.lm_head + end +end + +class Phase20DenseLaneQGemma3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_gemma3_construct_forward_shape_and_sanitize_multimodal_prefixes + args = MlxLm::Models::Gemma3::ModelArgs.from_dict({ + "model_type" => "gemma3", + "vocab_size" => 97, + "text_config" => { + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "head_dim" => 4, + "rms_norm_eps" => 1e-6, + "sliding_window" => 8, + "sliding_window_pattern" => 2, + "query_pre_attn_scalar" => 16.0, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 128, + }, + }) + + model = MlxLm::Models::Gemma3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 97], output.shape + + weights = { + "vision_tower.patch_embed.weight" => @mx.zeros([1]).astype(@mx.float32), + "multi_modal_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "language_model.model.embed_tokens.weight" => @mx.zeros([97, 32]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "language_model.model.norm.weight" => @mx.zeros([32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("vision_tower.patch_embed.weight") + refute sanitized.key?("multi_modal_projector.linear.weight") + refute sanitized.key?("language_model.model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("language_model.model.embed_tokens.weight") + assert sanitized.key?("language_model.model.norm.weight") + + assert_equal true, model.language_model.instance_variable_get(:@tie_word_embeddings) + assert_nil model.language_model.lm_head + end +end + +class Phase20DenseLaneQRegistryTest < Minitest::Test + def test_models_registered_and_resolvable + assert MlxLm::Models::REGISTRY.key?("gemma3_text"), "gemma3_text should be registered" + assert MlxLm::Models::REGISTRY.key?("gemma3"), "gemma3 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "gemma3_text" }) + assert_equal MlxLm::Models::Gemma3Text::Model, model_class + assert_equal MlxLm::Models::Gemma3Text::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "gemma3" }) + assert_equal MlxLm::Models::Gemma3::Model, model_class + assert_equal MlxLm::Models::Gemma3::ModelArgs, args_class + end +end diff --git a/test/parity/gemma3n_ernie4_5_moe_models_test.rb b/test/parity/gemma3n_ernie4_5_moe_models_test.rb new file mode 100644 index 0000000..d0eefef --- /dev/null +++ b/test/parity/gemma3n_ernie4_5_moe_models_test.rb @@ -0,0 +1,150 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/gemma2" +require_relative "../../lib/mlx_lm/models/ernie4_5" +require_relative "../../lib/mlx_lm/models/gemma3n" +require_relative "../../lib/mlx_lm/models/ernie4_5_moe" + +class Phase20DenseLaneRGemma3nTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_gemma3n_construct_forward_shape_and_sanitize_with_shared_config_handling + shared_config = { + "model_type" => "gemma3n", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "vocab_size" => 111, + "head_dim" => 8, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "final_logit_softcapping" => 20.0, + "query_pre_attn_scalar" => 64.0, + "attn_logit_softcapping" => 25.0, + } + + args = MlxLm::Models::Gemma3n::ModelArgs.from_dict(shared_config) + refute_same shared_config, args.text_config + args.text_config["added_by_args"] = true + refute shared_config.key?("added_by_args") + + model = MlxLm::Models::Gemma3n::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 4, 111], output.shape + + weights = { + "model.vision_tower.patch_embed.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.audio_tower.blocks.0.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.embed_audio.proj.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.embed_vision.proj.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.language_model.embed_tokens.weight" => @mx.zeros([111, 32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.vision_tower.patch_embed.weight") + refute sanitized.key?("model.audio_tower.blocks.0.weight") + refute sanitized.key?("model.embed_audio.proj.weight") + refute sanitized.key?("model.embed_vision.proj.weight") + assert sanitized.key?("model.language_model.embed_tokens.weight") + end +end + +class Phase20DenseLaneRErnie45MoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_ernie4_5_moe_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::Ernie45Moe::ModelArgs.from_dict({ + "model_type" => "ernie4_5_moe", + "hidden_size" => 32, + "intermediate_size" => 64, + "max_position_embeddings" => 256, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "num_hidden_layers" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "rope_theta" => 10_000.0, + "use_bias" => false, + "tie_word_embeddings" => false, + "moe_num_experts" => 2, + "moe_k" => 1, + "moe_layer_interval" => 1, + "moe_layer_start_index" => 0, + "moe_num_shared_experts" => 1, + "moe_gate_act" => "softmax", + }) + + model = MlxLm::Models::Ernie45Moe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 3, 101], output.shape + + weights = { + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.down_proj.weight" => @mx.array([[0.0, 1.0], [1.0, 0.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.down_proj.weight" => @mx.array([[2.0, 3.0], [4.0, 5.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 1.0], [1.0, 1.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[2.0, 2.0], [2.0, 2.0]], dtype: @mx.float32), + "model.layers.0.mtp_block.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.layers.0.e_score_correction_bias" => @mx.zeros([1]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + down = sanitized["model.layers.0.mlp.switch_mlp.down_proj.weight"] + up = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + @mx.eval(gate, down, up) + + assert_equal [2, 2, 2], gate.shape + assert_equal [2, 2, 2], down.shape + assert_equal [2, 2, 2], up.shape + + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.0.mtp_block.weight") + refute sanitized.key?("model.layers.0.e_score_correction_bias") + assert sanitized.key?("model.embed_tokens.weight") + end +end + +class Phase20DenseLaneRRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("gemma3n"), "gemma3n should be registered" + assert MlxLm::Models::REGISTRY.key?("ernie4_5_moe"), "ernie4_5_moe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "gemma3n" }) + assert_equal MlxLm::Models::Gemma3n::Model, model_class + assert_equal MlxLm::Models::Gemma3n::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "ernie4_5_moe" }) + assert_equal MlxLm::Models::Ernie45Moe::Model, model_class + assert_equal MlxLm::Models::Ernie45Moe::ModelArgs, args_class + end +end diff --git a/test/parity/glm4_telechat3_models_test.rb b/test/parity/glm4_telechat3_models_test.rb new file mode 100644 index 0000000..a549e64 --- /dev/null +++ b/test/parity/glm4_telechat3_models_test.rb @@ -0,0 +1,89 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/glm4" +require_relative "../../lib/mlx_lm/models/telechat3" + +class Phase17DenseLaneFGLM4Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_glm4_construct_forward_shape_and_registry_resolution + args = MlxLm::Models::GLM4::ModelArgs.from_dict({ + "model_type" => "glm4", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "intermediate_size" => 128, + "num_attention_heads" => 4, + "attention_bias" => false, + "head_dim" => 16, + "rms_norm_eps" => 1e-5, + "vocab_size" => 128, + "num_key_value_heads" => 2, + "partial_rotary_factor" => 0.5, + "rope_theta" => 10_000.0, + "rope_traditional" => true, + }) + + model = MlxLm::Models::GLM4::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 128], output.shape + assert MlxLm::Models::REGISTRY.key?("glm4"), "glm4 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "glm4" }) + assert_equal MlxLm::Models::GLM4::Model, model_class + assert_equal MlxLm::Models::GLM4::ModelArgs, args_class + end +end + +class Phase17DenseLaneFTelechat3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_telechat3_construct_forward_shape_and_registry_resolution + args = MlxLm::Models::Telechat3::ModelArgs.from_dict({ + "model_type" => "telechat3", + "hidden_size" => 64, + "intermediate_size" => 128, + "max_position_embeddings" => 256, + "num_attention_heads" => 4, + "num_hidden_layers" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-6, + "vocab_size" => 96, + "rope_theta" => 10_000.0, + "mlp_bias" => false, + "attention_bias" => false, + "head_dim" => 16, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::Telechat3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + assert MlxLm::Models::REGISTRY.key?("telechat3"), "telechat3 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "telechat3" }) + assert_equal MlxLm::Models::Telechat3::Model, model_class + assert_equal MlxLm::Models::Telechat3::ModelArgs, args_class + end +end diff --git a/test/parity/glm_helium_models_test.rb b/test/parity/glm_helium_models_test.rb new file mode 100644 index 0000000..14a874e --- /dev/null +++ b/test/parity/glm_helium_models_test.rb @@ -0,0 +1,89 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/glm" +require_relative "../../lib/mlx_lm/models/helium" + +class Phase16DenseLaneCGlmTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_glm_instantiation_forward_shape_and_sanitize + args = MlxLm::Models::GLM::ModelArgs.from_dict({ + "model_type" => "glm", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "intermediate_size" => 128, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 128, + "head_dim" => 16, + "attention_bias" => false, + "tie_word_embeddings" => true, + }) + model = MlxLm::Models::GLM::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 128], output.shape + assert MlxLm::Models::REGISTRY.key?("glm") + + sanitized = model.sanitize({ + "lm_head.weight" => 1, + "model.layers.0.self_attn.rotary_emb.inv_freq" => 2, + "model.embed_tokens.weight" => 3, + }) + refute sanitized.key?("lm_head.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert_equal 3, sanitized["model.embed_tokens.weight"] + end +end + +class Phase16DenseLaneCHeliumTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_helium_instantiation_forward_shape_and_mlp_bias + args = MlxLm::Models::Helium::ModelArgs.from_dict({ + "model_type" => "helium", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "intermediate_size" => 128, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 96, + "head_dim" => 8, + "attention_bias" => false, + "mlp_bias" => true, + "tie_word_embeddings" => false, + "rope_theta" => 1000.0, + }) + model = MlxLm::Models::Helium::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 96], output.shape + assert MlxLm::Models::REGISTRY.key?("helium") + + mlp = model.layers[0].mlp + assert mlp.gate_proj.state.key?("bias") + assert mlp.up_proj.state.key?("bias") + assert mlp.down_proj.state.key?("bias") + end +end diff --git a/test/parity/governance_parity_gates_test.rb b/test/parity/governance_parity_gates_test.rb new file mode 100644 index 0000000..48ebdb9 --- /dev/null +++ b/test/parity/governance_parity_gates_test.rb @@ -0,0 +1,60 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require "open3" +require_relative "../../tasks/parity_inventory_task" + +class Phase13GovernanceGatesTest < Minitest::Test + MLX_ONNX_SUBMODULE_DIR = File.expand_path("../../mlx-ruby/submodules/mlx-onnx", __dir__) + REQUIRED_MLX_ONNX_OPS = { + "ArgPartition" => /ArgPartition|arg_partition/, + "GatherMM" => /GatherMM|gather_mm/, + }.freeze + + def test_parity_inventory_snapshot_is_current + message = <<~MSG + parity inventory snapshot is stale. + regenerate with: bundle exec rake parity:inventory + MSG + + assert ParityInventoryTask.run!(check: true), message + end + + def test_mlx_onnx_checkout_includes_required_ops + mlx_onnx_dir = MLX_ONNX_SUBMODULE_DIR + skip "mlx-onnx submodule checkout not available" unless Dir.exist?(mlx_onnx_dir) + + missing_ops = REQUIRED_MLX_ONNX_OPS.keys.reject do |op_name| + mlx_onnx_source_includes?(mlx_onnx_dir, REQUIRED_MLX_ONNX_OPS[op_name]) + end + + current_sha, = Open3.capture3("git", "-C", mlx_onnx_dir, "rev-parse", "HEAD") + + message = <<~MSG + mlx-onnx capability gate failed. + required ops: #{REQUIRED_MLX_ONNX_OPS.keys.join(", ")} + missing ops: #{missing_ops.join(", ")} + checkout path: #{mlx_onnx_dir} + current HEAD: #{current_sha.strip} + ensure mlx-ruby mlx-onnx checkout includes ArgPartition/GatherMM support. + MSG + + assert missing_ops.empty?, message + end + + private + + SOURCE_GLOB = "**/*.{cc,cpp,c,h,hpp,hh,mm,m,py,rb}".freeze + + def mlx_onnx_source_includes?(root_dir, pattern) + Dir.glob(File.join(root_dir, SOURCE_GLOB)).any? do |path| + next false unless File.file?(path) + + begin + File.read(path).match?(pattern) + rescue Encoding::InvalidByteSequenceError, Encoding::UndefinedConversionError + false + end + end + end +end diff --git a/test/parity/gpt_bigcode_nemotron_models_test.rb b/test/parity/gpt_bigcode_nemotron_models_test.rb new file mode 100644 index 0000000..4b902c4 --- /dev/null +++ b/test/parity/gpt_bigcode_nemotron_models_test.rb @@ -0,0 +1,90 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/gpt_bigcode" +require_relative "../../lib/mlx_lm/models/nemotron" + +class Phase18DenseLaneJGPTBigCodeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_gpt_bigcode_construct_forward_shape_and_multi_query_kv_heads + args = MlxLm::Models::GPTBigCode::ModelArgs.from_dict({ + "model_type" => "gpt_bigcode", + "n_embd" => 64, + "n_layer" => 2, + "n_inner" => 128, + "n_head" => 4, + "n_positions" => 64, + "layer_norm_epsilon" => 1e-5, + "vocab_size" => 96, + "multi_query" => true, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::GPTBigCode::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + assert_equal 1, model.layers.first.attn.instance_variable_get(:@n_kv_heads) + end +end + +class Phase18DenseLaneJNemotronTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_nemotron_construct_forward_shape_and_linear_rope_scale + args = MlxLm::Models::Nemotron::ModelArgs.from_dict({ + "model_type" => "nemotron", + "hidden_size" => 64, + "hidden_act" => "relu2", + "num_hidden_layers" => 2, + "intermediate_size" => 128, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "norm_eps" => 1e-5, + "vocab_size" => 80, + "partial_rotary_factor" => 0.5, + "rope_theta" => 10_000.0, + "rope_scaling" => { "type" => "linear", "factor" => 2.0 }, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::Nemotron::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 80], output.shape + assert_in_delta 0.5, model.layers.first.self_attn.rope.instance_variable_get(:@scale), 1e-12 + end +end + +class Phase18DenseLaneJRegistryTest < Minitest::Test + def test_models_registered + assert MlxLm::Models::REGISTRY.key?("gpt_bigcode"), "gpt_bigcode should be registered" + assert MlxLm::Models::REGISTRY.key?("nemotron"), "nemotron should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "gpt_bigcode" }) + assert_equal MlxLm::Models::GPTBigCode::Model, model_class + assert_equal MlxLm::Models::GPTBigCode::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "nemotron" }) + assert_equal MlxLm::Models::Nemotron::Model, model_class + assert_equal MlxLm::Models::Nemotron::ModelArgs, args_class + end +end diff --git a/test/parity/granite_minicpm_models_test.rb b/test/parity/granite_minicpm_models_test.rb new file mode 100644 index 0000000..ed0d777 --- /dev/null +++ b/test/parity/granite_minicpm_models_test.rb @@ -0,0 +1,127 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/granite" +require_relative "../../lib/mlx_lm/models/minicpm" + +class Phase17DenseLaneGRegistryTest < Minitest::Test + def test_registry_entries_for_lane_g_models + assert MlxLm::Models::REGISTRY.key?("granite"), "granite should be registered" + assert MlxLm::Models::REGISTRY.key?("minicpm"), "minicpm should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "granite" }) + assert_equal MlxLm::Models::Granite::Model, model_class + assert_equal MlxLm::Models::Granite::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "minicpm" }) + assert_equal MlxLm::Models::MiniCPM::Model, model_class + assert_equal MlxLm::Models::MiniCPM::ModelArgs, args_class + end +end + +class Phase17DenseLaneGGraniteTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_granite_construct_forward_shape_and_scaling_fields + args = MlxLm::Models::Granite::ModelArgs.from_dict({ + "model_type" => "granite", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 128, + "logits_scaling" => 2.0, + "attention_multiplier" => 0.25, + "embedding_multiplier" => 1.5, + "residual_multiplier" => 0.75, + "max_position_embeddings" => 256, + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + }) + model = MlxLm::Models::Granite::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 128], output.shape + assert_in_delta 2.0, model.instance_variable_get(:@logits_scaling), 1e-12 + assert_in_delta 1.5, model.model.instance_variable_get(:@embedding_multiplier), 1e-12 + assert_in_delta 0.75, model.layers.first.instance_variable_get(:@residual_multiplier), 1e-12 + assert_in_delta 0.25, model.layers.first.self_attn.instance_variable_get(:@scale), 1e-12 + end +end + +class Phase17DenseLaneGMiniCPMTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_minicpm_construct_forward_shape_and_depth_scaling + args = MlxLm::Models::MiniCPM::ModelArgs.from_dict({ + "model_type" => "minicpm", + "hidden_size" => 32, + "dim_model_base" => 16, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 96, + "scale_depth" => 1.4, + "scale_emb" => 8.0, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }) + model = MlxLm::Models::MiniCPM::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 96], output.shape + expected_residual_scale = args.scale_depth / Math.sqrt(args.num_hidden_layers) + assert_in_delta expected_residual_scale, model.layers.first.instance_variable_get(:@residual_scale), 1e-12 + end + + def test_minicpm_sanitize_adds_missing_lm_head_weight + args = MlxLm::Models::MiniCPM::ModelArgs.from_dict({ + "model_type" => "minicpm", + "hidden_size" => 32, + "dim_model_base" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 96, + "scale_depth" => 1.0, + "scale_emb" => 1.0, + "tie_word_embeddings" => false, + }) + model = MlxLm::Models::MiniCPM::Model.new(args) + embed_weight = @mx.zeros([96, 32]).astype(@mx.float32) + weights = { "model.embed_tokens.weight" => embed_weight } + + sanitized = model.sanitize(weights) + + assert sanitized.key?("lm_head.weight") + assert_same embed_weight, sanitized["lm_head.weight"] + end +end diff --git a/test/parity/granitemoe_olmoe_models_test.rb b/test/parity/granitemoe_olmoe_models_test.rb new file mode 100644 index 0000000..d992076 --- /dev/null +++ b/test/parity/granitemoe_olmoe_models_test.rb @@ -0,0 +1,143 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/granite" +require_relative "../../lib/mlx_lm/models/granitemoe" +require_relative "../../lib/mlx_lm/models/olmo2" +require_relative "../../lib/mlx_lm/models/olmoe" + +class Phase20DenseLaneTGraniteMoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_granitemoe_construct_forward_shape_and_sanitize_moe_linear_split + args = MlxLm::Models::GraniteMoe::ModelArgs.from_dict({ + "model_type" => "granitemoe", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 97, + "logits_scaling" => 2.0, + "attention_multiplier" => 0.25, + "embedding_multiplier" => 1.25, + "residual_multiplier" => 0.75, + "max_position_embeddings" => 256, + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::GraniteMoe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + weights = { + "model.layers.0.block_sparse_moe.input_linear.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.0.block_sparse_moe.output_linear.weight" => @mx.zeros([3, 4]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([97, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.layers.0.block_sparse_moe.input_linear.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.output_linear.weight") + refute sanitized.key?("lm_head.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.down_proj.weight") + assert_equal [4, 3], sanitized["model.layers.0.block_sparse_moe.switch_mlp.gate_proj.weight"].shape + assert_equal [4, 3], sanitized["model.layers.0.block_sparse_moe.switch_mlp.up_proj.weight"].shape + assert_equal [3, 4], sanitized["model.layers.0.block_sparse_moe.switch_mlp.down_proj.weight"].shape + end +end + +class Phase20DenseLaneTOLMoETest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_olmoe_construct_forward_shape_and_sanitize_stacks_expert_weights + args = MlxLm::Models::OLMoE::ModelArgs.from_dict({ + "model_type" => "olmoe", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "intermediate_size" => 64, + "vocab_size" => 103, + "rms_norm_eps" => 1e-5, + "head_dim" => 8, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "norm_topk_prob" => false, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::OLMoE::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 103], output.shape + + weights = { + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[2.0, 2.0], [2.0, 2.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[4.0, 4.0], [4.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.down_proj.weight" => @mx.array([[9.0, 9.0], [9.0, 9.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.down_proj.weight" => @mx.array([[10.0, 10.0], [10.0, 10.0]], dtype: @mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([103, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked_gate = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + @mx.eval(stacked_gate) + + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("lm_head.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.down_proj.weight") + assert_equal [2, 2, 2], stacked_gate.shape + end +end + +class Phase20DenseLaneTRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("granitemoe"), "granitemoe should be registered" + assert MlxLm::Models::REGISTRY.key?("olmoe"), "olmoe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "granitemoe" }) + assert_equal MlxLm::Models::GraniteMoe::Model, model_class + assert_equal MlxLm::Models::GraniteMoe::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "olmoe" }) + assert_equal MlxLm::Models::OLMoE::Model, model_class + assert_equal MlxLm::Models::OLMoE::ModelArgs, args_class + end +end diff --git a/test/parity/granitemoehybrid_jamba_models_test.rb b/test/parity/granitemoehybrid_jamba_models_test.rb new file mode 100644 index 0000000..052f68b --- /dev/null +++ b/test/parity/granitemoehybrid_jamba_models_test.rb @@ -0,0 +1,177 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/recurrent_gemma" +require_relative "../../lib/mlx_lm/models/falcon_h1" +require_relative "../../lib/mlx_lm/models/granitemoehybrid" +require_relative "../../lib/mlx_lm/models/jamba" + +class Phase27HybridLaneAQGraniteMoeHybridTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_granitemoehybrid_construct_forward_shape_sanitize_mapping_and_cache + args = MlxLm::Models::GraniteMoeHybrid::ModelArgs.from_dict({ + "model_type" => "granitemoehybrid", + "vocab_size" => 73, + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "max_position_embeddings" => 128, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "attention_bias" => false, + "embedding_multiplier" => 1.0, + "attention_multiplier" => 1.0, + "logits_scaling" => 1.0, + "residual_multiplier" => 1.0, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "shared_intermediate_size" => 32, + "mamba_n_heads" => 4, + "mamba_d_head" => 8, + "mamba_d_conv" => 3, + "layer_types" => ["mamba", "attention"], + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + }) + assert_instance_of MlxLm::Models::GraniteMoeHybrid::ModelArgs, args + + model = MlxLm::Models::GraniteMoeHybrid::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 73], output.shape + + input_linear = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]) + output_linear = @mx.array((0...12).to_a, dtype: @mx.float32).reshape([3, 4]) + expected_gate, _expected_up = @mx.split(input_linear, [3], 1) + weights = { + "model.layers.0.block_sparse_moe.input_linear.weight" => input_linear, + "model.layers.0.block_sparse_moe.output_linear.weight" => output_linear, + "model.layers.0.mamba.in_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.layers.0.temporal_block.linear_out.weight" => @mx.ones([32, 32]).astype(@mx.float32), + "model.norm.weight" => @mx.ones([32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + mapped_gate = sanitized["model.layers.0.mlp_block.switch_mlp.gate_proj.weight"] + mapped_in_proj = sanitized["model.layers.0.temporal_block.linear_x.weight"] + pass_through = sanitized["model.layers.0.temporal_block.linear_out.weight"] + mapped_final_norm = sanitized["model.final_norm.weight"] + @mx.eval(mapped_gate, mapped_in_proj, pass_through, mapped_final_norm) + + refute sanitized.key?("model.layers.0.block_sparse_moe.input_linear.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.output_linear.weight") + refute sanitized.key?("model.layers.0.mamba.in_proj.weight") + refute sanitized.key?("model.norm.weight") + assert sanitized.key?("model.layers.0.mlp_block.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp_block.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp_block.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_out.weight") + assert sanitized.key?("model.final_norm.weight") + assert_equal [4, 3], mapped_gate.shape + assert_equal expected_gate.to_a, mapped_gate.to_a + + cache = model.make_cache + assert_equal 2, cache.length + end +end + +class Phase27HybridLaneAQJambaTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_jamba_construct_forward_shape_sanitize_mapping_and_cache + args = MlxLm::Models::Jamba::ModelArgs.from_dict({ + "model_type" => "jamba", + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "attn_layer_offset" => 1, + "attn_layer_period" => 2, + "expert_layer_offset" => 1, + "expert_layer_period" => 2, + "mamba_d_conv" => 3, + "mamba_d_state" => 16, + "mamba_expand" => 2, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "vocab_size" => 79, + "tie_word_embeddings" => true, + }) + assert_instance_of MlxLm::Models::Jamba::ModelArgs, args + + model = MlxLm::Models::Jamba::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 79], output.shape + + expert0 = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + expert1 = @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32) + weights = { + "model.layers.0.feed_forward.experts.0.up_proj.weight" => expert0, + "model.layers.0.feed_forward.experts.1.up_proj.weight" => expert1, + "model.layers.0.mamba.in_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.layers.0.temporal_block.linear_out.weight" => @mx.ones([32, 32]).astype(@mx.float32), + "model.final_layernorm.weight" => @mx.ones([32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked_experts = sanitized["model.layers.0.mlp_block.switch_mlp.up_proj.weight"] + mapped_in_proj = sanitized["model.layers.0.temporal_block.linear_x.weight"] + pass_through = sanitized["model.layers.0.temporal_block.linear_out.weight"] + mapped_final_norm = sanitized["model.final_norm.weight"] + @mx.eval(stacked_experts, mapped_in_proj, pass_through, mapped_final_norm) + + refute sanitized.key?("model.layers.0.feed_forward.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.feed_forward.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.mamba.in_proj.weight") + refute sanitized.key?("model.final_layernorm.weight") + assert sanitized.key?("model.layers.0.mlp_block.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_out.weight") + assert sanitized.key?("model.final_norm.weight") + assert_equal [2, 2, 2], stacked_experts.shape + assert_equal expert0.to_a, stacked_experts[0].to_a + assert_equal expert1.to_a, stacked_experts[1].to_a + + cache = model.make_cache + assert_equal 2, cache.length + end +end + +class Phase27HybridLaneAQRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("granitemoehybrid"), "granitemoehybrid should be registered" + assert MlxLm::Models::REGISTRY.key?("jamba"), "jamba should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "granitemoehybrid" }) + assert_equal MlxLm::Models::GraniteMoeHybrid::Model, model_class + assert_equal MlxLm::Models::GraniteMoeHybrid::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "jamba" }) + assert_equal MlxLm::Models::Jamba::Model, model_class + assert_equal MlxLm::Models::Jamba::ModelArgs, args_class + end +end diff --git a/test/parity/hunyuan_gpt_oss_models_test.rb b/test/parity/hunyuan_gpt_oss_models_test.rb new file mode 100644 index 0000000..65fa502 --- /dev/null +++ b/test/parity/hunyuan_gpt_oss_models_test.rb @@ -0,0 +1,200 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/hunyuan" +require_relative "../../lib/mlx_lm/models/gpt_oss" + +class Phase24DenseLaneAIHunyuanTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_hunyuan_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::Hunyuan::ModelArgs.from_dict({ + "model_type" => "hunyuan", + "vocab_size" => 103, + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "attention_bias" => false, + "moe_topk" => 1, + "num_experts" => 2, + "num_shared_expert" => 1, + "use_mixed_mlp_moe" => true, + "use_qk_norm" => true, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "use_cla" => true, + "cla_share_factor" => 2, + "rope_scaling" => { + "alpha" => 1.0, + "factor" => 1.0, + "type" => "dynamic", + }, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Hunyuan::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 103], output.shape + assert_equal 2, model.layers.length + + weights = { + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array((0...12).to_a, dtype: @mx.float32).reshape([4, 3]), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array((12...24).to_a, dtype: @mx.float32).reshape([4, 3]), + "model.layers.0.mlp.experts.0.down_proj.weight" => @mx.array((0...12).to_a, dtype: @mx.float32).reshape([3, 4]), + "model.layers.0.mlp.experts.1.down_proj.weight" => @mx.array((12...24).to_a, dtype: @mx.float32).reshape([3, 4]), + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array((0...12).to_a, dtype: @mx.float32).reshape([4, 3]), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array((12...24).to_a, dtype: @mx.float32).reshape([4, 3]), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([103, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + up_stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + down_stacked = sanitized["model.layers.0.mlp.switch_mlp.down_proj.weight"] + gate_stacked = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + @mx.eval(up_stacked, down_stacked, gate_stacked) + + refute sanitized.key?("lm_head.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + assert_equal [2, 4, 3], up_stacked.shape + assert_equal [2, 3, 4], down_stacked.shape + assert_equal [2, 4, 3], gate_stacked.shape + end +end + +class Phase24DenseLaneAIGptOssTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_gpt_oss_construct_forward_shape_sanitize_cleanup_and_cache_mix + args = MlxLm::Models::GptOss::ModelArgs.from_dict({ + "model_type" => "gpt_oss", + "num_hidden_layers" => 4, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "vocab_size" => 109, + "rms_norm_eps" => 1e-5, + "hidden_size" => 32, + "intermediate_size" => 16, + "head_dim" => 8, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "sliding_window" => 4, + "rope_theta" => 10_000, + "layer_types" => [ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + }) + + model = MlxLm::Models::GptOss::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 109], output.shape + assert_equal 4, model.layers.length + + weights = { + "model.layers.0.mlp.experts.gate_up_proj.weight" => @mx.array((0...48).to_a, dtype: @mx.float32).reshape([2, 6, 4]), + "model.layers.0.mlp.experts.gate_up_proj_bias" => @mx.array((0...12).to_a, dtype: @mx.float32).reshape([2, 6]), + "model.layers.0.mlp.experts.down_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([2, 4, 3]), + "model.layers.0.mlp.experts.gate_up_proj_scales" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([2, 6, 2]), + "model.layers.0.mlp.experts.down_proj_scales" => @mx.array((0...16).to_a, dtype: @mx.float32).reshape([2, 4, 2]), + "model.layers.3.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + gate_weight = sanitized["model.layers.0.mlp.experts.gate_proj.weight"] + up_weight = sanitized["model.layers.0.mlp.experts.up_proj.weight"] + gate_bias = sanitized["model.layers.0.mlp.experts.gate_proj.bias"] + up_bias = sanitized["model.layers.0.mlp.experts.up_proj.bias"] + gate_scales = sanitized["model.layers.0.mlp.experts.gate_proj.scales"] + up_scales = sanitized["model.layers.0.mlp.experts.up_proj.scales"] + down_weight = sanitized["model.layers.0.mlp.experts.down_proj.weight"] + down_scales = sanitized["model.layers.0.mlp.experts.down_proj.scales"] + @mx.eval( + gate_weight, + up_weight, + gate_bias, + up_bias, + gate_scales, + up_scales, + down_weight, + down_scales + ) + + refute sanitized.key?("model.layers.0.mlp.experts.gate_up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.gate_up_proj_bias") + refute sanitized.key?("model.layers.0.mlp.experts.gate_up_proj_scales") + refute sanitized.key?("model.layers.0.mlp.experts.down_proj_scales") + assert sanitized.key?("model.layers.0.mlp.experts.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.experts.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.experts.gate_proj.bias") + assert sanitized.key?("model.layers.0.mlp.experts.up_proj.bias") + assert sanitized.key?("model.layers.0.mlp.experts.gate_proj.scales") + assert sanitized.key?("model.layers.0.mlp.experts.up_proj.scales") + assert sanitized.key?("model.layers.0.mlp.experts.down_proj.weight") + assert sanitized.key?("model.layers.0.mlp.experts.down_proj.scales") + assert sanitized.key?("model.layers.3.self_attn.q_proj.weight") + assert_equal [2, 3, 4], gate_weight.shape + assert_equal [2, 3, 4], up_weight.shape + assert_equal [2, 3], gate_bias.shape + assert_equal [2, 3], up_bias.shape + assert_equal [2, 3, 2], gate_scales.shape + assert_equal [2, 3, 2], up_scales.shape + + caches = model.make_cache + assert_equal 4, caches.length + assert_instance_of MlxLm::RotatingKVCache, caches[0] + assert_instance_of MlxLm::KVCache, caches[1] + assert_instance_of MlxLm::RotatingKVCache, caches[2] + assert_instance_of MlxLm::KVCache, caches[3] + end +end + +class Phase24DenseLaneAIRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("hunyuan"), "hunyuan should be registered" + assert MlxLm::Models::REGISTRY.key?("gpt_oss"), "gpt_oss should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "hunyuan" }) + assert_equal MlxLm::Models::Hunyuan::Model, model_class + assert_equal MlxLm::Models::Hunyuan::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "gpt_oss" }) + assert_equal MlxLm::Models::GptOss::Model, model_class + assert_equal MlxLm::Models::GptOss::ModelArgs, args_class + end +end diff --git a/test/parity/hunyuan_v1_dense_dbrx_models_test.rb b/test/parity/hunyuan_v1_dense_dbrx_models_test.rb new file mode 100644 index 0000000..66d1ef0 --- /dev/null +++ b/test/parity/hunyuan_v1_dense_dbrx_models_test.rb @@ -0,0 +1,128 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/hunyuan_v1_dense" +require_relative "../../lib/mlx_lm/models/dbrx" + +class Phase23DenseLaneAEHunyuanV1DenseTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_hunyuan_v1_dense_construct_forward_shape_and_sanitize_tied_embeddings + args = MlxLm::Models::HunyuanV1Dense::ModelArgs.from_dict({ + "model_type" => "hunyuan_v1_dense", + "vocab_size" => 97, + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "attention_bias" => false, + "use_qk_norm" => true, + "tie_word_embeddings" => true, + "rope_scaling" => { + "alpha" => 1.0, + "factor" => 1.0, + "type" => "dynamic", + }, + }) + + model = MlxLm::Models::HunyuanV1Dense::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + weights = { + "model.embed_tokens.weight" => @mx.zeros([97, 32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([97, 32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("lm_head.weight") + assert sanitized.key?("model.embed_tokens.weight") + end +end + +class Phase23DenseLaneAEDbrxTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_dbrx_construct_forward_shape_and_sanitize_splits_expert_weights + args = MlxLm::Models::Dbrx::ModelArgs.from_dict({ + "model_type" => "dbrx", + "vocab_size" => 101, + "d_model" => 24, + "n_layers" => 2, + "n_heads" => 4, + "attn_config" => { + "kv_n_heads" => 2, + "clip_qkv" => 8.0, + "rope_theta" => 10_000.0, + }, + "ffn_config" => { + "ffn_hidden_size" => 16, + "moe_num_experts" => 2, + "moe_top_k" => 1, + }, + }) + + model = MlxLm::Models::Dbrx::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + + expert_w1 = @mx.array((0...12).to_a, dtype: @mx.float32).reshape([4, 3]) + expert_w2 = @mx.array((12...24).to_a, dtype: @mx.float32).reshape([4, 3]) + weights = { + "transformer.blocks.0.ffn.experts.mlp.w1" => expert_w1, + "transformer.blocks.0.ffn.experts.mlp.w2" => expert_w2, + "transformer.blocks.1.norm_attn_norm.attn.out_proj.weight" => @mx.zeros([24, 24]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("transformer.blocks.0.ffn.experts.mlp.w1") + refute sanitized.key?("transformer.blocks.0.ffn.experts.mlp.w2") + assert sanitized.key?("transformer.blocks.0.ffn.experts.0.w1.weight") + assert sanitized.key?("transformer.blocks.0.ffn.experts.1.w1.weight") + assert sanitized.key?("transformer.blocks.0.ffn.experts.0.w2.weight") + assert sanitized.key?("transformer.blocks.0.ffn.experts.1.w2.weight") + assert sanitized.key?("transformer.blocks.1.norm_attn_norm.attn.out_proj.weight") + + assert_equal [2, 3], sanitized["transformer.blocks.0.ffn.experts.0.w1.weight"].shape + assert_equal [2, 3], sanitized["transformer.blocks.0.ffn.experts.1.w1.weight"].shape + assert_equal [3, 2], sanitized["transformer.blocks.0.ffn.experts.0.w2.weight"].shape + assert_equal [3, 2], sanitized["transformer.blocks.0.ffn.experts.1.w2.weight"].shape + end +end + +class Phase23DenseLaneAERegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("hunyuan_v1_dense"), "hunyuan_v1_dense should be registered" + assert MlxLm::Models::REGISTRY.key?("dbrx"), "dbrx should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "hunyuan_v1_dense" }) + assert_equal MlxLm::Models::HunyuanV1Dense::Model, model_class + assert_equal MlxLm::Models::HunyuanV1Dense::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "dbrx" }) + assert_equal MlxLm::Models::Dbrx::Model, model_class + assert_equal MlxLm::Models::Dbrx::ModelArgs, args_class + end +end diff --git a/test/parity/kimi_k25_kimi_vl_models_test.rb b/test/parity/kimi_k25_kimi_vl_models_test.rb new file mode 100644 index 0000000..b98776f --- /dev/null +++ b/test/parity/kimi_k25_kimi_vl_models_test.rb @@ -0,0 +1,126 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/deepseek" +require_relative "../../lib/mlx_lm/models/kimi_k25" +require_relative "../../lib/mlx_lm/models/kimi_vl" + +class Phase22DenseLaneAAKimiK25Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_kimi_k25_construct_forward_shape_and_sanitize_drops_multimodal_towers + args = MlxLm::Models::KimiK25::ModelArgs.from_dict({ + "model_type" => "kimi_k25", + "text_config" => { + "model_type" => "deepseek", + "vocab_size" => 83, + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "max_position_embeddings" => 256, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + }, + }) + + model = MlxLm::Models::KimiK25::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 83], output.shape + + weights = { + "vision_tower.encoder.weight" => @mx.zeros([1]).astype(@mx.float32), + "vision_model.encoder.weight" => @mx.zeros([1]).astype(@mx.float32), + "multi_modal_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "mm_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "language_model.model.embed_tokens.weight" => @mx.zeros([83, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("vision_tower.encoder.weight") + refute sanitized.key?("vision_model.encoder.weight") + refute sanitized.key?("multi_modal_projector.linear.weight") + refute sanitized.key?("mm_projector.linear.weight") + refute sanitized.key?("language_model.model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("language_model.model.embed_tokens.weight") + end +end + +class Phase22DenseLaneAAKimiVLTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_kimi_vl_construct_forward_shape_and_sanitize_drops_multimodal_towers + args = MlxLm::Models::KimiVL::ModelArgs.from_dict({ + "model_type" => "kimi_vl", + "text_config" => { + "model_type" => "deepseek", + "vocab_size" => 89, + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "max_position_embeddings" => 256, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + }, + }) + + model = MlxLm::Models::KimiVL::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 89], output.shape + + weights = { + "vision_tower.patch_embed.weight" => @mx.zeros([1]).astype(@mx.float32), + "multi_modal_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "language_model.model.embed_tokens.weight" => @mx.zeros([89, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("vision_tower.patch_embed.weight") + refute sanitized.key?("multi_modal_projector.linear.weight") + refute sanitized.key?("language_model.model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("language_model.model.embed_tokens.weight") + end +end + +class Phase22DenseLaneAARegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("kimi_k25"), "kimi_k25 should be registered" + assert MlxLm::Models::REGISTRY.key?("kimi_vl"), "kimi_vl should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "kimi_k25" }) + assert_equal MlxLm::Models::KimiK25::Model, model_class + assert_equal MlxLm::Models::KimiK25::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "kimi_vl" }) + assert_equal MlxLm::Models::KimiVL::Model, model_class + assert_equal MlxLm::Models::KimiVL::ModelArgs, args_class + end +end diff --git a/test/parity/kimi_linear_longcat_flash_models_test.rb b/test/parity/kimi_linear_longcat_flash_models_test.rb new file mode 100644 index 0000000..b2a02b6 --- /dev/null +++ b/test/parity/kimi_linear_longcat_flash_models_test.rb @@ -0,0 +1,176 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/kimi_linear" +require_relative "../../lib/mlx_lm/models/longcat_flash" + +class Phase44DenseLaneARKimiLinearTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_kimi_linear_wrapper_from_dict_forward_shape_and_sanitize_mapping_and_stacking + args = MlxLm::Models::KimiLinear::ModelArgs.from_dict({ + "model_type" => "kimi_linear", + "vocab_size" => 73, + "hidden_dim" => 32, + "ffn_hidden_size" => 64, + "moe_intermediate_size" => 24, + "num_layers" => 2, + "num_heads" => 4, + "num_kv_heads" => 2, + "num_local_experts" => 2, + "n_shared_experts" => 1, + "top_k" => 1, + "norm_topk_prob" => true, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "first_k_dense_replace" => 0, + "layer_group_size" => 1, + "group_norm_size" => 1, + "use_bias" => false, + "use_qkv_bias" => false, + "tie_word_embeddings" => false, + "score_func" => "softmax", + "n_group" => 1, + "topk_group" => 1, + "moe_router_enable_expert_bias" => true, + "moe_router_enable_shared_expert" => true, + }) + + assert_equal 32, args.hidden_size + assert_equal 2, args.num_hidden_layers + assert_equal 4, args.num_attention_heads + assert_equal 2, args.num_experts + + model = MlxLm::Models::KimiLinear::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 73], output.shape + assert_equal 2, model.layers.length + assert_nil model.make_cache + + gate_weight = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + weights = { + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.mlp.router.weight" => gate_weight, + "model.layers.0.mlp.router.bias" => @mx.array([0.1, -0.2], dtype: @mx.float32), + "model.embed_tokens.weight" => @mx.zeros([73, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.0.mlp.switch_mlp.gate_proj.weight"] + remapped_gate = sanitized["model.layers.0.mlp.gate.gate_proj.weight"] + remapped_bias = sanitized["model.layers.0.mlp.gate.gate_proj.bias"] + remapped_embed = sanitized["model.word_embeddings.weight"] + @mx.eval(stacked, remapped_gate, remapped_bias, remapped_embed) + + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.router.weight") + refute sanitized.key?("model.layers.0.mlp.router.bias") + refute sanitized.key?("model.embed_tokens.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.gate.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.gate.gate_proj.bias") + assert sanitized.key?("model.word_embeddings.weight") + assert_equal [2, 2, 2], stacked.shape + assert_equal gate_weight.to_a, remapped_gate.to_a + end +end + +class Phase44DenseLaneARLongcatFlashTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_longcat_flash_wrapper_from_dict_forward_shape_and_sanitize_mapping_and_stacking + args = MlxLm::Models::LongcatFlash::ModelArgs.from_dict({ + "model_type" => "longcat_flash", + "vocab_size" => 71, + "hidden_dim" => 32, + "ffn_hidden_size" => 64, + "moe_intermediate_size" => 16, + "num_layers" => 2, + "num_heads" => 4, + "num_kv_heads" => 4, + "num_local_experts" => 2, + "num_shared_experts" => 1, + "routed_scaling_factor" => 1.0, + "kv_lora_rank" => 8, + "q_lora_rank" => 8, + "qk_rope_head_dim" => 8, + "qk_nope_head_dim" => 8, + "v_head_dim" => 8, + "topk_method" => "noaux_tc", + "score_function" => "sigmoid", + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "top_k" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + "partial_rotary_factor" => 1.0, + "tie_word_embeddings" => false, + }) + + assert_equal 32, args.hidden_size + assert_equal 2, args.num_hidden_layers + assert_equal 4, args.num_attention_heads + assert_equal 2, args.n_routed_experts + + model = MlxLm::Models::LongcatFlash::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 71], output.shape + assert_equal 2, model.layers.length + assert_nil model.make_cache + + embed_q_weight = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + weights = { + "model.layers.0.attention.embed_q.weight" => embed_q_weight, + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.2.mlp.up_proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + mapped_q_proj = sanitized["model.layers.0.self_attn.q_proj.weight"] + stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + @mx.eval(mapped_q_proj, stacked) + + refute sanitized.key?("model.layers.0.attention.embed_q.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.2.mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert_equal embed_q_weight.to_a, mapped_q_proj.to_a + assert_equal [2, 2, 2], stacked.shape + end +end + +class Phase44DenseLaneARRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("kimi_linear"), "kimi_linear should be registered" + assert MlxLm::Models::REGISTRY.key?("longcat_flash"), "longcat_flash should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "kimi_linear" }) + assert_equal MlxLm::Models::KimiLinear::Model, model_class + assert_equal MlxLm::Models::KimiLinear::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "longcat_flash" }) + assert_equal MlxLm::Models::LongcatFlash::Model, model_class + assert_equal MlxLm::Models::LongcatFlash::ModelArgs, args_class + end +end diff --git a/test/parity/klear_iquestloopcoder_models_test.rb b/test/parity/klear_iquestloopcoder_models_test.rb new file mode 100644 index 0000000..b2595d9 --- /dev/null +++ b/test/parity/klear_iquestloopcoder_models_test.rb @@ -0,0 +1,138 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/klear" +require_relative "../../lib/mlx_lm/models/iquestloopcoder" + +class Phase23DenseLaneAFKlearTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_klear_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::Klear::ModelArgs.from_dict({ + "model_type" => "Klear", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "attention_bias" => false, + "mlp_only_layers" => [], + "num_experts" => 2, + "num_experts_per_tok" => 1, + "decoder_sparse_step" => 1, + "n_shared_experts" => 1, + "moe_intermediate_size" => 48, + "rms_norm_eps" => 1e-5, + "vocab_size" => 97, + "num_key_value_heads" => 4, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "norm_topk_prob" => true, + }) + + model = MlxLm::Models::Klear::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + weights = { + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array((48...72).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array((72...96).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.0.mlp.experts.0.down_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.0.mlp.experts.1.down_proj.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked_gate = sanitized["model.layers.0.mlp.experts.gate_proj.weight"] + stacked_up = sanitized["model.layers.0.mlp.experts.up_proj.weight"] + stacked_down = sanitized["model.layers.0.mlp.experts.down_proj.weight"] + @mx.eval(stacked_gate, stacked_up, stacked_down) + + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.down_proj.weight") + assert sanitized.key?("model.layers.0.mlp.experts.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.experts.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.experts.down_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + assert_equal [2, 6, 4], stacked_gate.shape + assert_equal [2, 6, 4], stacked_up.shape + assert_equal [2, 4, 6], stacked_down.shape + end +end + +class Phase23DenseLaneAFIquestloopcoderTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_iquestloopcoder_construct_forward_shape_and_make_cache_halves + args = MlxLm::Models::Iquestloopcoder::ModelArgs.from_dict({ + "model_type" => "iquestloopcoder", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 89, + "head_dim" => 8, + "max_position_embeddings" => 256, + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + "loop_num" => 2, + "loop_window_size" => 4, + }) + + model = MlxLm::Models::Iquestloopcoder::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 89], output.shape + + cache = model.make_cache + assert_equal 4, cache.length + assert_instance_of MlxLm::KVCache, cache[0] + assert_instance_of MlxLm::KVCache, cache[1] + assert_instance_of MlxLm::RotatingKVCache, cache[2] + assert_instance_of MlxLm::RotatingKVCache, cache[3] + end +end + +class Phase23DenseLaneAFRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("Klear"), "Klear should be registered" + assert MlxLm::Models::REGISTRY.key?("iquestloopcoder"), "iquestloopcoder should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "Klear" }) + assert_equal MlxLm::Models::Klear::Model, model_class + assert_equal MlxLm::Models::Klear::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "iquestloopcoder" }) + assert_equal MlxLm::Models::Iquestloopcoder::Model, model_class + assert_equal MlxLm::Models::Iquestloopcoder::ModelArgs, args_class + end +end diff --git a/test/parity/phase3_test.rb b/test/parity/kv_cache_and_llama_model_test.rb similarity index 95% rename from test/parity/phase3_test.rb rename to test/parity/kv_cache_and_llama_model_test.rb index e25ff06..1cc457e 100644 --- a/test/parity/phase3_test.rb +++ b/test/parity/kv_cache_and_llama_model_test.rb @@ -116,7 +116,7 @@ def test_forward_pass_shapes model = make_small_model mx_mod.eval(model.parameters) - input_ids = mx_mod.array([[1, 2, 3, 4]]).astype(mx_mod.int32) + input_ids = mx_mod.array([[1, 2, 3, 4]], dtype: mx_mod.int32) logits = model.call(input_ids) mx_mod.eval(logits) @@ -160,14 +160,14 @@ def test_forward_with_cache cache = MlxLm::Cache.make_prompt_cache(model) # Prefill - input_ids = mx.array([[1, 2, 3]]).astype(mx.int32) + input_ids = mx.array([[1, 2, 3]], dtype: mx.int32) logits1 = model.call(input_ids, cache: cache) mx.eval(logits1) assert_equal [1, 3, 100], logits1.shape # Generation step - next_token = mx.array([[4]]).astype(mx.int32) + next_token = mx.array([[4]], dtype: mx.int32) logits2 = model.call(next_token, cache: cache) mx.eval(logits2) @@ -188,7 +188,7 @@ def test_tied_embeddings model = MlxLm::Models::Llama::Model.new(args) mx.eval(model.parameters) - input_ids = mx.array([[1, 2]]).astype(mx.int32) + input_ids = mx.array([[1, 2]], dtype: mx.int32) logits = model.call(input_ids) mx.eval(logits) assert_equal [1, 2, 100], logits.shape @@ -207,7 +207,7 @@ def test_untied_embeddings model = MlxLm::Models::Llama::Model.new(args) mx.eval(model.parameters) - input_ids = mx.array([[1, 2]]).astype(mx.int32) + input_ids = mx.array([[1, 2]], dtype: mx.int32) logits = model.call(input_ids) mx.eval(logits) assert_equal [1, 2, 100], logits.shape diff --git a/test/parity/lfm2_lfm2_vl_models_test.rb b/test/parity/lfm2_lfm2_vl_models_test.rb new file mode 100644 index 0000000..a07dceb --- /dev/null +++ b/test/parity/lfm2_lfm2_vl_models_test.rb @@ -0,0 +1,130 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/qwen3" +require_relative "../../lib/mlx_lm/models/lfm2" +require_relative "../../lib/mlx_lm/models/lfm2_vl" + +class Phase22DenseLaneABLfm2Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_lfm2_construct_forward_shape_and_sanitize_conv_transpose + args = MlxLm::Models::Lfm2::ModelArgs.from_dict({ + "model_type" => "lfm2", + "vocab_size" => 97, + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "max_position_embeddings" => 128, + "norm_eps" => 1e-5, + "conv_bias" => false, + "conv_L_cache" => 3, + "block_dim" => 32, + "block_ff_dim" => 64, + "block_multiple_of" => 8, + "block_ffn_dim_multiplier" => 1.5, + "block_auto_adjust_ff_dim" => true, + "layer_types" => ["full_attention", "conv"], + "rope_parameters" => { "rope_theta" => 10_000.0 }, + }) + + model = MlxLm::Models::Lfm2::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([2, 3, 4]) + weights = { + "model.layers.0.conv.weight" => conv_weight, + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["model.layers.0.conv.weight"] + @mx.eval(sanitized_conv) + + assert_equal [2, 4, 3], sanitized_conv.shape + assert_equal @mx.swapaxes(conv_weight, 1, 2).to_a, sanitized_conv.to_a + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + end +end + +class Phase22DenseLaneABLfm2VlTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_lfm2_vl_construct_forward_shape_and_sanitize_multimodal_key_removal + args = MlxLm::Models::Lfm2VL::ModelArgs.from_dict({ + "model_type" => "lfm2-vl", + "text_config" => { + "model_type" => "lfm2", + "vocab_size" => 101, + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "max_position_embeddings" => 128, + "norm_eps" => 1e-5, + "conv_bias" => false, + "conv_L_cache" => 3, + "block_dim" => 32, + "block_ff_dim" => 64, + "block_multiple_of" => 8, + "block_ffn_dim_multiplier" => 1.5, + "block_auto_adjust_ff_dim" => true, + "layer_types" => ["full_attention", "conv"], + "rope_theta" => 10_000.0, + }, + }) + + model = MlxLm::Models::Lfm2VL::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + + weights = { + "vision_tower.patch_embed.weight" => @mx.zeros([1]).astype(@mx.float32), + "multi_modal_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "language_model.model.embed_tokens.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + refute sanitized.key?("vision_tower.patch_embed.weight") + refute sanitized.key?("multi_modal_projector.linear.weight") + assert sanitized.key?("language_model.model.embed_tokens.weight") + assert sanitized.key?("language_model.model.layers.0.self_attn.q_proj.weight") + end +end + +class Phase22DenseLaneABRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("lfm2"), "lfm2 should be registered" + assert MlxLm::Models::REGISTRY.key?("lfm2-vl"), "lfm2-vl should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "lfm2" }) + assert_equal MlxLm::Models::Lfm2::Model, model_class + assert_equal MlxLm::Models::Lfm2::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "lfm2-vl" }) + assert_equal MlxLm::Models::Lfm2VL::Model, model_class + assert_equal MlxLm::Models::Lfm2VL::ModelArgs, args_class + end +end diff --git a/test/parity/lille130m_mimo_models_test.rb b/test/parity/lille130m_mimo_models_test.rb new file mode 100644 index 0000000..1421a8b --- /dev/null +++ b/test/parity/lille130m_mimo_models_test.rb @@ -0,0 +1,113 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/lille_130m" +require_relative "../../lib/mlx_lm/models/mimo" + +class Phase21DenseLaneVLille130mTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_lille_130m_construct_forward_shape_and_sanitize_rotary_weights + args = MlxLm::Models::Lille130m::ModelArgs.from_dict({ + "model_type" => "lille-130m", + "block_size" => 128, + "layer_norm_eps" => 1e-5, + "n_embd" => 96, + "n_head" => 4, + "n_kv_heads" => 2, + "n_layer" => 2, + "rope_theta" => 10_000.0, + "vocab_size" => 89, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Lille130m::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 89], output.shape + assert_equal 2, model.layers.length + + weights = { + "transformer.layers.0.attention.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "transformer.layers.0.feed_forward.gate_proj.weight" => @mx.zeros([1]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("transformer.layers.0.attention.rotary_emb.inv_freq") + assert sanitized.key?("transformer.layers.0.feed_forward.gate_proj.weight") + end +end + +class Phase21DenseLaneVMimoTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_mimo_construct_forward_shape_and_sanitize_tied_embeddings_and_mtp_weights + args = MlxLm::Models::Mimo::ModelArgs.from_dict({ + "model_type" => "mimo", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "rope_traditional" => false, + "tie_word_embeddings" => true, + "num_nextn_predict_layers" => 2, + }) + + model = MlxLm::Models::Mimo::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + assert_nil model.lm_head + + weights = { + "lm_head.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "model.mtp_layers.0.proj.weight" => @mx.zeros([1]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("lm_head.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("model.mtp_layers.0.proj.weight") + assert sanitized.key?("model.embed_tokens.weight") + end +end + +class Phase21DenseLaneVRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("lille-130m"), "lille-130m should be registered" + assert MlxLm::Models::REGISTRY.key?("mimo"), "mimo should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "lille-130m" }) + assert_equal MlxLm::Models::Lille130m::Model, model_class + assert_equal MlxLm::Models::Lille130m::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "mimo" }) + assert_equal MlxLm::Models::Mimo::Model, model_class + assert_equal MlxLm::Models::Mimo::ModelArgs, args_class + end +end diff --git a/test/parity/llama4_ministral3_models_test.rb b/test/parity/llama4_ministral3_models_test.rb new file mode 100644 index 0000000..33640f3 --- /dev/null +++ b/test/parity/llama4_ministral3_models_test.rb @@ -0,0 +1,158 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/pipeline" +require_relative "../../lib/mlx_lm/models/llama4" +require_relative "../../lib/mlx_lm/models/ministral3" + +class Phase24DenseLaneAHLlama4Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_llama4_construct_forward_shape_sanitize_and_make_cache + args = MlxLm::Models::Llama4::ModelArgs.from_dict({ + "model_type" => "llama4", + "text_config" => { + "model_type" => "llama4_text", + "hidden_size" => 32, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "num_hidden_layers" => 4, + "vocab_size" => 97, + "intermediate_size" => 48, + "intermediate_size_mlp" => 64, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "interleave_moe_layer_step" => 2, + "attention_chunk_size" => 4, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "head_dim" => 8, + "rms_norm_eps" => 1e-5, + "attention_bias" => false, + "use_qk_norm" => true, + }, + }) + + model = MlxLm::Models::Llama4::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 97], output.shape + + prefix = "language_model.model.layers.1.feed_forward.experts" + weights = { + "vision_model.patch_embed.weight" => @mx.zeros([1]).astype(@mx.float32), + "multi_modal_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "#{prefix}.gate_up_proj" => @mx.array((0...48).to_a, dtype: @mx.float32).reshape([2, 3, 8]), + "#{prefix}.down_proj" => @mx.array((0...36).to_a, dtype: @mx.float32).reshape([2, 6, 3]), + "language_model.model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate_proj = sanitized["#{prefix}.gate_proj.weight"] + up_proj = sanitized["#{prefix}.up_proj.weight"] + down_proj = sanitized["#{prefix}.down_proj.weight"] + @mx.eval(gate_proj, up_proj, down_proj) + + refute sanitized.key?("vision_model.patch_embed.weight") + refute sanitized.key?("multi_modal_projector.linear.weight") + refute sanitized.key?("#{prefix}.gate_up_proj") + refute sanitized.key?("#{prefix}.down_proj") + assert_equal [2, 4, 3], gate_proj.shape + assert_equal [2, 4, 3], up_proj.shape + assert_equal [2, 3, 6], down_proj.shape + assert sanitized.key?("language_model.model.layers.0.self_attn.q_proj.weight") + + cache = model.make_cache + assert_equal 4, cache.length + assert_instance_of MlxLm::ChunkedKVCache, cache[0] + assert_instance_of MlxLm::ChunkedKVCache, cache[1] + assert_instance_of MlxLm::ChunkedKVCache, cache[2] + assert_instance_of MlxLm::KVCache, cache[3] + end +end + +class Phase24DenseLaneAHMinistral3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_ministral3_construct_forward_shape_sanitize_and_make_cache + args = MlxLm::Models::Ministral3::ModelArgs.from_dict({ + "model_type" => "ministral3", + "hidden_size" => 32, + "num_hidden_layers" => 4, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "head_dim" => 8, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "tie_word_embeddings" => true, + "sliding_window" => 8, + "layer_types" => ["sliding_attention", "full_attention", "sliding_attention", "full_attention"], + "rope_parameters" => { + "rope_theta" => 10_000.0, + "llama_4_scaling_beta" => 0.1, + "original_max_position_embeddings" => 128, + }, + }) + + model = MlxLm::Models::Ministral3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 101], output.shape + + weights = { + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("lm_head.weight") + assert sanitized.key?("model.embed_tokens.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + + cache = model.make_cache + assert_equal 4, cache.length + assert_instance_of MlxLm::RotatingKVCache, cache[0] + assert_instance_of MlxLm::KVCache, cache[1] + assert_instance_of MlxLm::RotatingKVCache, cache[2] + assert_instance_of MlxLm::KVCache, cache[3] + end +end + +class Phase24DenseLaneAHRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("llama4"), "llama4 should be registered" + assert MlxLm::Models::REGISTRY.key?("ministral3"), "ministral3 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "llama4" }) + assert_equal MlxLm::Models::Llama4::Model, model_class + assert_equal MlxLm::Models::Llama4::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "ministral3" }) + assert_equal MlxLm::Models::Ministral3::Model, model_class + assert_equal MlxLm::Models::Ministral3::ModelArgs, args_class + end +end diff --git a/test/parity/llama4_text_plamo_models_test.rb b/test/parity/llama4_text_plamo_models_test.rb new file mode 100644 index 0000000..65d3094 --- /dev/null +++ b/test/parity/llama4_text_plamo_models_test.rb @@ -0,0 +1,100 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/llama4_text" +require_relative "../../lib/mlx_lm/models/plamo" + +class Phase23DenseLaneACLlama4TextTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_llama4_text_construct_forward_shape_and_tied_output_path + args_hash = { + "model_type" => "llama4_text", + "hidden_size" => 32, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "num_hidden_layers" => 2, + "vocab_size" => 97, + "intermediate_size" => 64, + "intermediate_size_mlp" => 64, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "head_dim" => 8, + "tie_word_embeddings" => true, + "no_rope_layers" => [0, 1], + "use_qk_norm" => true, + } + + args = MlxLm::Models::Llama4Text::ModelArgs.from_dict(args_hash) + model = MlxLm::Models::Llama4Text::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 4, 97], output.shape + assert_equal 2, model.layers.length + assert_nil model.output + + untied_args = MlxLm::Models::Llama4Text::ModelArgs.from_dict( + args_hash.merge("tie_word_embeddings" => false) + ) + untied_model = MlxLm::Models::Llama4Text::Model.new(untied_args) + assert_instance_of MLX::NN::Linear, untied_model.output + end +end + +class Phase23DenseLaneACPlamoTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_plamo_construct_forward_shape + args = MlxLm::Models::Plamo::ModelArgs.from_dict({ + "model_type" => "plamo", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "n_shared_head" => 2, + "rope_theta" => 10_000.0, + "rope_traditional" => false, + }) + + model = MlxLm::Models::Plamo::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 3, 101], output.shape + assert_equal 2, model.layers.length + end +end + +class Phase23DenseLaneACRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("llama4_text"), "llama4_text should be registered" + assert MlxLm::Models::REGISTRY.key?("plamo"), "plamo should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "llama4_text" }) + assert_equal MlxLm::Models::Llama4Text::Model, model_class + assert_equal MlxLm::Models::Llama4Text::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "plamo" }) + assert_equal MlxLm::Models::Plamo::Model, model_class + assert_equal MlxLm::Models::Plamo::ModelArgs, args_class + end +end diff --git a/test/parity/longcat_flash_ngram_qwen3_next_models_test.rb b/test/parity/longcat_flash_ngram_qwen3_next_models_test.rb new file mode 100644 index 0000000..f5ad5f6 --- /dev/null +++ b/test/parity/longcat_flash_ngram_qwen3_next_models_test.rb @@ -0,0 +1,179 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/longcat_flash_ngram" +require_relative "../../lib/mlx_lm/models/qwen3_next" + +class Phase46DenseLaneASLongcatFlashNgramTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_longcat_flash_ngram_wrapper_from_dict_forward_shape_and_sanitize_mapping + args = MlxLm::Models::LongcatFlashNgram::ModelArgs.from_dict({ + "model_type" => "longcat_flash_ngram", + "vocab_size" => 67, + "hidden_size" => 32, + "ffn_hidden_size" => 64, + "expert_ffn_hidden_size" => 16, + "num_layers" => 2, + "num_attention_heads" => 4, + "n_routed_experts" => 2, + "zero_expert_num" => 1, + "moe_topk" => 1, + "kv_lora_rank" => 8, + "q_lora_rank" => 8, + "qk_rope_head_dim" => 8, + "qk_nope_head_dim" => 8, + "v_head_dim" => 8, + "routed_scaling_factor" => 1.0, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "first_k_dense_replace" => 0, + "tie_word_embeddings" => false, + }) + + assert_equal 2, args.num_hidden_layers + assert_equal 1, args.num_experts_per_tok + assert_equal 16, args.moe_intermediate_size + + model = MlxLm::Models::LongcatFlashNgram::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 67], output.shape + assert_equal 2, model.layers.length + assert_nil model.make_cache + + embed_q_weight = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + weights = { + "model.ngram_embeddings.word_embeddings.weight" => @mx.zeros([67, 32]).astype(@mx.float32), + "model.layers.0.attention.embed_q.weight" => embed_q_weight, + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.2.up_proj.weight" => @mx.array([[9.0, 10.0], [11.0, 12.0]], dtype: @mx.float32), + "model.layers.2.mlp.up_proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + mapped_q_proj = sanitized["model.layers.0.self_attn.q_proj.weight"] + stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + mapped_ngram_embed = sanitized["model.ngram_embeddings.word_embeddings.weight"] + @mx.eval(mapped_q_proj, stacked, mapped_ngram_embed) + + refute sanitized.key?("model.layers.0.attention.embed_q.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.2.up_proj.weight") + refute sanitized.key?("model.layers.2.mlp.up_proj.weight") + refute sanitized.key?("model.embed_tokens.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.ngram_embeddings.word_embeddings.weight") + assert_equal embed_q_weight.to_a, mapped_q_proj.to_a + assert_equal [3, 2, 2], stacked.shape + end +end + +class Phase46DenseLaneASQwen3NextTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_next_wrapper_from_dict_forward_shape_and_sanitize_mapping + args = MlxLm::Models::Qwen3Next::ModelArgs.from_dict({ + "model_type" => "qwen3_next", + "vocab_size" => 71, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "shared_expert_intermediate_size" => 20, + "decoder_sparse_step" => 1, + "mlp_only_layers" => [0], + "linear_num_value_heads" => 2, + "linear_num_key_heads" => 2, + "linear_key_head_dim" => 8, + "linear_value_head_dim" => 8, + "linear_conv_kernel_dim" => 3, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "partial_rotary_factor" => 0.5, + "attention_bias" => false, + "tie_word_embeddings" => false, + }) + + assert_equal 1, args.first_k_dense_replace + assert_equal 1, args.num_shared_experts + assert_equal 20, args.moe_shared_expert_intermediate_size + + model = MlxLm::Models::Qwen3Next::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 71], output.shape + assert_equal 2, model.layers.length + assert_nil model.make_cache + + gate_weight = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + weights = { + "model.layers.1.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.1.mlp.router.weight" => gate_weight, + "model.layers.1.mlp.router.bias" => @mx.array([0.1, -0.2], dtype: @mx.float32), + "model.layers.1.mlp.shared_expert.gate_proj.weight" => @mx.zeros([20, 32]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([71, 32]).astype(@mx.float32), + "model.mtp.layers.0.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.1.mlp.switch_mlp.gate_proj.weight"] + remapped_gate = sanitized["model.layers.1.mlp.gate.gate_proj.weight"] + remapped_bias = sanitized["model.layers.1.mlp.gate.gate_proj.bias"] + remapped_shared = sanitized["model.layers.1.mlp.shared_experts.gate_proj.weight"] + remapped_embed = sanitized["model.word_embeddings.weight"] + @mx.eval(stacked, remapped_gate, remapped_bias, remapped_shared, remapped_embed) + + refute sanitized.key?("model.layers.1.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.router.weight") + refute sanitized.key?("model.layers.1.mlp.router.bias") + refute sanitized.key?("model.layers.1.mlp.shared_expert.gate_proj.weight") + refute sanitized.key?("model.embed_tokens.weight") + refute sanitized.key?("model.mtp.layers.0.weight") + assert sanitized.key?("model.layers.1.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.gate.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.gate.gate_proj.bias") + assert sanitized.key?("model.layers.1.mlp.shared_experts.gate_proj.weight") + assert sanitized.key?("model.word_embeddings.weight") + assert_equal [2, 2, 2], stacked.shape + assert_equal gate_weight.to_a, remapped_gate.to_a + end +end + +class Phase46DenseLaneASRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("longcat_flash_ngram"), "longcat_flash_ngram should be registered" + assert MlxLm::Models::REGISTRY.key?("qwen3_next"), "qwen3_next should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "longcat_flash_ngram" }) + assert_equal MlxLm::Models::LongcatFlashNgram::Model, model_class + assert_equal MlxLm::Models::LongcatFlashNgram::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen3_next" }) + assert_equal MlxLm::Models::Qwen3Next::Model, model_class + assert_equal MlxLm::Models::Qwen3Next::ModelArgs, args_class + end +end diff --git a/test/parity/phase9_test.rb b/test/parity/lora_layers_training_test.rb similarity index 99% rename from test/parity/phase9_test.rb rename to test/parity/lora_layers_training_test.rb index eff771a..af2e0d4 100644 --- a/test/parity/phase9_test.rb +++ b/test/parity/lora_layers_training_test.rb @@ -63,7 +63,7 @@ def test_lora_embedding_forward lora_embed = MlxLm::Tuner::LoRAEmbedding.from_base(embed, r: 8, scale: 20.0) @mx.eval(*MLX::Utils.tree_flatten(lora_embed.parameters).map { |_, v| v }) - ids = @mx.array([[1, 2, 3]]).astype(@mx.int32) + ids = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = lora_embed.call(ids) assert_equal [1, 3, 32], output.shape end diff --git a/test/parity/mamba_mamba2_models_test.rb b/test/parity/mamba_mamba2_models_test.rb new file mode 100644 index 0000000..9e9b0a0 --- /dev/null +++ b/test/parity/mamba_mamba2_models_test.rb @@ -0,0 +1,142 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/ssm" +require_relative "../../lib/mlx_lm/models/mamba" +require_relative "../../lib/mlx_lm/models/mamba2" + +class Phase23DenseLaneADMambaTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_mamba_tiny_construct_forward_shape_and_sanitize_conv_transpose + args = MlxLm::Models::Mamba::ModelArgs.from_dict({ + "model_type" => "mamba", + "vocab_size" => 67, + "hidden_size" => 32, + "intermediate_size" => 16, + "state_size" => 8, + "num_hidden_layers" => 2, + "conv_kernel" => 3, + "use_bias" => true, + "use_conv_bias" => true, + "time_step_rank" => "auto", + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Mamba::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 67], output.shape + + conv_weight = @mx.array((0...6).to_a, dtype: @mx.float32).reshape([2, 1, 3]) + weights = { + "backbone.layers.0.mixer.conv1d.weight" => conv_weight, + "backbone.layers.0.mixer.in_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["backbone.layers.0.mixer.conv1d.weight"] + @mx.eval(sanitized_conv) + + assert_equal [2, 3, 1], sanitized_conv.shape + assert_equal @mx.swapaxes(conv_weight, 1, 2).to_a, sanitized_conv.to_a + assert sanitized.key?("backbone.layers.0.mixer.in_proj.weight") + end +end + +class Phase23DenseLaneADMamba2Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_mamba2_tiny_construct_forward_shape_and_sanitize_conv_transpose + args = MlxLm::Models::Mamba2::ModelArgs.from_dict({ + "model_type" => "mamba2", + "num_heads" => 4, + "head_dim" => 4, + "vocab_size" => 71, + "hidden_size" => 32, + "intermediate_size" => 16, + "state_size" => 8, + "num_hidden_layers" => 2, + "layer_norm_epsilon" => 1e-5, + "conv_kernel" => 3, + "n_groups" => 2, + "use_bias" => true, + "use_conv_bias" => true, + "tie_word_embeddings" => true, + "time_step_limit" => [0.001, 10.0], + "time_step_rank" => "auto", + }) + + model = MlxLm::Models::Mamba2::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 71], output.shape + + conv_weight = @mx.array((0...12).to_a, dtype: @mx.float32).reshape([3, 1, 4]) + weights = { + "backbone.layers.0.mixer.conv1d.weight" => conv_weight, + "backbone.layers.0.mixer.in_proj.weight" => @mx.zeros([16, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["backbone.layers.0.mixer.conv1d.weight"] + @mx.eval(sanitized_conv) + + assert_equal [3, 4, 1], sanitized_conv.shape + assert_equal @mx.swapaxes(conv_weight, 1, 2).to_a, sanitized_conv.to_a + assert sanitized.key?("backbone.layers.0.mixer.in_proj.weight") + end +end + +class Phase23DenseLaneADRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("mamba"), "mamba should be registered" + assert MlxLm::Models::REGISTRY.key?("mamba2"), "mamba2 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "mamba" }) + assert_equal MlxLm::Models::Mamba::Model, model_class + assert_equal MlxLm::Models::Mamba::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "mamba2" }) + assert_equal MlxLm::Models::Mamba2::Model, model_class + assert_equal MlxLm::Models::Mamba2::ModelArgs, args_class + end + + def test_falcon_mamba_alias_remaps_to_mamba_and_enables_bcdt_rms + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "falcon_mamba" }) + assert_equal MlxLm::Models::Mamba::Model, model_class + assert_equal MlxLm::Models::Mamba::ModelArgs, args_class + + args = args_class.from_dict({ + "model_type" => "falcon_mamba", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 16, + "state_size" => 8, + "num_hidden_layers" => 1, + "conv_kernel" => 3, + "use_bias" => true, + "use_conv_bias" => true, + "time_step_rank" => "auto", + }) + + assert_equal true, args.use_bcdt_rms + end +end diff --git a/test/parity/mimo_v2_flash_lfm2_moe_models_test.rb b/test/parity/mimo_v2_flash_lfm2_moe_models_test.rb new file mode 100644 index 0000000..6a6f811 --- /dev/null +++ b/test/parity/mimo_v2_flash_lfm2_moe_models_test.rb @@ -0,0 +1,223 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/mimo_v2_flash" +require_relative "../../lib/mlx_lm/models/lfm2_moe" + +class Phase24DenseLaneAJMimoV2FlashTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_mimo_v2_flash_construct_forward_shape_and_sanitize_stacks_experts_and_cleans_mtp + args = MlxLm::Models::MimoV2Flash::ModelArgs.from_dict({ + "model_type" => "mimo_v2_flash", + "num_experts_per_tok" => 1, + "hybrid_layer_pattern" => [0, 1], + "moe_layer_freq" => [0, 1], + "add_swa_attention_sink_bias" => false, + "add_full_attention_sink_bias" => false, + "sliding_window_size" => 2, + "vocab_size" => 113, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 48, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "topk_method" => "noaux_tc", + "scoring_func" => "sigmoid", + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "max_position_embeddings" => 128, + "layernorm_epsilon" => 1e-5, + "rope_theta" => 10_000.0, + "swa_rope_theta" => 20_000.0, + "swa_num_attention_heads" => 4, + "swa_num_key_value_heads" => 2, + "head_dim" => 8, + "v_head_dim" => 8, + "swa_head_dim" => 8, + "swa_v_head_dim" => 8, + "partial_rotary_factor" => 1.0, + }) + + model = MlxLm::Models::MimoV2Flash::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 113], output.shape + + weights = { + "model.layers.1.mlp.experts.0.gate_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.1.gate_proj.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.0.up_proj.weight" => @mx.array((48...72).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.1.up_proj.weight" => @mx.array((72...96).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.0.down_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.1.mlp.experts.1.down_proj.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.mtp.layers.0.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked_gate = sanitized["model.layers.1.mlp.switch_mlp.gate_proj.weight"] + stacked_up = sanitized["model.layers.1.mlp.switch_mlp.up_proj.weight"] + stacked_down = sanitized["model.layers.1.mlp.switch_mlp.down_proj.weight"] + @mx.eval(stacked_gate, stacked_up, stacked_down) + + refute sanitized.key?("model.layers.1.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.down_proj.weight") + refute sanitized.key?("model.mtp.layers.0.weight") + + assert sanitized.key?("model.layers.1.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.1.mlp.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + + assert_equal [2, 6, 4], stacked_gate.shape + assert_equal [2, 6, 4], stacked_up.shape + assert_equal [2, 4, 6], stacked_down.shape + end +end + +class Phase24DenseLaneAJLfm2MoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_lfm2_moe_construct_forward_shape_and_sanitize_transposes_conv_and_stacks_experts + args = MlxLm::Models::Lfm2Moe::ModelArgs.from_dict({ + "model_type" => "lfm2_moe", + "vocab_size" => 101, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 48, + "num_hidden_layers" => 3, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "norm_topk_prob" => true, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "max_position_embeddings" => 128, + "use_expert_bias" => true, + "num_dense_layers" => 1, + "norm_eps" => 1e-5, + "conv_bias" => false, + "conv_L_cache" => 3, + "layer_types" => ["full_attention", "conv", "full_attention"], + "rope_parameters" => { "rope_theta" => 10_000.0 }, + }) + + model = MlxLm::Models::Lfm2Moe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([2, 3, 4]) + weights = { + "model.layers.1.conv.weight" => conv_weight, + "model.layers.1.feed_forward.experts.0.w1.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.feed_forward.experts.1.w1.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.feed_forward.experts.0.w2.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.1.feed_forward.experts.1.w2.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.1.feed_forward.experts.0.w3.weight" => @mx.array((48...72).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.feed_forward.experts.1.w3.weight" => @mx.array((72...96).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["model.layers.1.conv.weight"] + stacked_gate = sanitized["model.layers.1.feed_forward.switch_mlp.gate_proj.weight"] + stacked_down = sanitized["model.layers.1.feed_forward.switch_mlp.down_proj.weight"] + stacked_up = sanitized["model.layers.1.feed_forward.switch_mlp.up_proj.weight"] + @mx.eval(sanitized_conv, stacked_gate, stacked_down, stacked_up) + + assert_equal [2, 4, 3], sanitized_conv.shape + assert_equal @mx.swapaxes(conv_weight, 1, 2).to_a, sanitized_conv.to_a + + refute sanitized.key?("model.layers.1.feed_forward.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.feed_forward.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.1.feed_forward.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.1.feed_forward.experts.1.down_proj.weight") + refute sanitized.key?("model.layers.1.feed_forward.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.1.feed_forward.experts.1.up_proj.weight") + + assert sanitized.key?("model.layers.1.feed_forward.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.feed_forward.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.1.feed_forward.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + + assert_equal [2, 6, 4], stacked_gate.shape + assert_equal [2, 4, 6], stacked_down.shape + assert_equal [2, 6, 4], stacked_up.shape + end + + def test_lfm2_moe_make_cache_returns_attention_and_conv_cache_types + args = MlxLm::Models::Lfm2Moe::ModelArgs.from_dict({ + "model_type" => "lfm2_moe", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 3, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "norm_topk_prob" => false, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "max_position_embeddings" => 64, + "use_expert_bias" => false, + "num_dense_layers" => 1, + "norm_eps" => 1e-5, + "conv_bias" => false, + "conv_L_cache" => 3, + "full_attn_idxs" => [0, 2], + "rope_theta" => 10_000.0, + }) + + model = MlxLm::Models::Lfm2Moe::Model.new(args) + cache = model.make_cache + + assert_equal 3, cache.length + assert_instance_of MlxLm::KVCache, cache[0] + assert_instance_of MlxLm::ArraysCache, cache[1] + assert_instance_of MlxLm::KVCache, cache[2] + end +end + +class Phase24DenseLaneAJRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("mimo_v2_flash"), "mimo_v2_flash should be registered" + assert MlxLm::Models::REGISTRY.key?("lfm2_moe"), "lfm2_moe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "mimo_v2_flash" }) + assert_equal MlxLm::Models::MimoV2Flash::Model, model_class + assert_equal MlxLm::Models::MimoV2Flash::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "lfm2_moe" }) + assert_equal MlxLm::Models::Lfm2Moe::Model, model_class + assert_equal MlxLm::Models::Lfm2Moe::ModelArgs, args_class + end +end diff --git a/test/parity/minimax_nemotron_nas_models_test.rb b/test/parity/minimax_nemotron_nas_models_test.rb new file mode 100644 index 0000000..9b5ad98 --- /dev/null +++ b/test/parity/minimax_nemotron_nas_models_test.rb @@ -0,0 +1,154 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/minimax" +require_relative "../../lib/mlx_lm/models/nemotron_nas" + +class Phase22DenseLaneAMMinimaxTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_minimax_construct_forward_shape_and_sanitize_moe_remap + args = MlxLm::Models::Minimax::ModelArgs.from_dict({ + "model_type" => "minimax", + "hidden_size" => 32, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "max_position_embeddings" => 128, + "num_experts_per_tok" => 1, + "num_local_experts" => 2, + "shared_intermediate_size" => 32, + "num_hidden_layers" => 2, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "rotary_dim" => 8, + "vocab_size" => 97, + "tie_word_embeddings" => false, + "use_qk_norm" => true, + }) + + model = MlxLm::Models::Minimax::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 97], output.shape + + weights = { + "model.embed_tokens.weight" => @mx.zeros([97, 32]).astype(@mx.float32), + "model.layers.0.block_sparse_moe.experts.0.w1.weight" => @mx.zeros([4, 4]).astype(@mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w1.weight" => @mx.ones([4, 4]).astype(@mx.float32), + "model.layers.0.block_sparse_moe.experts.0.w2.weight" => @mx.zeros([4, 4]).astype(@mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w2.weight" => @mx.ones([4, 4]).astype(@mx.float32), + "model.layers.0.block_sparse_moe.experts.0.w3.weight" => @mx.zeros([4, 4]).astype(@mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w3.weight" => @mx.ones([4, 4]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w1.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w1.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w2.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w2.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w3.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w3.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.up_proj.weight") + assert sanitized.key?("model.embed_tokens.weight") + assert_equal [2, 4, 4], sanitized["model.layers.0.block_sparse_moe.switch_mlp.gate_proj.weight"].shape + assert_equal [2, 4, 4], sanitized["model.layers.0.block_sparse_moe.switch_mlp.down_proj.weight"].shape + assert_equal [2, 4, 4], sanitized["model.layers.0.block_sparse_moe.switch_mlp.up_proj.weight"].shape + end +end + +class Phase22DenseLaneAMNemotronNasTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_nemotron_nas_construct_forward_shape_sanitize_and_cache + args = MlxLm::Models::NemotronNas::ModelArgs.from_dict({ + "model_type" => "nemotron-nas", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "hidden_act" => "silu", + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "rope_scaling" => { + "type" => "linear", + "factor" => 2.0, + }, + "max_position_embeddings" => 128, + "tie_word_embeddings" => true, + "block_configs" => [ + { + "attention" => { "n_heads_in_group" => 2 }, + "ffn" => { "ffn_mult" => 1.5 }, + }, + { + "attention" => { "no_op" => true }, + "ffn" => { "replace_with_linear" => true }, + }, + ], + }) + + assert_equal "linear", args.rope_scaling["rope_type"] + + model = MlxLm::Models::NemotronNas::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 101], output.shape + assert_equal 2, model.layers.length + assert_nil model.layers[1].self_attn + + weights = { + "lm_head.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "model.embed_tokens.weight" => @mx.zeros([101, 32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("lm_head.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("model.embed_tokens.weight") + + caches = model.make_cache + assert_equal 1, caches.length + assert_instance_of MlxLm::KVCache, caches[0] + end +end + +class Phase22DenseLaneAMRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("minimax"), "minimax should be registered" + assert MlxLm::Models::REGISTRY.key?("nemotron-nas"), "nemotron-nas should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "minimax" }) + assert_equal MlxLm::Models::Minimax::Model, model_class + assert_equal MlxLm::Models::Minimax::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "nemotron-nas" }) + assert_equal MlxLm::Models::NemotronNas::Model, model_class + assert_equal MlxLm::Models::NemotronNas::ModelArgs, args_class + end +end diff --git a/test/parity/mistral3_solar_open_models_test.rb b/test/parity/mistral3_solar_open_models_test.rb new file mode 100644 index 0000000..87c7040 --- /dev/null +++ b/test/parity/mistral3_solar_open_models_test.rb @@ -0,0 +1,138 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/mistral3" +require_relative "../../lib/mlx_lm/models/solar_open" + +class Phase19DenseLanePMistral3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_mistral3_construct_forward_shape_and_sanitize_drops_multimodal_and_rotary + args = MlxLm::Models::Mistral3::ModelArgs.from_dict({ + "model_type" => "mistral3", + "text_config" => { + "model_type" => "llama", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "vocab_size" => 96, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "tie_word_embeddings" => true, + }, + }) + + model = MlxLm::Models::Mistral3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 96], output.shape + + weights = { + "vision_tower.patch_embed.weight" => @mx.zeros([1]).astype(@mx.float32), + "multi_modal_projector.linear.weight" => @mx.zeros([1]).astype(@mx.float32), + "language_model.model.embed_tokens.weight" => @mx.zeros([96, 32]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + "language_model.lm_head.weight" => @mx.zeros([96, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("vision_tower.patch_embed.weight") + refute sanitized.key?("multi_modal_projector.linear.weight") + refute sanitized.key?("language_model.model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("language_model.lm_head.weight") + assert sanitized.key?("language_model.model.embed_tokens.weight") + end +end + +class Phase19DenseLanePSolarOpenTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_solar_open_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::SolarOpen::ModelArgs.from_dict({ + "model_type" => "solar_open", + "vocab_size" => 96, + "hidden_size" => 32, + "intermediate_size" => 64, + "moe_intermediate_size" => 16, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "head_dim" => 8, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "num_experts_per_tok" => 1, + "first_k_dense_replace" => 1, + "norm_topk_prob" => false, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + "partial_rotary_factor" => 1.0, + "attention_bias" => false, + "use_qk_norm" => false, + "n_group" => 1, + "topk_group" => 1, + "scoring_func" => "sigmoid", + "topk_method" => "noaux_tc", + }) + + model = MlxLm::Models::SolarOpen::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 96], output.shape + + weights = { + "model.embed_tokens.weight" => @mx.zeros([96, 32]).astype(@mx.float32), + "model.layers.1.mlp.experts.0.gate_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.1.mlp.experts.1.gate_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.2.mlp.gate_proj.weight" => @mx.ones([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked = sanitized["model.layers.1.mlp.switch_mlp.gate_proj.weight"] + @mx.eval(stacked) + + assert_equal [2, 2, 2], stacked.shape + assert_equal [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], stacked.to_a + refute sanitized.key?("model.layers.1.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.2.mlp.gate_proj.weight") + assert sanitized.key?("model.embed_tokens.weight") + end +end + +class Phase19DenseLanePRegistryTest < Minitest::Test + def test_models_registered_and_resolvable + assert MlxLm::Models::REGISTRY.key?("mistral3"), "mistral3 should be registered" + assert MlxLm::Models::REGISTRY.key?("solar_open"), "solar_open should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "mistral3" }) + assert_equal MlxLm::Models::Mistral3::Model, model_class + assert_equal MlxLm::Models::Mistral3::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "solar_open" }) + assert_equal MlxLm::Models::SolarOpen::Model, model_class + assert_equal MlxLm::Models::SolarOpen::ModelArgs, args_class + end +end diff --git a/test/parity/mla_multilinear_quantization_test.rb b/test/parity/mla_multilinear_quantization_test.rb new file mode 100644 index 0000000..7eed925 --- /dev/null +++ b/test/parity/mla_multilinear_quantization_test.rb @@ -0,0 +1,64 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/mla" + +class Phase14MLATest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_multilinear_forward_matches_matmul_when_transposed + layer = MlxLm::Models::MLA::MultiLinear.new(16, 8, 3) + x = @mx.random_uniform([2, 3, 4, 16], -1.0, 1.0, @mx.float32) + + expected = @mx.matmul(x, @mx.swapaxes(layer.weight, -1, -2)) + actual = layer.call(x) + + @mx.eval(expected, actual) + assert_equal expected.shape, actual.shape + assert_arrays_close(expected.tolist, actual.tolist, atol: 1e-6, msg: "transpose=true should match matmul") + end + + def test_multilinear_forward_matches_matmul_without_transpose + layer = MlxLm::Models::MLA::MultiLinear.new(16, 8, 3) + x = @mx.random_uniform([2, 3, 4, 8], -1.0, 1.0, @mx.float32) + + expected = @mx.matmul(x, layer.weight) + actual = layer.call(x, transpose: false) + + @mx.eval(expected, actual) + assert_equal expected.shape, actual.shape + assert_arrays_close(expected.tolist, actual.tolist, atol: 1e-6, msg: "transpose=false should match matmul") + end + + def test_to_quantized_returns_quantized_multilinear_and_runs_forward + layer = MlxLm::Models::MLA::MultiLinear.new(64, 32, 2) + qlayer = layer.to_quantized(group_size: 32, bits: 4, mode: "affine") + + assert_instance_of MlxLm::Models::MLA::QuantizedMultiLinear, qlayer + assert_equal 32, qlayer.group_size + assert_equal 4, qlayer.bits + assert_equal "affine", qlayer.mode + + x = @mx.random_uniform([1, 2, 5, 64], -1.0, 1.0, @mx.float32) + y_fp = layer.call(x) + y_q = qlayer.call(x) + @mx.eval(y_fp, y_q) + + assert_equal y_fp.shape, y_q.shape + assert y_q.dtype == y_fp.dtype + + fp_flat = y_fp.tolist.flatten + q_flat = y_q.tolist.flatten + mae = fp_flat.zip(q_flat).map { |a, b| (a - b).abs }.sum / fp_flat.length.to_f + assert mae < 0.25, "quantized output drift too high: mae=#{mae}" + end + + def test_to_quantized_rejects_quantized_input + layer = MlxLm::Models::MLA::MultiLinear.new(16, 8, 2) + assert_raises(ArgumentError) do + layer.to_quantized(group_size: 32, bits: 4, quantize_input: true) + end + end +end diff --git a/test/parity/phase1_test.rb b/test/parity/model_args_weights_config_test.rb similarity index 100% rename from test/parity/phase1_test.rb rename to test/parity/model_args_weights_config_test.rb diff --git a/test/parity/nemotron_h_plamo2_models_test.rb b/test/parity/nemotron_h_plamo2_models_test.rb new file mode 100644 index 0000000..ef37b64 --- /dev/null +++ b/test/parity/nemotron_h_plamo2_models_test.rb @@ -0,0 +1,172 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/nemotron_h" +require_relative "../../lib/mlx_lm/models/plamo2" + +class Phase28HybridLaneATNemotronHTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_nemotron_h_wrapper_construct_forward_shape_sanitize_mapping_and_cache + args = MlxLm::Models::NemotronH::ModelArgs.from_dict({ + "model_type" => "nemotron_h", + "vocab_size" => 83, + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "max_position_embeddings" => 128, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "attention_bias" => false, + "mamba_num_heads" => 4, + "mamba_head_dim" => 8, + "conv_kernel" => 3, + "layer_norm_epsilon" => 1e-5, + "hybrid_override_pattern" => ["M", "*"], + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::NemotronH::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 83], output.shape + + expert0 = @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32) + expert1 = @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32) + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 1, 6]) + weights = { + "backbone.layers.0.mixer.experts.0.up_proj.weight" => expert0, + "backbone.layers.0.mixer.experts.1.up_proj.weight" => expert1, + "backbone.layers.0.mixer.in_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "backbone.layers.0.mixer.conv1d.weight" => conv_weight, + "backbone.layers.0.norm.weight" => @mx.ones([32]).astype(@mx.float32), + "backbone.norm_f.weight" => @mx.ones([32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + stacked_experts = sanitized["model.layers.0.mlp_block.switch_mlp.up_proj.weight"] + mapped_in_proj = sanitized["model.layers.0.temporal_block.linear_x.weight"] + mapped_conv = sanitized["model.layers.0.temporal_block.conv_1d.weight"] + mapped_layer_norm = sanitized["model.layers.0.temporal_pre_norm.weight"] + mapped_final_norm = sanitized["model.final_norm.weight"] + @mx.eval(stacked_experts, mapped_in_proj, mapped_conv, mapped_layer_norm, mapped_final_norm) + + refute sanitized.key?("backbone.layers.0.mixer.experts.0.up_proj.weight") + refute sanitized.key?("backbone.layers.0.mixer.experts.1.up_proj.weight") + refute sanitized.key?("backbone.layers.0.mixer.in_proj.weight") + refute sanitized.key?("backbone.layers.0.mixer.conv1d.weight") + refute sanitized.key?("backbone.layers.0.norm.weight") + refute sanitized.key?("backbone.norm_f.weight") + assert sanitized.key?("model.layers.0.mlp_block.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + assert sanitized.key?("model.layers.0.temporal_block.conv_1d.weight") + assert sanitized.key?("model.layers.0.temporal_pre_norm.weight") + assert sanitized.key?("model.final_norm.weight") + assert_equal [2, 2, 2], stacked_experts.shape + assert_equal expert0.to_a, stacked_experts[0].to_a + assert_equal expert1.to_a, stacked_experts[1].to_a + assert_equal [4, 6, 1], mapped_conv.shape + + cache = model.make_cache + assert_equal 2, cache.length + end +end + +class Phase28HybridLaneATPlamo2Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_plamo2_wrapper_construct_forward_shape_sanitize_mapping_and_cache + args = MlxLm::Models::Plamo2::ModelArgs.from_dict({ + "model_type" => "plamo2", + "hidden_size" => 32, + "num_hidden_layers" => 3, + "rms_norm_eps" => 1e-5, + "tie_word_embeddings" => true, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "hidden_size_per_head" => 8, + "max_position_embeddings" => 128, + "attention_window_size" => 16, + "mamba_d_conv" => 3, + "mamba_step" => 2, + "mamba_enabled" => true, + "intermediate_size" => 64, + "vocab_size" => 71, + }) + + model = MlxLm::Models::Plamo2::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 71], output.shape + + gate_up = @mx.array((0...16).to_a, dtype: @mx.float32).reshape([8, 2]) + expected_gate, expected_up = @mx.split(gate_up, [4], 0) + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 1, 6]) + weights = { + "model.layers.layers.0.mlp.gate_up_proj.weight" => gate_up, + "model.layers.layers.0.mlp.down_proj.weight" => @mx.zeros([2, 4]).astype(@mx.float32), + "model.layers.layers.0.mixer.in_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.layers.layers.0.mixer.conv1d.weight" => conv_weight, + "model.layers.layers.0.pre_mixer_norm.weight" => @mx.ones([32]).astype(@mx.float32), + "model.layers.layers.0.pre_mlp_norm.weight" => @mx.ones([32]).astype(@mx.float32), + "model.norm.weight" => @mx.ones([32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + mapped_gate = sanitized["model.layers.0.mlp_block.gate_proj.weight"] + mapped_up = sanitized["model.layers.0.mlp_block.up_proj.weight"] + mapped_down = sanitized["model.layers.0.mlp_block.down_proj.weight"] + mapped_in_proj = sanitized["model.layers.0.temporal_block.linear_x.weight"] + mapped_conv = sanitized["model.layers.0.temporal_block.conv_1d.weight"] + mapped_pre_mixer_norm = sanitized["model.layers.0.temporal_pre_norm.weight"] + mapped_pre_mlp_norm = sanitized["model.layers.0.channel_pre_norm.weight"] + mapped_final_norm = sanitized["model.final_norm.weight"] + @mx.eval(mapped_gate, mapped_up, mapped_down, mapped_in_proj, mapped_conv, mapped_pre_mixer_norm, mapped_pre_mlp_norm, mapped_final_norm) + + refute sanitized.key?("model.layers.layers.0.mlp.gate_up_proj.weight") + refute sanitized.key?("model.layers.layers.0.mixer.in_proj.weight") + refute sanitized.key?("model.layers.layers.0.mixer.conv1d.weight") + refute sanitized.key?("model.layers.layers.0.pre_mixer_norm.weight") + refute sanitized.key?("model.layers.layers.0.pre_mlp_norm.weight") + refute sanitized.key?("model.norm.weight") + assert sanitized.key?("model.layers.0.mlp_block.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp_block.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp_block.down_proj.weight") + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + assert sanitized.key?("model.layers.0.temporal_block.conv_1d.weight") + assert sanitized.key?("model.layers.0.temporal_pre_norm.weight") + assert sanitized.key?("model.layers.0.channel_pre_norm.weight") + assert sanitized.key?("model.final_norm.weight") + assert_equal expected_gate.to_a, mapped_gate.to_a + assert_equal expected_up.to_a, mapped_up.to_a + assert_equal [4, 6, 1], mapped_conv.shape + + cache = model.make_cache + assert_equal 3, cache.length + end +end + +class Phase28HybridLaneATRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("nemotron_h"), "nemotron_h should be registered" + assert MlxLm::Models::REGISTRY.key?("plamo2"), "plamo2 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "nemotron_h" }) + assert_equal MlxLm::Models::NemotronH::Model, model_class + assert_equal MlxLm::Models::NemotronH::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "plamo2" }) + assert_equal MlxLm::Models::Plamo2::Model, model_class + assert_equal MlxLm::Models::Plamo2::ModelArgs, args_class + end +end diff --git a/test/parity/phase11_test.rb b/test/parity/olmo2_gpt_neox_mixtral_deepseek_internlm2_test.rb similarity index 95% rename from test/parity/phase11_test.rb rename to test/parity/olmo2_gpt_neox_mixtral_deepseek_internlm2_test.rb index e86f4bc..2bb2259 100644 --- a/test/parity/phase11_test.rb +++ b/test/parity/olmo2_gpt_neox_mixtral_deepseek_internlm2_test.rb @@ -26,7 +26,7 @@ def test_olmo2_forward model = MlxLm::Models::OLMo2::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape @@ -81,7 +81,7 @@ def test_gpt_neox_forward model = MlxLm::Models::GPTNeoX::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape @@ -102,7 +102,7 @@ def test_gpt_neox_parallel_residual model = MlxLm::Models::GPTNeoX::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape @@ -138,7 +138,7 @@ def test_mixtral_forward model = MlxLm::Models::Mixtral::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape @@ -192,7 +192,7 @@ def test_deepseek_forward_dense model = MlxLm::Models::DeepSeek::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape @@ -218,7 +218,7 @@ def test_deepseek_forward_moe model = MlxLm::Models::DeepSeek::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape @@ -253,7 +253,7 @@ def test_internlm2_forward model = MlxLm::Models::InternLM2::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(input) @mx.eval(output) assert_equal [1, 3, 128], output.shape diff --git a/test/parity/olmo3_gpt2_models_test.rb b/test/parity/olmo3_gpt2_models_test.rb new file mode 100644 index 0000000..9e92976 --- /dev/null +++ b/test/parity/olmo3_gpt2_models_test.rb @@ -0,0 +1,92 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/olmo3" +require_relative "../../lib/mlx_lm/models/gpt2" + +class Phase18DenseLaneIRegistryTest < Minitest::Test + def test_registry_entries_for_lane_i_models + assert MlxLm::Models::REGISTRY.key?("olmo3"), "olmo3 should be registered" + assert MlxLm::Models::REGISTRY.key?("gpt2"), "gpt2 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "olmo3" }) + assert_equal MlxLm::Models::OLMo3::Model, model_class + assert_equal MlxLm::Models::OLMo3::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "gpt2" }) + assert_equal MlxLm::Models::GPT2::Model, model_class + assert_equal MlxLm::Models::GPT2::ModelArgs, args_class + end +end + +class Phase18DenseLaneIOLMo3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_olmo3_construct_forward_shape_and_cache_types + args = MlxLm::Models::OLMo3::ModelArgs.from_dict({ + "model_type" => "olmo3", + "hidden_size" => 48, + "num_hidden_layers" => 4, + "intermediate_size" => 96, + "num_attention_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 128, + "max_position_embeddings" => 256, + "sliding_window" => 4, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::OLMo3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 128], output.shape + + caches = model.make_cache + assert_equal 4, caches.length + assert_instance_of MlxLm::RotatingKVCache, caches[0] + assert_instance_of MlxLm::RotatingKVCache, caches[1] + assert_instance_of MlxLm::RotatingKVCache, caches[2] + assert_instance_of MlxLm::KVCache, caches[3] + end +end + +class Phase18DenseLaneIGPT2Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_gpt2_construct_and_forward_shape + args = MlxLm::Models::GPT2::ModelArgs.from_dict({ + "model_type" => "gpt2", + "n_ctx" => 16, + "n_embd" => 32, + "n_head" => 4, + "n_layer" => 2, + "n_positions" => 16, + "layer_norm_epsilon" => 1e-5, + "vocab_size" => 96, + }) + + model = MlxLm::Models::GPT2::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 96], output.shape + end +end diff --git a/test/parity/olmo_seed_oss_models_test.rb b/test/parity/olmo_seed_oss_models_test.rb new file mode 100644 index 0000000..ddae11f --- /dev/null +++ b/test/parity/olmo_seed_oss_models_test.rb @@ -0,0 +1,69 @@ +require_relative "../test_helper" + +class Phase16DenseLaneDTest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_olmo_construct_and_forward_shape + args = MlxLm::Models::OLMo::ModelArgs.from_dict({ + "model_type" => "olmo", + "d_model" => 64, + "n_layers" => 2, + "mlp_hidden_size" => 128, + "n_heads" => 2, + "vocab_size" => 128, + "embedding_size" => 128, + "weight_tying" => false, + }) + + model = MlxLm::Models::OLMo::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 128], output.shape + end + + def test_seed_oss_construct_and_forward_shape + args = MlxLm::Models::SeedOSS::ModelArgs.from_dict({ + "model_type" => "seed_oss", + "hidden_size" => 64, + "num_hidden_layers" => 2, + "intermediate_size" => 128, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "head_dim" => 32, + "rms_norm_eps" => 1e-6, + "vocab_size" => 128, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::SeedOSS::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3]], @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 128], output.shape + end + + def test_olmo_registered + assert MlxLm::Models::REGISTRY.key?("olmo"), "olmo should be registered" + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "olmo" }) + assert_equal MlxLm::Models::OLMo::Model, model_class + assert_equal MlxLm::Models::OLMo::ModelArgs, args_class + end + + def test_seed_oss_registered + assert MlxLm::Models::REGISTRY.key?("seed_oss"), "seed_oss should be registered" + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "seed_oss" }) + assert_equal MlxLm::Models::SeedOSS::Model, model_class + assert_equal MlxLm::Models::SeedOSS::ModelArgs, args_class + end +end diff --git a/test/parity/onnx_export_test.rb b/test/parity/onnx_export_test.rb deleted file mode 100644 index c42892a..0000000 --- a/test/parity/onnx_export_test.rb +++ /dev/null @@ -1,347 +0,0 @@ -# frozen_string_literal: true - -require_relative "../test_helper" -require "tmpdir" -require "json" -require "open3" - -# ONNX Export Tests -# -# For every registered model architecture, instantiate a tiny model, trace a -# forward pass through MLX::ONNX.export_onnx, and report whether export -# succeeds. When export fails, run the compatibility report to identify the -# unsupported ops. -# -# Each model runs in an isolated subprocess because MoE models (mixtral, -# deepseek) segfault during ONNX tracing due to data-dependent control flow -# (tolist + per-token expert routing). - -class OnnxExportTest < Minitest::Test - include ParityTestHelpers - - # ── Tiny model configs ──────────────────────────────────────────────── - # Each config uses the smallest possible dimensions so instantiation and - # tracing finish quickly without consuming real memory. - - TINY_CONFIGS = { - "llama" => { - "model_type" => "llama", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "tie_word_embeddings" => true, - }, - "gemma" => { - "model_type" => "gemma", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "head_dim" => 32, - "tie_word_embeddings" => true, - }, - "gemma2" => { - "model_type" => "gemma2", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "head_dim" => 32, - "query_pre_attn_scalar" => 32.0, - }, - "qwen2" => { - "model_type" => "qwen2", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "tie_word_embeddings" => true, - }, - "phi3" => { - "model_type" => "phi3", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "tie_word_embeddings" => true, - }, - "starcoder2" => { - "model_type" => "starcoder2", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "tie_word_embeddings" => true, - }, - "stablelm" => { - "model_type" => "stablelm", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - }, - "cohere" => { - "model_type" => "cohere", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - }, - "olmo2" => { - "model_type" => "olmo2", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "tie_word_embeddings" => true, - }, - "gpt_neox" => { - "model_type" => "gpt_neox", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "vocab_size" => 128, - "intermediate_size" => 256, - }, - "mixtral" => { - "model_type" => "mixtral", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "num_local_experts" => 2, - "num_experts_per_tok" => 1, - "tie_word_embeddings" => true, - }, - "deepseek" => { - "model_type" => "deepseek", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "moe_intermediate_size" => 64, - "vocab_size" => 128, - "n_routed_experts" => 2, - "num_experts_per_tok" => 1, - "n_shared_experts" => 1, - "moe_layer_freq" => 1, - "first_k_dense_replace" => 1, - }, - "internlm2" => { - "model_type" => "internlm2", - "hidden_size" => 64, - "num_hidden_layers" => 2, - "num_attention_heads" => 2, - "num_key_value_heads" => 2, - "intermediate_size" => 128, - "vocab_size" => 128, - "bias" => false, - "tie_word_embeddings" => true, - }, - }.freeze - - # ── Subprocess runner ───────────────────────────────────────────────── - # Runs a single model's ONNX export + compat report in an isolated Ruby - # process. Returns JSON with results. Catches segfaults gracefully. - - SUBPROCESS_SCRIPT = <<~'RUBY' - require "json" - require "tmpdir" - - $LOAD_PATH.unshift File.expand_path("lib", __dir__) - $LOAD_PATH.unshift File.expand_path("mlx-ruby/lib", __dir__) - require "mlx" - require "mlx_lm" - - config = JSON.parse(ARGV[0]) - model_type = config["model_type"] - result = { "model_type" => model_type } - - begin - mx = MLX::Core - model_class, args_class = MlxLm::Models.get_classes(config) - args = args_class.from_dict(config) - model = model_class.new(args) - - params = MLX::Utils.tree_flatten(model.parameters).map { |_, v| v } - mx.eval(*params) unless params.empty? - - input = mx.array([[1, 2, 3]]).astype(mx.int32) - fun = ->(x) { model.call(x) } - - # Run compatibility report first (does not require full lowering) - begin - report = MLX::ONNX.export_onnx_compatibility_report(fun, input) - result["compat_report"] = { - "total_nodes" => report["total_nodes"], - "supported_nodes" => report["supported_nodes"], - "unsupported_nodes"=> report["unsupported_nodes"], - "unsupported_ops" => report["unsupported_ops"], - "ready" => report["ready_for_stub_conversion"], - } - rescue => e - result["compat_error"] = "#{e.class}: #{e.message}" - end - - # Attempt full ONNX export - begin - Dir.mktmpdir do |dir| - path = File.join(dir, "#{model_type}.onnx") - MLX::ONNX.export_onnx(path, fun, input) - result["export"] = "success" - result["onnx_size"] = File.size(path) - end - rescue NotImplementedError, RuntimeError => e - result["export"] = "failed" - result["export_error"] = e.message - end - rescue => e - result["fatal"] = "#{e.class}: #{e.message}" - end - - puts JSON.generate(result) - RUBY - - def run_model_in_subprocess(model_type) - config_json = JSON.generate(TINY_CONFIGS.fetch(model_type)) - project_root = File.expand_path("../..", __dir__) - - out, err, status = Open3.capture3( - "ruby", "-e", SUBPROCESS_SCRIPT, config_json, - chdir: project_root - ) - - if status.signaled? - sig = status.termsig - signal_name = Signal.signame(sig) rescue sig.to_s - return { - "model_type" => model_type, - "export" => "crashed", - "crash_signal" => signal_name, - "stderr" => err.lines.first(5).join, - } - end - - unless status.success? - return { - "model_type" => model_type, - "export" => "process_error", - "exit_code" => status.exitstatus, - "stderr" => err.lines.first(10).join, - } - end - - JSON.parse(out) - rescue JSON::ParserError - { - "model_type" => model_type, - "export" => "parse_error", - "stdout" => out.to_s[0, 500], - "stderr" => err.to_s[0, 500], - } - end - - # ── Per-model export tests ──────────────────────────────────────────── - - TINY_CONFIGS.each_key do |model_type| - define_method(:"test_onnx_export_#{model_type}") do - result = run_model_in_subprocess(model_type) - - case result["export"] - when "success" - assert true, "#{model_type}: ONNX export succeeded (#{result['onnx_size']} bytes)" - report = result["compat_report"] - if report - puts "\n [ONNX] #{model_type}: PASS — #{report['supported_nodes']}/#{report['total_nodes']} nodes, #{result['onnx_size']} bytes" - end - - when "failed" - report = result["compat_report"] - msg = "#{model_type}: ONNX export failed — #{result['export_error']}" - if report - unsupported = report["unsupported_ops"] || [] - msg += "\n Nodes: #{report['supported_nodes']}/#{report['total_nodes']} supported" - msg += "\n Missing ops: #{unsupported.join(', ')}" - end - flunk(msg) - - when "crashed" - report = result["compat_report"] - msg = "#{model_type}: ONNX tracing crashed with signal #{result['crash_signal']}" - if report - unsupported = report["unsupported_ops"] || [] - msg += "\n Compat report (pre-crash): #{report['supported_nodes']}/#{report['total_nodes']} nodes" - msg += "\n Missing ops: #{unsupported.empty? ? 'none' : unsupported.join(', ')}" - end - msg += "\n (MoE models crash because tolist forces data-dependent control flow during tracing)" - flunk(msg) - - else - flunk("#{model_type}: unexpected result — #{result.inspect}") - end - end - end - - # ── Compatibility report tests (always run) ─────────────────────────── - - TINY_CONFIGS.each_key do |model_type| - define_method(:"test_onnx_compat_report_#{model_type}") do - result = run_model_in_subprocess(model_type) - - if result["compat_error"] - skip "#{model_type}: compat report unavailable — #{result['compat_error']}" - end - - if result["crash_signal"] - # Crashed before we could get a report - if result["compat_report"] - report = result["compat_report"] - assert_kind_of Integer, report["total_nodes"] - unsupported = report["unsupported_ops"] || [] - pct = report["total_nodes"] > 0 ? (report["supported_nodes"].to_f / report["total_nodes"] * 100).round(1) : 0 - puts "\n [ONNX] #{model_type}: #{report['supported_nodes']}/#{report['total_nodes']} nodes (#{pct}%) — missing: #{unsupported.empty? ? 'none' : unsupported.join(', ')} (CRASH during export)" - else - skip "#{model_type}: process crashed (signal #{result['crash_signal']}) before compat report" - end - return - end - - report = result["compat_report"] - skip "#{model_type}: no compat report in result" unless report - - assert_kind_of Integer, report["total_nodes"] - assert_kind_of Integer, report["supported_nodes"] - assert_kind_of Integer, report["unsupported_nodes"] - - total = report["total_nodes"] - supported = report["supported_nodes"] - unsupported_ops = report["unsupported_ops"] || [] - pct = total > 0 ? (supported.to_f / total * 100).round(1) : 0 - status = result["export"] == "success" ? "PASS" : "FAIL" - puts "\n [ONNX] #{model_type}: #{status} — #{supported}/#{total} nodes (#{pct}%) — missing: #{unsupported_ops.empty? ? 'none' : unsupported_ops.join(', ')}" - end - end -end diff --git a/test/parity/phi3small_dots1_models_test.rb b/test/parity/phi3small_dots1_models_test.rb new file mode 100644 index 0000000..22d08ce --- /dev/null +++ b/test/parity/phi3small_dots1_models_test.rb @@ -0,0 +1,153 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/phi3small" +require_relative "../../lib/mlx_lm/models/dots1" + +class Phase24DenseLaneAGPhi3smallTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_phi3small_construct_forward_shape_and_sanitize_inv_freq_cleanup + args = MlxLm::Models::Phi3small::ModelArgs.from_dict({ + "model_type" => "phi3small", + "hidden_size" => 32, + "dense_attention_every_n_layers" => 1, + "ff_intermediate_size" => 64, + "gegelu_limit" => 16.0, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "layer_norm_epsilon" => 1e-5, + "vocab_size" => 97, + "num_key_value_heads" => 2, + "mup_attn_multiplier" => 1.0, + "mup_use_scaling" => true, + "mup_embedding_multiplier" => 1.0, + "mup_width_multiplier" => 1.0, + "rope_embedding_base" => 10_000.0, + "rope_position_scale" => 1.0, + "blocksparse_block_size" => 64, + "blocksparse_num_local_blocks" => 4, + "blocksparse_vert_stride" => 2, + }) + + model = MlxLm::Models::Phi3small::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 97], output.shape + + weights = { + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([1]).astype(@mx.float32), + "model.layers.0.self_attn.position_embeddings.inv_freq" => @mx.zeros([1]).astype(@mx.float32), + "model.layers.0.self_attn.query_key_value.weight" => @mx.zeros([64, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + refute sanitized.key?("model.layers.0.self_attn.position_embeddings.inv_freq") + assert sanitized.key?("model.layers.0.self_attn.query_key_value.weight") + end +end + +class Phase24DenseLaneAGDots1Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_dots1_construct_forward_shape_and_sanitize_stacks_switch_glu_experts + args = MlxLm::Models::Dots1::ModelArgs.from_dict({ + "model_type" => "dots1", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "max_position_embeddings" => 256, + "num_key_value_heads" => 2, + "first_k_dense_replace" => 1, + "moe_intermediate_size" => 48, + "n_routed_experts" => 2, + "n_shared_experts" => 1, + "norm_topk_prob" => true, + "num_experts_per_tok" => 1, + "rope_theta" => 10_000.0, + "routed_scaling_factor" => 1.0, + "head_dim" => 8, + "scoring_func" => "noaux_tc", + "n_group" => 1, + "topk_group" => 1, + "attention_bias" => false, + "mlp_bias" => false, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::Dots1::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 101], output.shape + + weights = { + "model.layers.1.mlp.experts.0.gate_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.1.gate_proj.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.0.up_proj.weight" => @mx.array((48...72).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.1.up_proj.weight" => @mx.array((72...96).to_a, dtype: @mx.float32).reshape([6, 4]), + "model.layers.1.mlp.experts.0.down_proj.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.1.mlp.experts.1.down_proj.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([4, 6]), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([1]).astype(@mx.float32), + "model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + stacked_gate = sanitized["model.layers.1.mlp.experts.gate_proj.weight"] + stacked_up = sanitized["model.layers.1.mlp.experts.up_proj.weight"] + stacked_down = sanitized["model.layers.1.mlp.experts.down_proj.weight"] + @mx.eval(stacked_gate, stacked_up, stacked_down) + + refute sanitized.key?("model.layers.1.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.1.mlp.experts.1.down_proj.weight") + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("model.layers.1.mlp.experts.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.experts.up_proj.weight") + assert sanitized.key?("model.layers.1.mlp.experts.down_proj.weight") + assert sanitized.key?("model.layers.0.self_attn.q_proj.weight") + assert_equal [2, 6, 4], stacked_gate.shape + assert_equal [2, 6, 4], stacked_up.shape + assert_equal [2, 4, 6], stacked_down.shape + end +end + +class Phase24DenseLaneAGRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("phi3small"), "phi3small should be registered" + assert MlxLm::Models::REGISTRY.key?("dots1"), "dots1 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "phi3small" }) + assert_equal MlxLm::Models::Phi3small::Model, model_class + assert_equal MlxLm::Models::Phi3small::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "dots1" }) + assert_equal MlxLm::Models::Dots1::Model, model_class + assert_equal MlxLm::Models::Dots1::ModelArgs, args_class + end +end diff --git a/test/parity/phi_exaone_models_test.rb b/test/parity/phi_exaone_models_test.rb new file mode 100644 index 0000000..35f78b7 --- /dev/null +++ b/test/parity/phi_exaone_models_test.rb @@ -0,0 +1,93 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/phi" +require_relative "../../lib/mlx_lm/models/exaone" + +class Phase16DenseLaneBRegistryTest < Minitest::Test + def test_models_registered + assert MlxLm::Models::REGISTRY.key?("phi"), "phi should be registered" + assert MlxLm::Models::REGISTRY.key?("exaone"), "exaone should be registered" + end + + def test_get_classes_resolves_lane_b_models + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "phi" }) + assert_equal MlxLm::Models::Phi::Model, model_class + assert_equal MlxLm::Models::Phi::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "exaone" }) + assert_equal MlxLm::Models::Exaone::Model, model_class + assert_equal MlxLm::Models::Exaone::ModelArgs, args_class + end +end + +class Phase16DenseLaneBPhiTest < Minitest::Test + def setup + @mx = MLX::Core + @args = MlxLm::Models::Phi::ModelArgs.from_dict({ + "model_type" => "phi", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "intermediate_size" => 64, + "vocab_size" => 100, + "layer_norm_eps" => 1e-5, + "partial_rotary_factor" => 0.5, + "rope_theta" => 10_000.0, + }) + end + + def test_phi_model_instantiates + model = MlxLm::Models::Phi::Model.new(@args) + assert_instance_of MlxLm::Models::Phi::Model, model + assert_equal 2, model.layers.length + end + + def test_phi_forward_shape + model = MlxLm::Models::Phi::Model.new(@args) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + assert_equal [1, 3, 100], output.shape + end +end + +class Phase16DenseLaneBExaoneTest < Minitest::Test + def setup + @mx = MLX::Core + @args = MlxLm::Models::Exaone::ModelArgs.from_dict({ + "model_type" => "exaone", + "hidden_size" => 32, + "num_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "vocab_size" => 100, + "rope_theta" => 10_000.0, + "layer_norm_epsilon" => 1e-5, + "tie_word_embeddings" => true, + "attention_bias" => false, + "mlp_bias" => false, + }) + end + + def test_exaone_model_instantiates + model = MlxLm::Models::Exaone::Model.new(@args) + assert_instance_of MlxLm::Models::Exaone::Model, model + assert_equal 2, model.layers.length + end + + def test_exaone_forward_shape + model = MlxLm::Models::Exaone::Model.new(@args) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + assert_equal [1, 3, 100], output.shape + end +end diff --git a/test/parity/phixtral_minicpm3_models_test.rb b/test/parity/phixtral_minicpm3_models_test.rb new file mode 100644 index 0000000..f06b0f2 --- /dev/null +++ b/test/parity/phixtral_minicpm3_models_test.rb @@ -0,0 +1,151 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/phixtral" +require_relative "../../lib/mlx_lm/models/minicpm3" + +class Phase21DenseLaneXPhixtralTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_phixtral_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::Phixtral::ModelArgs.from_dict({ + "model_type" => "phixtral", + "num_vocab" => 79, + "model_dim" => 32, + "num_heads" => 4, + "num_layers" => 2, + "rotary_dim" => 4, + "num_experts_per_tok" => 1, + "num_local_experts" => 2, + }) + + model = MlxLm::Models::Phixtral::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 79], output.shape + + weights = { + "transformer.h.0.moe.mlp.0.fc1.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([6, 4]), + "transformer.h.0.moe.mlp.1.fc1.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([6, 4]), + "transformer.h.0.moe.mlp.0.fc2.weight" => @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 6]), + "transformer.h.0.moe.mlp.1.fc2.weight" => @mx.array((24...48).to_a, dtype: @mx.float32).reshape([4, 6]), + "transformer.h.0.moe.mlp.0.fc1.bias" => @mx.array((0...6).to_a, dtype: @mx.float32), + "transformer.h.0.moe.mlp.1.fc1.bias" => @mx.array((6...12).to_a, dtype: @mx.float32), + "transformer.h.0.moe.mlp.0.fc2.bias" => @mx.array((0...4).to_a, dtype: @mx.float32), + "transformer.h.0.moe.mlp.1.fc2.bias" => @mx.array((4...8).to_a, dtype: @mx.float32), + "transformer.h.1.mixer.wqkv.weight" => @mx.zeros([96, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + fc1_stacked = sanitized["transformer.h.0.moe.switch_mlp.fc1.weight"] + fc2_stacked = sanitized["transformer.h.0.moe.switch_mlp.fc2.weight"] + @mx.eval(fc1_stacked, fc2_stacked) + + refute sanitized.key?("transformer.h.0.moe.mlp.0.fc1.weight") + refute sanitized.key?("transformer.h.0.moe.mlp.1.fc1.weight") + refute sanitized.key?("transformer.h.0.moe.mlp.0.fc2.weight") + refute sanitized.key?("transformer.h.0.moe.mlp.1.fc2.weight") + refute sanitized.key?("transformer.h.0.moe.mlp.0.fc1.bias") + refute sanitized.key?("transformer.h.0.moe.mlp.1.fc1.bias") + refute sanitized.key?("transformer.h.0.moe.mlp.0.fc2.bias") + refute sanitized.key?("transformer.h.0.moe.mlp.1.fc2.bias") + assert sanitized.key?("transformer.h.0.moe.switch_mlp.fc1.weight") + assert sanitized.key?("transformer.h.0.moe.switch_mlp.fc2.weight") + assert sanitized.key?("transformer.h.0.moe.switch_mlp.fc1.bias") + assert sanitized.key?("transformer.h.0.moe.switch_mlp.fc2.bias") + assert sanitized.key?("transformer.h.1.mixer.wqkv.weight") + assert_equal [2, 6, 4], fc1_stacked.shape + assert_equal [2, 4, 6], fc2_stacked.shape + end +end + +class Phase21DenseLaneXMiniCPM3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_minicpm3_construct_forward_shape_and_sanitize + args_hash = { + "model_type" => "minicpm3", + "hidden_size" => 32, + "dim_model_base" => 16, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "rms_norm_eps" => 1e-5, + "vocab_size" => 89, + "q_lora_rank" => 8, + "qk_nope_head_dim" => 4, + "qk_rope_head_dim" => 4, + "kv_lora_rank" => 8, + "scale_depth" => 1.0, + "scale_emb" => 1.25, + "max_position_embeddings" => 256, + "attention_bias" => false, + "rope_theta" => 10_000.0, + "rope_scaling" => { + "original_max_position_embeddings" => 128, + "short_factor" => 1.0, + "long_factor" => 1.0, + }, + "tie_word_embeddings" => false, + } + + args = MlxLm::Models::MiniCPM3::ModelArgs.from_dict(args_hash) + model = MlxLm::Models::MiniCPM3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 89], output.shape + + weights = { + "model.embed_tokens.weight" => @mx.zeros([89, 32]).astype(@mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([8]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + assert sanitized.key?("lm_head.weight") + assert_equal [89, 32], sanitized["lm_head.weight"].shape + + tied_args = MlxLm::Models::MiniCPM3::ModelArgs.from_dict(args_hash.merge("tie_word_embeddings" => true)) + tied_model = MlxLm::Models::MiniCPM3::Model.new(tied_args) + tied_weights = { + "model.embed_tokens.weight" => @mx.zeros([89, 32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([89, 32]).astype(@mx.float32), + } + tied_sanitized = tied_model.sanitize(tied_weights) + refute tied_sanitized.key?("lm_head.weight") + end +end + +class Phase21DenseLaneXRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("phixtral"), "phixtral should be registered" + assert MlxLm::Models::REGISTRY.key?("minicpm3"), "minicpm3 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "phixtral" }) + assert_equal MlxLm::Models::Phixtral::Model, model_class + assert_equal MlxLm::Models::Phixtral::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "minicpm3" }) + assert_equal MlxLm::Models::MiniCPM3::Model, model_class + assert_equal MlxLm::Models::MiniCPM3::ModelArgs, args_class + end +end diff --git a/test/parity/pixtral_qwen2_vl_models_test.rb b/test/parity/pixtral_qwen2_vl_models_test.rb new file mode 100644 index 0000000..3be2eef --- /dev/null +++ b/test/parity/pixtral_qwen2_vl_models_test.rb @@ -0,0 +1,118 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/llama" +require_relative "../../lib/mlx_lm/models/qwen2" +require_relative "../../lib/mlx_lm/models/pixtral" +require_relative "../../lib/mlx_lm/models/qwen2_vl" + +class Phase19DenseLaneMPixtralTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_pixtral_construct_forward_shape_and_sanitize + args = MlxLm::Models::Pixtral::ModelArgs.from_dict({ + "model_type" => "pixtral", + "text_config" => { + "hidden_size" => 64, + "num_hidden_layers" => 2, + "intermediate_size" => 128, + "vocab_size" => 97, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + }, + }) + + model = MlxLm::Models::Pixtral::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 3, 97], output.shape + assert_equal false, args.text_config["tie_word_embeddings"] + assert_equal 32, args.text_config["num_attention_heads"] + + weights = { + "language_model.model.embed_tokens.weight" => @mx.zeros([97, 64]).astype(@mx.float32), + "vision_tower.encoder.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + "multi_modal_projector.proj.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("vision_tower.encoder.weight") + refute sanitized.key?("multi_modal_projector.proj.weight") + assert sanitized.key?("language_model.model.embed_tokens.weight") + end +end + +class Phase19DenseLaneMQwen2VLTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen2_vl_construct_forward_shape_and_sanitize_prefixes_language_model + args = MlxLm::Models::Qwen2VL::ModelArgs.from_dict({ + "model_type" => "qwen2_vl", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 4, + "intermediate_size" => 64, + "vocab_size" => 89, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + }) + + model = MlxLm::Models::Qwen2VL::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 4, 89], output.shape + assert_equal 32, args.text_config["hidden_size"] + + weights = { + "model.embed_tokens.weight" => @mx.zeros([89, 32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([89, 32]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "visual.encoder.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + "vision_tower.blocks.0.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("visual.encoder.weight") + refute sanitized.key?("vision_tower.blocks.0.weight") + assert sanitized.key?("language_model.model.embed_tokens.weight") + assert sanitized.key?("language_model.lm_head.weight") + assert sanitized.key?("language_model.model.layers.0.self_attn.q_proj.weight") + assert sanitized.keys.all? { |key| key.start_with?("language_model.") } + end +end + +class Phase19DenseLaneMRegistryTest < Minitest::Test + def test_models_registered_and_resolved + assert MlxLm::Models::REGISTRY.key?("pixtral"), "pixtral should be registered" + assert MlxLm::Models::REGISTRY.key?("qwen2_vl"), "qwen2_vl should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "pixtral" }) + assert_equal MlxLm::Models::Pixtral::Model, model_class + assert_equal MlxLm::Models::Pixtral::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen2_vl" }) + assert_equal MlxLm::Models::Qwen2VL::Model, model_class + assert_equal MlxLm::Models::Qwen2VL::ModelArgs, args_class + end +end diff --git a/test/parity/phase12_test.rb b/test/parity/prompt_cache_perplexity_benchmark_registry_convert_test.rb similarity index 97% rename from test/parity/phase12_test.rb rename to test/parity/prompt_cache_perplexity_benchmark_registry_convert_test.rb index 9ba2db5..759d720 100644 --- a/test/parity/phase12_test.rb +++ b/test/parity/prompt_cache_perplexity_benchmark_registry_convert_test.rb @@ -50,7 +50,7 @@ def test_prompt_cache_save_load # Create cache and populate it cache = MlxLm::Cache.make_prompt_cache(model) - input = @mx.array([[1, 2, 3]]).astype(@mx.int32) + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) model.call(input, cache: cache) @mx.eval(*cache.map(&:state).flatten.compact) @@ -94,7 +94,7 @@ def test_perplexity_computation @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) # Compute perplexity on a small sequence - tokens = @mx.array([1, 2, 3, 4, 5, 6, 7, 8]).astype(@mx.int32) + tokens = @mx.array([1, 2, 3, 4, 5, 6, 7, 8], dtype: @mx.int32) ppl = MlxLm::Perplexity.compute(model, tokens) assert ppl > 0, "Perplexity should be positive, got #{ppl}" assert ppl.is_a?(Numeric), "Perplexity should be numeric" @@ -115,7 +115,7 @@ def test_log_likelihood model = MlxLm::Models::Llama::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([1, 2, 3, 4, 5]).astype(@mx.int32) + tokens = @mx.array([1, 2, 3, 4, 5], dtype: @mx.int32) ll = MlxLm::Perplexity.log_likelihood(model, tokens) assert ll < 0, "Log-likelihood should be negative, got #{ll}" end diff --git a/test/parity/phase7_test.rb b/test/parity/quantization_pipeline_test.rb similarity index 95% rename from test/parity/phase7_test.rb rename to test/parity/quantization_pipeline_test.rb index dd204b9..bdad977 100644 --- a/test/parity/phase7_test.rb +++ b/test/parity/quantization_pipeline_test.rb @@ -46,7 +46,7 @@ def test_quantized_model_forward @mx.eval(*model.parameters.values) MlxLm::Quantize.quantize_model(model, group_size: 32, bits: 4) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 128], output.shape, "Quantized model should produce same shape output" end @@ -103,7 +103,7 @@ def test_quantized_embedding qembed = embed.to_quantized(group_size: 32, bits: 4) assert_instance_of MLX::NN::QuantizedEmbedding, qembed - ids = @mx.array([[1, 2, 3]]).astype(@mx.int32) + ids = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = qembed.call(ids) assert_equal [1, 3, 32], output.shape end @@ -135,12 +135,12 @@ def test_quantized_model_with_cache cache = Array.new(2) { MlxLm::KVCache.new } # First token - token1 = @mx.array([[1]]).astype(@mx.int32) + token1 = @mx.array([[1]], dtype: @mx.int32) out1 = model.call(token1, cache: cache) assert_equal [1, 1, 128], out1.shape # Second token (using cache) - token2 = @mx.array([[2]]).astype(@mx.int32) + token2 = @mx.array([[2]], dtype: @mx.int32) out2 = model.call(token2, cache: cache) assert_equal [1, 1, 128], out2.shape end diff --git a/test/parity/qwen2_moe_phimoe_models_test.rb b/test/parity/qwen2_moe_phimoe_models_test.rb new file mode 100644 index 0000000..37a2915 --- /dev/null +++ b/test/parity/qwen2_moe_phimoe_models_test.rb @@ -0,0 +1,173 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/qwen2" +require_relative "../../lib/mlx_lm/models/qwen2_moe" +require_relative "../../lib/mlx_lm/models/phimoe" + +class Phase21DenseLaneWQwen2MoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen2_moe_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::Qwen2Moe::ModelArgs.from_dict({ + "model_type" => "qwen2_moe", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 109, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "tie_word_embeddings" => false, + "num_experts_per_tok" => 1, + "num_experts" => 2, + "moe_intermediate_size" => 16, + "shared_expert_intermediate_size" => 24, + }) + + model = MlxLm::Models::Qwen2Moe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 109], output.shape + + weights = { + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.down_proj.weight" => @mx.array([[9.0, 10.0], [11.0, 12.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.down_proj.weight" => @mx.array([[13.0, 14.0], [15.0, 16.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[17.0, 18.0], [19.0, 20.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[21.0, 22.0], [23.0, 24.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.up_proj.scales" => @mx.array([0.5, 1.0], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.scales" => @mx.array([1.5, 2.0], dtype: @mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + up_stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + scales_stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.scales"] + @mx.eval(up_stacked, scales_stacked) + + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.scales") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.scales") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.scales") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + + assert_equal [2, 2, 2], up_stacked.shape + assert_equal [2, 2], scales_stacked.shape + assert_equal [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], up_stacked.to_a + assert_equal [[0.5, 1.0], [1.5, 2.0]], scales_stacked.to_a + end +end + +class Phase21DenseLaneWPhiMoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_phimoe_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::PhiMoe::ModelArgs.from_dict({ + "model_type" => "phimoe", + "vocab_size" => 113, + "hidden_size" => 32, + "intermediate_size" => 64, + "num_hidden_layers" => 2, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "max_position_embeddings" => 256, + "original_max_position_embeddings" => 64, + "rms_norm_eps" => 1e-5, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "rope_theta" => 10_000.0, + "rope_scaling" => { + "short_factor" => [1.0, 1.0, 1.0, 1.0], + "long_factor" => [1.0, 1.0, 1.0, 1.0], + "short_mscale" => 1.0, + "long_mscale" => 1.0, + }, + }) + + model = MlxLm::Models::PhiMoe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 113], output.shape + + weights = { + "model.layers.0.block_sparse_moe.experts.0.w1.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w1.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.0.w2.weight" => @mx.array([[9.0, 10.0], [11.0, 12.0]], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w2.weight" => @mx.array([[13.0, 14.0], [15.0, 16.0]], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.0.w3.weight" => @mx.array([[17.0, 18.0], [19.0, 20.0]], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w3.weight" => @mx.array([[21.0, 22.0], [23.0, 24.0]], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.0.w1.scales" => @mx.array([0.25, 0.5], dtype: @mx.float32), + "model.layers.0.block_sparse_moe.experts.1.w1.scales" => @mx.array([0.75, 1.0], dtype: @mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate_stacked = sanitized["model.layers.0.block_sparse_moe.switch_mlp.gate_proj.weight"] + scales_stacked = sanitized["model.layers.0.block_sparse_moe.switch_mlp.gate_proj.scales"] + @mx.eval(gate_stacked, scales_stacked) + + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w1.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w1.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w2.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w2.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w3.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w3.weight") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.0.w1.scales") + refute sanitized.key?("model.layers.0.block_sparse_moe.experts.1.w1.scales") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.block_sparse_moe.switch_mlp.gate_proj.scales") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + + assert_equal [2, 2, 2], gate_stacked.shape + assert_equal [2, 2], scales_stacked.shape + assert_equal [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], gate_stacked.to_a + assert_equal [[0.25, 0.5], [0.75, 1.0]], scales_stacked.to_a + end +end + +class Phase21DenseLaneWRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("qwen2_moe"), "qwen2_moe should be registered" + assert MlxLm::Models::REGISTRY.key?("phimoe"), "phimoe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen2_moe" }) + assert_equal MlxLm::Models::Qwen2Moe::Model, model_class + assert_equal MlxLm::Models::Qwen2Moe::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "phimoe" }) + assert_equal MlxLm::Models::PhiMoe::Model, model_class + assert_equal MlxLm::Models::PhiMoe::ModelArgs, args_class + end +end diff --git a/test/parity/qwen3_5_qwen3_5_moe_models_test.rb b/test/parity/qwen3_5_qwen3_5_moe_models_test.rb new file mode 100644 index 0000000..d857e74 --- /dev/null +++ b/test/parity/qwen3_5_qwen3_5_moe_models_test.rb @@ -0,0 +1,137 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/qwen3" +require_relative "../../lib/mlx_lm/models/qwen3_5" +require_relative "../../lib/mlx_lm/models/qwen3_5_moe" + +class Phase19DenseLaneOQwen35Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_5_construct_forward_shape_and_sanitize_key_remap + args = MlxLm::Models::Qwen35::ModelArgs.from_dict({ + "model_type" => "qwen3_5", + "text_config" => { + "model_type" => "qwen3_5", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 113, + "head_dim" => 8, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }, + }) + + model = MlxLm::Models::Qwen35::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 3, 113], output.shape + + weights = { + "model.language_model.embed_tokens.weight" => @mx.zeros([113, 32]).astype(@mx.float32), + "model.language_model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "language_model.model.norm.weight" => @mx.zeros([32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([113, 32]).astype(@mx.float32), + "model.visual.encoder.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.visual.encoder.weight") + assert sanitized.key?("language_model.model.embed_tokens.weight") + assert sanitized.key?("language_model.model.layers.0.self_attn.q_proj.weight") + assert sanitized.key?("language_model.model.norm.weight") + assert sanitized.key?("language_model.lm_head.weight") + assert sanitized.keys.all? { |k| k.start_with?("language_model.") } + end +end + +class Phase19DenseLaneOQwen35MoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_5_moe_construct_forward_shape_and_sanitize_moe_key_remap + args = MlxLm::Models::Qwen35Moe::ModelArgs.from_dict({ + "model_type" => "qwen3_5_moe", + "text_config" => { + "model_type" => "qwen3_5_moe", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 109, + "head_dim" => 8, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + "num_experts" => 2, + "num_experts_per_tok" => 1, + }, + }) + + model = MlxLm::Models::Qwen35Moe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + + assert_equal [1, 4, 109], output.shape + + gate_up = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([6, 4]) + down_proj = @mx.zeros([3, 4]).astype(@mx.float32) + weights = { + "model.language_model.layers.0.mlp.experts.gate_up_proj" => gate_up, + "model.language_model.layers.0.mlp.experts.down_proj" => down_proj, + "model.visual.patch_embed.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("model.visual.patch_embed.weight") + refute sanitized.key?("language_model.model.layers.0.mlp.experts.gate_up_proj") + refute sanitized.key?("language_model.model.layers.0.mlp.experts.down_proj") + assert sanitized.key?("language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("language_model.model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("language_model.model.layers.0.mlp.switch_mlp.down_proj.weight") + + assert_equal [3, 4], sanitized["language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight"].shape + assert_equal [3, 4], sanitized["language_model.model.layers.0.mlp.switch_mlp.up_proj.weight"].shape + assert_equal [3, 4], sanitized["language_model.model.layers.0.mlp.switch_mlp.down_proj.weight"].shape + end +end + +class Phase19DenseLaneORegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("qwen3_5"), "qwen3_5 should be registered" + assert MlxLm::Models::REGISTRY.key?("qwen3_5_moe"), "qwen3_5_moe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen3_5" }) + assert_equal MlxLm::Models::Qwen35::Model, model_class + assert_equal MlxLm::Models::Qwen35::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen3_5_moe" }) + assert_equal MlxLm::Models::Qwen35Moe::Model, model_class + assert_equal MlxLm::Models::Qwen35Moe::ModelArgs, args_class + end +end diff --git a/test/parity/qwen3_moe_qwen3_vl_moe_models_test.rb b/test/parity/qwen3_moe_qwen3_vl_moe_models_test.rb new file mode 100644 index 0000000..b1896ca --- /dev/null +++ b/test/parity/qwen3_moe_qwen3_vl_moe_models_test.rb @@ -0,0 +1,164 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/qwen3" +require_relative "../../lib/mlx_lm/models/qwen3_moe" +require_relative "../../lib/mlx_lm/models/qwen3_vl_moe" + +class Phase20DenseLaneSQwen3MoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_moe_construct_forward_shape_and_sanitize_stacks_experts + args = MlxLm::Models::Qwen3Moe::ModelArgs.from_dict({ + "model_type" => "qwen3_moe", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 111, + "head_dim" => 8, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "tie_word_embeddings" => true, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "decoder_sparse_step" => 1, + "mlp_only_layers" => [], + "moe_intermediate_size" => 16, + "norm_topk_prob" => true, + }) + + model = MlxLm::Models::Qwen3Moe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 111], output.shape + + weights = { + "lm_head.weight" => @mx.zeros([111, 32]).astype(@mx.float32), + "model.layers.0.mlp.experts.0.up_proj.weight" => @mx.array([[1.0, 2.0], [3.0, 4.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.up_proj.weight" => @mx.array([[5.0, 6.0], [7.0, 8.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.down_proj.weight" => @mx.array([[9.0, 10.0], [11.0, 12.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.down_proj.weight" => @mx.array([[13.0, 14.0], [15.0, 16.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.0.gate_proj.weight" => @mx.array([[17.0, 18.0], [19.0, 20.0]], dtype: @mx.float32), + "model.layers.0.mlp.experts.1.gate_proj.weight" => @mx.array([[21.0, 22.0], [23.0, 24.0]], dtype: @mx.float32), + "model.layers.1.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + up_stacked = sanitized["model.layers.0.mlp.switch_mlp.up_proj.weight"] + @mx.eval(up_stacked) + + refute sanitized.key?("lm_head.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.up_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.down_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.0.gate_proj.weight") + refute sanitized.key?("model.layers.0.mlp.experts.1.gate_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.self_attn.q_proj.weight") + + assert_equal [2, 2, 2], up_stacked.shape + assert_equal [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], up_stacked.to_a + end +end + +class Phase20DenseLaneSQwen3VLMoeTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_vl_moe_construct_forward_shape_and_sanitize_remaps_gate_up + args = MlxLm::Models::Qwen3VLMoe::ModelArgs.from_dict({ + "model_type" => "qwen3_vl_moe", + "text_config" => { + "model_type" => "qwen3_moe", + "hidden_size" => 24, + "num_hidden_layers" => 2, + "intermediate_size" => 48, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "head_dim" => 6, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 256, + "tie_word_embeddings" => false, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "decoder_sparse_step" => 1, + "mlp_only_layers" => [], + "moe_intermediate_size" => 12, + "norm_topk_prob" => true, + }, + }) + + model = MlxLm::Models::Qwen3VLMoe::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 101], output.shape + + gate_up = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([2, 3, 4]) + down_proj = @mx.array((0...12).to_a, dtype: @mx.float32).reshape([2, 3, 2]) + weights = { + "visual.encoder.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + "language_model.model.layers.0.mlp.experts.gate_up_proj" => gate_up, + "language_model.model.layers.0.mlp.experts.down_proj" => down_proj, + "language_model.lm_head.weight" => @mx.zeros([101, 24]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + gate_stacked = sanitized["language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight"] + up_stacked = sanitized["language_model.model.layers.0.mlp.switch_mlp.up_proj.weight"] + down_stacked = sanitized["language_model.model.layers.0.mlp.switch_mlp.down_proj.weight"] + @mx.eval(gate_stacked, up_stacked, down_stacked) + + refute sanitized.key?("visual.encoder.weight") + refute sanitized.key?("language_model.model.layers.0.mlp.experts.gate_up_proj") + refute sanitized.key?("language_model.model.layers.0.mlp.experts.down_proj") + assert sanitized.key?("language_model.model.layers.0.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("language_model.model.layers.0.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("language_model.model.layers.0.mlp.switch_mlp.down_proj.weight") + assert sanitized.key?("language_model.lm_head.weight") + + assert_equal [2, 2, 3], gate_stacked.shape + assert_equal [2, 2, 3], up_stacked.shape + assert_equal [2, 2, 3], down_stacked.shape + end +end + +class Phase20DenseLaneSRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("qwen3_moe"), "qwen3_moe should be registered" + assert MlxLm::Models::REGISTRY.key?("qwen3_vl_moe"), "qwen3_vl_moe should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen3_moe" }) + assert_equal MlxLm::Models::Qwen3Moe::Model, model_class + assert_equal MlxLm::Models::Qwen3Moe::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen3_vl_moe" }) + assert_equal MlxLm::Models::Qwen3VLMoe::Model, model_class + assert_equal MlxLm::Models::Qwen3VLMoe::ModelArgs, args_class + end +end diff --git a/test/parity/qwen3_vl_smollm3_models_test.rb b/test/parity/qwen3_vl_smollm3_models_test.rb new file mode 100644 index 0000000..8c4ce23 --- /dev/null +++ b/test/parity/qwen3_vl_smollm3_models_test.rb @@ -0,0 +1,113 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/qwen3" +require_relative "../../lib/mlx_lm/models/llama" +require_relative "../../lib/mlx_lm/models/qwen3_vl" +require_relative "../../lib/mlx_lm/models/smollm3" + +class Phase19DenseLaneNQwen3VLTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_vl_construct_forward_shape_and_sanitize + args = MlxLm::Models::Qwen3VL::ModelArgs.from_dict({ + "model_type" => "qwen3_vl", + "text_config" => { + "model_type" => "qwen3", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 127, + "head_dim" => 8, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + }, + }) + + model = MlxLm::Models::Qwen3VL::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + input = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 4, 127], output.shape + + weights = { + "model.embed_tokens.weight" => @mx.zeros([127, 32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([127, 32]).astype(@mx.float32), + "language_model.model.layers.0.self_attn.q_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "vision_tower.blocks.0.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + + refute sanitized.key?("vision_tower.blocks.0.weight") + assert sanitized.key?("language_model.model.embed_tokens.weight") + assert sanitized.key?("language_model.lm_head.weight") + assert sanitized.key?("language_model.model.layers.0.self_attn.q_proj.weight") + end +end + +class Phase19DenseLaneNSmolLM3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_smollm3_construct_forward_shape_and_nope_placement + args = MlxLm::Models::SmolLM3::ModelArgs.from_dict({ + "model_type" => "smollm3", + "hidden_size" => 48, + "num_hidden_layers" => 4, + "intermediate_size" => 96, + "num_attention_heads" => 4, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 101, + "tie_word_embeddings" => false, + "no_rope_layers" => [1, 0, 1, 0], + }) + + model = MlxLm::Models::SmolLM3::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) + + assert_instance_of MlxLm::Models::SmolLM3::NoPE, model.layers[1].self_attn.rope + assert_instance_of MlxLm::Models::SmolLM3::NoPE, model.layers[3].self_attn.rope + refute_instance_of MlxLm::Models::SmolLM3::NoPE, model.layers[0].self_attn.rope + refute_instance_of MlxLm::Models::SmolLM3::NoPE, model.layers[2].self_attn.rope + + input = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(input) + @mx.eval(output) + + assert_equal [1, 3, 101], output.shape + end +end + +class Phase19DenseLaneNRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("qwen3_vl"), "qwen3_vl should be registered" + assert MlxLm::Models::REGISTRY.key?("smollm3"), "smollm3 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "qwen3_vl" }) + assert_equal MlxLm::Models::Qwen3VL::Model, model_class + assert_equal MlxLm::Models::Qwen3VL::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "smollm3" }) + assert_equal MlxLm::Models::SmolLM3::Model, model_class + assert_equal MlxLm::Models::SmolLM3::ModelArgs, args_class + end +end diff --git a/test/parity/qwen_qwen3_models_test.rb b/test/parity/qwen_qwen3_models_test.rb new file mode 100644 index 0000000..5f83ba1 --- /dev/null +++ b/test/parity/qwen_qwen3_models_test.rb @@ -0,0 +1,117 @@ +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/qwen" +require_relative "../../lib/mlx_lm/models/qwen3" + +class Phase16DenseLaneAQwenTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen_instantiation_and_forward_shape + args = MlxLm::Models::Qwen::ModelArgs.from_dict({ + "model_type" => "qwen", + "hidden_size" => 32, + "num_attention_heads" => 2, + "num_hidden_layers" => 2, + "kv_channels" => 16, + "intermediate_size" => 64, + "vocab_size" => 100, + "no_bias" => true, + }) + model = MlxLm::Models::Qwen::Model.new(args) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + + assert_instance_of MlxLm::Models::Qwen::Model, model + assert_equal [1, 3, 100], output.shape + end + + def test_qwen_sanitize_removes_rotary_freqs_only + args = MlxLm::Models::Qwen::ModelArgs.from_dict({ + "model_type" => "qwen", + "hidden_size" => 32, + "num_attention_heads" => 2, + "num_hidden_layers" => 1, + "kv_channels" => 16, + "intermediate_size" => 64, + "vocab_size" => 100, + }) + model = MlxLm::Models::Qwen::Model.new(args) + weights = { + "transformer.wte.weight" => @mx.zeros([100, 32]).astype(@mx.float32), + "transformer.h.0.attn.rotary_emb.inv_freq" => @mx.zeros([16]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([100, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("transformer.h.0.attn.rotary_emb.inv_freq") + assert sanitized.key?("transformer.wte.weight") + assert sanitized.key?("lm_head.weight") + end +end + +class Phase16DenseLaneAQwen3Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_qwen3_instantiation_and_forward_shape + args = MlxLm::Models::Qwen3::ModelArgs.from_dict({ + "model_type" => "qwen3", + "hidden_size" => 32, + "num_hidden_layers" => 2, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-6, + "vocab_size" => 100, + "max_position_embeddings" => 256, + "rope_theta" => 10_000.0, + "head_dim" => 16, + "tie_word_embeddings" => true, + }) + model = MlxLm::Models::Qwen3::Model.new(args) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + + assert_instance_of MlxLm::Models::Qwen3::Model, model + assert_equal [1, 3, 100], output.shape + end + + def test_qwen3_sanitize_drops_lm_head_when_tied + args = MlxLm::Models::Qwen3::ModelArgs.from_dict({ + "model_type" => "qwen3", + "hidden_size" => 32, + "num_hidden_layers" => 1, + "intermediate_size" => 64, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-6, + "vocab_size" => 100, + "head_dim" => 16, + "tie_word_embeddings" => true, + }) + model = MlxLm::Models::Qwen3::Model.new(args) + weights = { + "model.embed_tokens.weight" => @mx.zeros([100, 32]).astype(@mx.float32), + "lm_head.weight" => @mx.zeros([100, 32]).astype(@mx.float32), + "model.layers.0.self_attn.rotary_emb.inv_freq" => @mx.zeros([16]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + + refute sanitized.key?("lm_head.weight") + assert sanitized.key?("model.embed_tokens.weight") + assert sanitized.key?("model.layers.0.self_attn.rotary_emb.inv_freq") + end +end diff --git a/test/parity/recurrent_gemma_step3p5_models_test.rb b/test/parity/recurrent_gemma_step3p5_models_test.rb new file mode 100644 index 0000000..5b43077 --- /dev/null +++ b/test/parity/recurrent_gemma_step3p5_models_test.rb @@ -0,0 +1,190 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/activations" +require_relative "../../lib/mlx_lm/models/rope_utils" +require_relative "../../lib/mlx_lm/models/switch_layers" +require_relative "../../lib/mlx_lm/models/recurrent_gemma" +require_relative "../../lib/mlx_lm/models/step3p5" + +class Phase24DenseLaneANRecurrentGemmaTest < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_recurrent_gemma_construct_forward_shape_sanitize_and_make_cache + args = MlxLm::Models::RecurrentGemma::ModelArgs.from_dict({ + "model_type" => "recurrent_gemma", + "hidden_size" => 32, + "attention_bias" => false, + "conv1d_width" => 3, + "intermediate_size" => 64, + "logits_soft_cap" => 1.5, + "num_attention_heads" => 4, + "num_hidden_layers" => 3, + "num_key_value_heads" => 2, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_window_size" => 4, + "vocab_size" => 97, + "block_types" => ["recurrent", "attention"], + }) + + model = MlxLm::Models::RecurrentGemma::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3, 4]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 4, 97], output.shape + + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 1, 6]) + weights = { + "model.layers.0.temporal_block.conv_1d.weight" => conv_weight, + "model.layers.0.temporal_block.linear_x.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["model.layers.0.temporal_block.conv_1d.weight"] + @mx.eval(sanitized_conv) + + assert_equal [4, 6, 1], sanitized_conv.shape + assert_equal @mx.swapaxes(conv_weight, 1, 2).to_a, sanitized_conv.to_a + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + assert_nil model.lm_head + + tied_output = model.call(tokens) + @mx.eval(tied_output) + assert_equal [1, 4, 97], tied_output.shape + + cache = model.make_cache + assert_equal 3, cache.length + assert_instance_of MlxLm::ArraysCache, cache[0] + assert_instance_of MlxLm::RotatingKVCache, cache[1] + assert_instance_of MlxLm::ArraysCache, cache[2] + assert_equal 2, cache[0].cache.length + assert_equal 2, cache[2].cache.length + end +end + +class Phase24DenseLaneANStep3p5Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_step3p5_construct_forward_shape_sanitize_and_make_cache + args = MlxLm::Models::Step3p5::ModelArgs.from_dict({ + "model_type" => "step3p5", + "hidden_size" => 32, + "num_hidden_layers" => 3, + "vocab_size" => 103, + "num_attention_heads" => 4, + "num_attention_groups" => 2, + "head_dim" => 8, + "intermediate_size" => 64, + "rms_norm_eps" => 1e-5, + "rope_theta" => [10_000.0, 10_000.0, 10_000.0], + "sliding_window" => 4, + "layer_types" => ["full_attention", "sliding_attention", "full_attention"], + "partial_rotary_factors" => [0.5, 1.0, 0.5], + "attention_other_setting" => { + "num_attention_heads" => 4, + "num_attention_groups" => 2, + }, + "use_head_wise_attn_gate" => true, + "moe_num_experts" => 2, + "moe_top_k" => 1, + "moe_intermediate_size" => 48, + "share_expert_dim" => 48, + "moe_layers_enum" => "1,2", + }) + + model = MlxLm::Models::Step3p5::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 103], output.shape + + norm_weight = @mx.array([1.0, 2.0], dtype: @mx.float32) + weights = { + "model.layers.1.moe.gate_proj.weight" => @mx.array((0...48).to_a, dtype: @mx.float32).reshape([2, 6, 4]), + "model.layers.1.moe.up_proj.weight" => @mx.array((48...96).to_a, dtype: @mx.float32).reshape([2, 6, 4]), + "model.layers.1.moe.down_proj.weight" => @mx.array((0...48).to_a, dtype: @mx.float32).reshape([2, 4, 6]), + "model.layers.1.moe.gate.weight" => @mx.zeros([2, 32]).astype(@mx.float32), + "model.layers.1.moe.router_bias" => @mx.zeros([2]).astype(@mx.float32), + "model.layers.1.share_expert.gate_proj.weight" => @mx.zeros([48, 32]).astype(@mx.float32), + "model.layers.0.input_layernorm.weight" => norm_weight, + "model.layers.5.mlp.up_proj.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + "model.mtp.layers.0.weight" => @mx.zeros([2, 2]).astype(@mx.float32), + } + + sanitized = model.sanitize(weights) + remapped_gate = sanitized["model.layers.1.mlp.switch_mlp.gate_proj.weight"] + remapped_up = sanitized["model.layers.1.mlp.switch_mlp.up_proj.weight"] + remapped_down = sanitized["model.layers.1.mlp.switch_mlp.down_proj.weight"] + remapped_router = sanitized["model.layers.1.mlp.gate.gate.weight"] + remapped_router_bias = sanitized["model.layers.1.mlp.gate.router_bias"] + remapped_shared = sanitized["model.layers.1.mlp.share_expert.gate_proj.weight"] + sanitized_norm = sanitized["model.layers.0.input_layernorm.weight"] + @mx.eval( + remapped_gate, + remapped_up, + remapped_down, + remapped_router, + remapped_router_bias, + remapped_shared, + sanitized_norm + ) + + refute sanitized.key?("model.layers.1.moe.gate_proj.weight") + refute sanitized.key?("model.layers.1.moe.up_proj.weight") + refute sanitized.key?("model.layers.1.moe.down_proj.weight") + refute sanitized.key?("model.layers.1.moe.gate.weight") + refute sanitized.key?("model.layers.1.moe.router_bias") + refute sanitized.key?("model.layers.1.share_expert.gate_proj.weight") + refute sanitized.key?("model.layers.5.mlp.up_proj.weight") + refute sanitized.key?("model.mtp.layers.0.weight") + + assert sanitized.key?("model.layers.1.mlp.switch_mlp.gate_proj.weight") + assert sanitized.key?("model.layers.1.mlp.switch_mlp.up_proj.weight") + assert sanitized.key?("model.layers.1.mlp.switch_mlp.down_proj.weight") + assert sanitized.key?("model.layers.1.mlp.gate.gate.weight") + assert sanitized.key?("model.layers.1.mlp.gate.router_bias") + assert sanitized.key?("model.layers.1.mlp.share_expert.gate_proj.weight") + + assert_equal [2, 6, 4], remapped_gate.shape + assert_equal [2, 6, 4], remapped_up.shape + assert_equal [2, 4, 6], remapped_down.shape + assert_equal [2, 32], remapped_router.shape + assert_equal [2], remapped_router_bias.shape + assert_equal [48, 32], remapped_shared.shape + assert_equal [2.0, 3.0], sanitized_norm.to_a + + cache = model.make_cache + assert_equal 3, cache.length + cache.each { |entry| assert_instance_of MlxLm::KVCache, entry } + end +end + +class Phase24DenseLaneANRegistryTest < Minitest::Test + def test_models_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("recurrent_gemma"), "recurrent_gemma should be registered" + assert MlxLm::Models::REGISTRY.key?("step3p5"), "step3p5 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "recurrent_gemma" }) + assert_equal MlxLm::Models::RecurrentGemma::Model, model_class + assert_equal MlxLm::Models::RecurrentGemma::ModelArgs, args_class + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "step3p5" }) + assert_equal MlxLm::Models::Step3p5::Model, model_class + assert_equal MlxLm::Models::Step3p5::ModelArgs, args_class + end +end diff --git a/test/parity/phase6_test.rb b/test/parity/registry_gemma_qwen2_phi3_starcoder2_test.rb similarity index 95% rename from test/parity/phase6_test.rb rename to test/parity/registry_gemma_qwen2_phi3_starcoder2_test.rb index ff232a9..9a69055 100644 --- a/test/parity/phase6_test.rb +++ b/test/parity/registry_gemma_qwen2_phi3_starcoder2_test.rb @@ -63,7 +63,7 @@ def test_gemma_instantiates # Test 5: Gemma forward pass produces correct output shape def test_gemma_forward_shape model = MlxLm::Models::Gemma::Model.new(@args) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 100], output.shape, "Output should be [batch, seq_len, vocab_size]" end @@ -72,7 +72,7 @@ def test_gemma_forward_shape def test_gemma_forward_with_cache model = MlxLm::Models::Gemma::Model.new(@args) cache = Array.new(2) { MlxLm::KVCache.new } - token = @mx.array([[5]]).astype(@mx.int32) + token = @mx.array([[5]], dtype: @mx.int32) output = model.call(token, cache: cache) assert_equal [1, 1, 100], output.shape end @@ -105,7 +105,7 @@ def test_qwen2_instantiates # Test 8: Qwen2 forward pass produces correct output shape def test_qwen2_forward_shape model = MlxLm::Models::Qwen2::Model.new(@args) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 100], output.shape end @@ -152,7 +152,7 @@ def test_phi3_instantiates # Test 11: Phi3 forward pass def test_phi3_forward_shape model = MlxLm::Models::Phi3::Model.new(@args) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 100], output.shape end @@ -185,7 +185,7 @@ def test_starcoder2_instantiates # Test 13: Starcoder2 forward pass def test_starcoder2_forward_shape model = MlxLm::Models::Starcoder2::Model.new(@args) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 100], output.shape end diff --git a/test/parity/registry_keys_afm7_bailing_moe_linear_falcon_h1_glm4_moe_lite_test.rb b/test/parity/registry_keys_afm7_bailing_moe_linear_falcon_h1_glm4_moe_lite_test.rb new file mode 100644 index 0000000..c839026 --- /dev/null +++ b/test/parity/registry_keys_afm7_bailing_moe_linear_falcon_h1_glm4_moe_lite_test.rb @@ -0,0 +1,112 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase26IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + afm7 + bailing_moe_linear + falcon_h1 + glm4_moe_lite + ].freeze + + def test_phase26_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "afm7" + { + "model_type" => "afm7", + "vocab_size" => 64, + "hidden_dim" => 16, + "num_layers" => 1, + "num_kv_reuse_layers" => 0, + "num_heads" => 2, + "num_kv_heads" => 1, + "hidden_dim_scale_factor" => 2.0, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 128, + } + when "bailing_moe_linear" + { + "model_type" => "bailing_moe_linear", + "hidden_size" => 16, + "intermediate_size" => 32, + "max_position_embeddings" => 128, + "moe_intermediate_size" => 24, + "num_experts" => 2, + "num_shared_experts" => 1, + "norm_topk_prob" => true, + "num_attention_heads" => 2, + "num_experts_per_tok" => 1, + "num_hidden_layers" => 1, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "vocab_size" => 64, + "first_k_dense_replace" => 0, + "layer_group_size" => 1, + "group_norm_size" => 1, + } + when "falcon_h1" + { + "model_type" => "falcon_h1", + "hidden_size" => 16, + "intermediate_size" => 32, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "mamba_d_conv" => 3, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "vocab_size" => 64, + "max_position_embeddings" => 128, + } + else + { + "model_type" => "glm4_moe_lite", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "kv_lora_rank" => 8, + "q_lora_rank" => 8, + "qk_rope_head_dim" => 4, + "qk_nope_head_dim" => 4, + "v_head_dim" => 8, + "topk_method" => "noaux_tc", + "scoring_func" => "sigmoid", + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "num_experts_per_tok" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + "partial_rotary_factor" => 1.0, + } + end + end +end diff --git a/test/parity/registry_keys_afmoe_bailing_moe_exaone_moe_glm4_moe_minimax_nemotron_nas_recurrent_gemma_step3p5_test.rb b/test/parity/registry_keys_afmoe_bailing_moe_exaone_moe_glm4_moe_minimax_nemotron_nas_recurrent_gemma_step3p5_test.rb new file mode 100644 index 0000000..8cf37f0 --- /dev/null +++ b/test/parity/registry_keys_afmoe_bailing_moe_exaone_moe_glm4_moe_minimax_nemotron_nas_recurrent_gemma_step3p5_test.rb @@ -0,0 +1,217 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase25IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + afmoe + bailing_moe + exaone_moe + glm4_moe + minimax + nemotron-nas + recurrent_gemma + step3p5 + ].freeze + + def test_phase25_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "afmoe" + { + "model_type" => "afmoe", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "head_dim" => 8, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "num_shared_experts" => 1, + "num_dense_layers" => 1, + "route_norm" => true, + "route_scale" => 1.0, + "score_func" => "sigmoid", + "n_group" => 1, + "topk_group" => 1, + "sliding_window" => 8, + "mup_enabled" => false, + "layer_types" => ["full_attention"], + } + when "bailing_moe" + { + "model_type" => "bailing_moe", + "hidden_size" => 16, + "intermediate_size" => 32, + "max_position_embeddings" => 128, + "moe_intermediate_size" => 24, + "num_experts" => 2, + "num_shared_experts" => 1, + "norm_topk_prob" => true, + "num_attention_heads" => 2, + "num_experts_per_tok" => 1, + "num_hidden_layers" => 1, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "vocab_size" => 64, + "first_k_dense_replace" => 0, + "use_bias" => false, + "use_qkv_bias" => false, + "score_function" => "softmax", + "n_group" => 1, + "topk_group" => 1, + } + when "exaone_moe" + { + "model_type" => "exaone_moe", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "head_dim" => 8, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "num_shared_experts" => 1, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 128, + "sliding_window" => 8, + "layer_types" => ["full_attention"], + "is_moe_layer" => [true], + "n_group" => 1, + "topk_group" => 1, + "routed_scaling_factor" => 1.0, + "norm_topk_prob" => true, + "rope_theta" => 10_000.0, + } + when "glm4_moe" + { + "model_type" => "glm4_moe", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "max_position_embeddings" => 128, + "moe_intermediate_size" => 24, + "norm_topk_prob" => true, + "num_attention_heads" => 2, + "n_group" => 1, + "head_dim" => 8, + "topk_group" => 1, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "num_experts_per_tok" => 1, + "first_k_dense_replace" => 0, + "num_hidden_layers" => 1, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "use_qk_norm" => true, + "attention_bias" => false, + "partial_rotary_factor" => 1.0, + } + when "minimax" + { + "model_type" => "minimax", + "hidden_size" => 16, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "max_position_embeddings" => 128, + "num_experts_per_tok" => 1, + "num_local_experts" => 2, + "shared_intermediate_size" => 24, + "num_hidden_layers" => 1, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "rotary_dim" => 8, + "vocab_size" => 64, + "use_qk_norm" => true, + } + when "nemotron-nas" + { + "model_type" => "nemotron-nas", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "hidden_act" => "silu", + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 128, + "block_configs" => [ + { + "attention" => { "no_op" => true }, + "ffn" => { "replace_with_linear" => true }, + }, + ], + } + when "recurrent_gemma" + { + "model_type" => "recurrent_gemma", + "hidden_size" => 16, + "attention_bias" => false, + "conv1d_width" => 3, + "intermediate_size" => 32, + "logits_soft_cap" => 1.0, + "num_attention_heads" => 2, + "num_hidden_layers" => 1, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_window_size" => 8, + "vocab_size" => 64, + "block_types" => ["recurrent"], + } + else + { + "model_type" => "step3p5", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "vocab_size" => 64, + "num_attention_heads" => 2, + "num_attention_groups" => 1, + "head_dim" => 8, + "intermediate_size" => 32, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "sliding_window" => 8, + "layer_types" => ["full_attention"], + "partial_rotary_factors" => [1.0], + "attention_other_setting" => { + "num_attention_heads" => 2, + "num_attention_groups" => 1, + }, + "moe_num_experts" => 2, + "moe_top_k" => 1, + "moe_intermediate_size" => 24, + "share_expert_dim" => 24, + "moe_layers_enum" => "0", + } + end + end +end diff --git a/test/parity/registry_keys_bitnet_openelm_lille_mimo_qwen2moe_phimoe_phixtral_minicpm3_test.rb b/test/parity/registry_keys_bitnet_openelm_lille_mimo_qwen2moe_phimoe_phixtral_minicpm3_test.rb new file mode 100644 index 0000000..fdcaa51 --- /dev/null +++ b/test/parity/registry_keys_bitnet_openelm_lille_mimo_qwen2moe_phimoe_phixtral_minicpm3_test.rb @@ -0,0 +1,130 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase21IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + bitnet + openelm + lille-130m + mimo + qwen2_moe + phimoe + phixtral + minicpm3 + ].freeze + + def test_phase21_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "openelm" + { + "model_type" => "openelm", + "head_dim" => 8, + "num_transformer_layers" => 1, + "model_dim" => 16, + "vocab_size" => 64, + "ffn_dim_divisor" => 8, + "num_query_heads" => [2], + "num_kv_heads" => [1], + "ffn_multipliers" => [2.0], + } + when "lille-130m" + { + "model_type" => "lille-130m", + "block_size" => 64, + "layer_norm_eps" => 1e-5, + "n_embd" => 16, + "n_head" => 1, + "n_kv_heads" => 1, + "n_layer" => 1, + "rope_theta" => 10_000.0, + "vocab_size" => 64, + } + when "qwen2_moe" + { + "model_type" => "qwen2_moe", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "num_experts_per_tok" => 1, + "num_experts" => 2, + "moe_intermediate_size" => 8, + "shared_expert_intermediate_size" => 16, + } + when "phimoe" + { + "model_type" => "phimoe", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "max_position_embeddings" => 128, + "original_max_position_embeddings" => 64, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "rope_scaling" => { + "short_factor" => 1.0, + "long_factor" => 1.0, + }, + } + when "phixtral" + { + "model_type" => "phixtral", + "num_vocab" => 64, + "model_dim" => 16, + "num_heads" => 1, + "num_layers" => 1, + "rotary_dim" => 8, + "num_experts_per_tok" => 1, + "num_local_experts" => 2, + } + when "minicpm3" + { + "model_type" => "minicpm3", + "hidden_size" => 16, + "dim_model_base" => 8, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "vocab_size" => 64, + "q_lora_rank" => 8, + "qk_nope_head_dim" => 8, + "qk_rope_head_dim" => 8, + "kv_lora_rank" => 8, + "scale_depth" => 1.0, + "scale_emb" => 1.0, + "max_position_embeddings" => 128, + } + else + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + end + end +end diff --git a/test/parity/registry_keys_cohere2_internlm3_glm4_telechat3_granite_minicpm_exaone4_nanochat_test.rb b/test/parity/registry_keys_cohere2_internlm3_glm4_telechat3_granite_minicpm_exaone4_nanochat_test.rb new file mode 100644 index 0000000..b3cf9fe --- /dev/null +++ b/test/parity/registry_keys_cohere2_internlm3_glm4_telechat3_granite_minicpm_exaone4_nanochat_test.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase17IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[cohere2 internlm3 glm4 telechat3 granite minicpm exaone4 nanochat].freeze + + def test_phase17_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + end +end diff --git a/test/parity/registry_keys_deepseek_glm_moe_kimi_lfm2_test.rb b/test/parity/registry_keys_deepseek_glm_moe_kimi_lfm2_test.rb new file mode 100644 index 0000000..58cc1b4 --- /dev/null +++ b/test/parity/registry_keys_deepseek_glm_moe_kimi_lfm2_test.rb @@ -0,0 +1,96 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase22IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + deepseek_v2 + deepseek_v3 + deepseek_v32 + glm_moe_dsa + kimi_k25 + kimi_vl + lfm2 + lfm2-vl + ].freeze + + def test_phase22_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "deepseek_v2", "deepseek_v3", "deepseek_v32" + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + when "glm_moe_dsa" + { + "model_type" => "glm_moe_dsa", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "rope_parameters" => { + "rope_theta" => 10_000.0, + "type" => "linear", + "factor" => 1.0, + }, + } + when "kimi_k25", "kimi_vl" + { + "model_type" => model_type, + "text_config" => { + "model_type" => "deepseek", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + }, + } + when "lfm2" + { + "model_type" => "lfm2", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "block_ff_dim" => 32, + "vocab_size" => 64, + } + else + { + "model_type" => "lfm2-vl", + "text_config" => { + "model_type" => "lfm2", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "block_ff_dim" => 32, + "vocab_size" => 64, + }, + } + end + end +end diff --git a/test/parity/registry_keys_gemma3_ernie45_moe_qwen3_moe_granitemoe_olmoe_test.rb b/test/parity/registry_keys_gemma3_ernie45_moe_qwen3_moe_granitemoe_olmoe_test.rb new file mode 100644 index 0000000..3aeb527 --- /dev/null +++ b/test/parity/registry_keys_gemma3_ernie45_moe_qwen3_moe_granitemoe_olmoe_test.rb @@ -0,0 +1,124 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase20IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + gemma3_text + gemma3 + gemma3n + ernie4_5_moe + qwen3_moe + qwen3_vl_moe + granitemoe + olmoe + ].freeze + + def test_phase20_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "gemma3" + { + "model_type" => "gemma3", + "text_config" => { + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "head_dim" => 8, + "vocab_size" => 64, + }, + } + when "gemma3n" + { + "model_type" => "gemma3n", + "text_config" => { + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "head_dim" => 8, + }, + } + when "ernie4_5_moe" + { + "model_type" => "ernie4_5_moe", + "hidden_size" => 16, + "intermediate_size" => 32, + "max_position_embeddings" => 64, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "num_hidden_layers" => 1, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "rope_theta" => 10_000.0, + "use_bias" => false, + "tie_word_embeddings" => false, + "moe_num_experts" => 2, + } + when "qwen3_vl_moe" + { + "model_type" => "qwen3_vl_moe", + "text_config" => { + "model_type" => "qwen3_moe", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "head_dim" => 16, + "num_experts" => 2, + "num_experts_per_tok" => 1, + }, + } + when "granitemoe" + { + "model_type" => "granitemoe", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "logits_scaling" => 1.0, + "attention_multiplier" => 1.0, + "embedding_multiplier" => 1.0, + "residual_multiplier" => 1.0, + "max_position_embeddings" => 64, + "attention_bias" => false, + "mlp_bias" => false, + "rope_theta" => 10_000.0, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + } + else + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + end + end +end diff --git a/test/parity/registry_keys_granitemoehybrid_jamba_kimi_linear_longcat_flash_test.rb b/test/parity/registry_keys_granitemoehybrid_jamba_kimi_linear_longcat_flash_test.rb new file mode 100644 index 0000000..cfb3580 --- /dev/null +++ b/test/parity/registry_keys_granitemoehybrid_jamba_kimi_linear_longcat_flash_test.rb @@ -0,0 +1,135 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase27IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + granitemoehybrid + jamba + kimi_linear + longcat_flash + ].freeze + + def test_phase27_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "granitemoehybrid" + { + "model_type" => "granitemoehybrid", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "num_hidden_layers" => 1, + "max_position_embeddings" => 128, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "attention_bias" => false, + "embedding_multiplier" => 1.0, + "attention_multiplier" => 1.0, + "logits_scaling" => 1.0, + "residual_multiplier" => 1.0, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "shared_intermediate_size" => 24, + "mamba_n_heads" => 2, + "mamba_d_head" => 8, + "mamba_d_conv" => 3, + "layer_types" => ["mamba"], + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => true, + } + when "jamba" + { + "model_type" => "jamba", + "hidden_size" => 16, + "intermediate_size" => 32, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "attn_layer_offset" => 0, + "attn_layer_period" => 1, + "expert_layer_offset" => 0, + "expert_layer_period" => 1, + "mamba_d_conv" => 3, + "mamba_d_state" => 8, + "mamba_expand" => 2, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "rms_norm_eps" => 1e-5, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "vocab_size" => 64, + "tie_word_embeddings" => true, + } + when "kimi_linear" + { + "model_type" => "kimi_linear", + "vocab_size" => 64, + "hidden_dim" => 16, + "ffn_hidden_size" => 32, + "moe_intermediate_size" => 24, + "num_layers" => 1, + "num_heads" => 2, + "num_kv_heads" => 1, + "num_local_experts" => 2, + "n_shared_experts" => 1, + "top_k" => 1, + "norm_topk_prob" => true, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "first_k_dense_replace" => 0, + "layer_group_size" => 1, + "group_norm_size" => 1, + "use_bias" => false, + "use_qkv_bias" => false, + } + else + { + "model_type" => "longcat_flash", + "vocab_size" => 64, + "hidden_dim" => 16, + "ffn_hidden_size" => 32, + "moe_intermediate_size" => 24, + "num_layers" => 1, + "num_heads" => 2, + "num_kv_heads" => 1, + "num_local_experts" => 2, + "num_shared_experts" => 1, + "routed_scaling_factor" => 1.0, + "kv_lora_rank" => 8, + "q_lora_rank" => 8, + "qk_rope_head_dim" => 4, + "qk_nope_head_dim" => 4, + "v_head_dim" => 8, + "topk_method" => "noaux_tc", + "score_function" => "sigmoid", + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "top_k" => 1, + "moe_layer_freq" => 1, + "first_k_dense_replace" => 0, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + "partial_rotary_factor" => 1.0, + } + end + end +end diff --git a/test/parity/registry_keys_llama4_plamo_mamba_hunyuan_dbrx_klear_iquestloopcoder_test.rb b/test/parity/registry_keys_llama4_plamo_mamba_hunyuan_dbrx_klear_iquestloopcoder_test.rb new file mode 100644 index 0000000..6618c74 --- /dev/null +++ b/test/parity/registry_keys_llama4_plamo_mamba_hunyuan_dbrx_klear_iquestloopcoder_test.rb @@ -0,0 +1,149 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase23IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + llama4_text + plamo + mamba + mamba2 + hunyuan_v1_dense + dbrx + Klear + iquestloopcoder + ].freeze + + def test_phase23_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "llama4_text" + { + "model_type" => "llama4_text", + "hidden_size" => 16, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "num_hidden_layers" => 2, + "vocab_size" => 64, + "intermediate_size" => 32, + "intermediate_size_mlp" => 32, + "head_dim" => 8, + "no_rope_layers" => [0, 1], + } + when "plamo" + { + "model_type" => "plamo", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "n_shared_head" => 1, + } + when "mamba" + { + "model_type" => "mamba", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 16, + "state_size" => 8, + "num_hidden_layers" => 1, + "conv_kernel" => 3, + "use_bias" => true, + "use_conv_bias" => true, + "time_step_rank" => "auto", + } + when "mamba2" + { + "model_type" => "mamba2", + "num_heads" => 2, + "head_dim" => 8, + "vocab_size" => 64, + "hidden_size" => 16, + "state_size" => 8, + "num_hidden_layers" => 1, + "conv_kernel" => 3, + "n_groups" => 1, + "time_step_rank" => "auto", + "time_step_limit" => [0.001, 10.0], + } + when "hunyuan_v1_dense" + { + "model_type" => "hunyuan_v1_dense", + "vocab_size" => 64, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + } + when "dbrx" + { + "model_type" => "dbrx", + "vocab_size" => 64, + "d_model" => 16, + "n_layers" => 1, + "n_heads" => 2, + "attn_config" => { + "kv_n_heads" => 1, + "clip_qkv" => 8.0, + "rope_theta" => 10_000.0, + }, + "ffn_config" => { + "ffn_hidden_size" => 32, + "moe_num_experts" => 2, + "moe_top_k" => 1, + }, + } + when "Klear" + { + "model_type" => "Klear", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 2, + "attention_bias" => false, + "mlp_only_layers" => [], + "num_experts" => 2, + "num_experts_per_tok" => 1, + "decoder_sparse_step" => 1, + "n_shared_experts" => 1, + "moe_intermediate_size" => 24, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "rope_theta" => 10_000.0, + "max_position_embeddings" => 128, + "norm_topk_prob" => false, + } + else + { + "model_type" => "iquestloopcoder", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "head_dim" => 8, + "loop_num" => 2, + "loop_window_size" => 4, + } + end + end +end diff --git a/test/parity/registry_keys_longcat_flash_ngram_nemotron_h_plamo2_qwen3_next_rwkv7_test.rb b/test/parity/registry_keys_longcat_flash_ngram_nemotron_h_plamo2_qwen3_next_rwkv7_test.rb new file mode 100644 index 0000000..3ef8522 --- /dev/null +++ b/test/parity/registry_keys_longcat_flash_ngram_nemotron_h_plamo2_qwen3_next_rwkv7_test.rb @@ -0,0 +1,134 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase28IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + longcat_flash_ngram + nemotron_h + plamo2 + qwen3_next + rwkv7 + ].freeze + + def test_phase28_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "longcat_flash_ngram" + { + "model_type" => "longcat_flash_ngram", + "vocab_size" => 64, + "hidden_size" => 16, + "ffn_hidden_size" => 32, + "expert_ffn_hidden_size" => 24, + "num_layers" => 1, + "num_attention_heads" => 2, + "n_routed_experts" => 2, + "zero_expert_num" => 1, + "moe_topk" => 1, + "kv_lora_rank" => 8, + "q_lora_rank" => 8, + "qk_rope_head_dim" => 4, + "qk_nope_head_dim" => 4, + "v_head_dim" => 8, + "routed_scaling_factor" => 1.0, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "attention_bias" => false, + "norm_topk_prob" => true, + "n_group" => 1, + "topk_group" => 1, + "first_k_dense_replace" => 0, + } + when "nemotron_h" + { + "model_type" => "nemotron_h", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "num_hidden_layers" => 1, + "max_position_embeddings" => 128, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "attention_bias" => false, + "mamba_num_heads" => 2, + "mamba_head_dim" => 8, + "conv_kernel" => 3, + "layer_norm_epsilon" => 1e-5, + "hybrid_override_pattern" => ["M"], + } + when "plamo2" + { + "model_type" => "plamo2", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "rms_norm_eps" => 1e-5, + "tie_word_embeddings" => true, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "hidden_size_per_head" => 8, + "max_position_embeddings" => 128, + "attention_window_size" => 16, + "mamba_d_conv" => 3, + "mamba_step" => 2, + "mamba_enabled" => true, + "intermediate_size" => 32, + "vocab_size" => 64, + } + when "qwen3_next" + { + "model_type" => "qwen3_next", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "shared_expert_intermediate_size" => 16, + "decoder_sparse_step" => 1, + "mlp_only_layers" => [0], + "linear_num_value_heads" => 1, + "linear_num_key_heads" => 1, + "linear_key_head_dim" => 8, + "linear_value_head_dim" => 8, + "linear_conv_kernel_dim" => 3, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "partial_rotary_factor" => 0.5, + "attention_bias" => false, + } + else + { + "model_type" => "rwkv7", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "norm_eps" => 1e-5, + "head_dim" => 8, + "num_hidden_layers" => 1, + "a_low_rank_dim" => 4, + "v_low_rank_dim" => 4, + "gate_low_rank_dim" => 4, + "decay_low_rank_dim" => 4, + } + end + end +end diff --git a/test/parity/registry_keys_olmo3_gpt2_bigcode_nemotron_apertus_youtu_ernie_baichuan_test.rb b/test/parity/registry_keys_olmo3_gpt2_bigcode_nemotron_apertus_youtu_ernie_baichuan_test.rb new file mode 100644 index 0000000..f34f456 --- /dev/null +++ b/test/parity/registry_keys_olmo3_gpt2_bigcode_nemotron_apertus_youtu_ernie_baichuan_test.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase18IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[olmo3 gpt2 gpt_bigcode nemotron apertus youtu_llm ernie4_5 baichuan_m1].freeze + + def test_phase18_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + end +end diff --git a/test/parity/registry_keys_phi3small_dots1_llama4_ministral3_hunyuan_gpt_oss_mimo_v2_flash_lfm2_moe_test.rb b/test/parity/registry_keys_phi3small_dots1_llama4_ministral3_hunyuan_gpt_oss_mimo_v2_flash_lfm2_moe_test.rb new file mode 100644 index 0000000..0f590b6 --- /dev/null +++ b/test/parity/registry_keys_phi3small_dots1_llama4_ministral3_hunyuan_gpt_oss_mimo_v2_flash_lfm2_moe_test.rb @@ -0,0 +1,202 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase24IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + phi3small + dots1 + llama4 + ministral3 + hunyuan + gpt_oss + mimo_v2_flash + lfm2_moe + ].freeze + + def test_phase24_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "phi3small" + { + "model_type" => "phi3small", + "hidden_size" => 16, + "dense_attention_every_n_layers" => 1, + "ff_intermediate_size" => 32, + "gegelu_limit" => 16.0, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "layer_norm_epsilon" => 1e-5, + "vocab_size" => 64, + "num_key_value_heads" => 1, + } + when "dots1" + { + "model_type" => "dots1", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "max_position_embeddings" => 128, + "first_k_dense_replace" => 0, + "moe_intermediate_size" => 24, + "n_routed_experts" => 2, + "n_shared_experts" => 1, + "norm_topk_prob" => false, + "num_experts_per_tok" => 1, + "rope_theta" => 10_000.0, + "routed_scaling_factor" => 1.0, + } + when "llama4" + { + "model_type" => "llama4", + "text_config" => { + "model_type" => "llama4_text", + "hidden_size" => 16, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "num_hidden_layers" => 2, + "vocab_size" => 64, + "intermediate_size" => 24, + "intermediate_size_mlp" => 24, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "interleave_moe_layer_step" => 1, + "attention_chunk_size" => 4, + "max_position_embeddings" => 128, + "rope_theta" => 10_000.0, + "head_dim" => 8, + "rms_norm_eps" => 1e-5, + }, + } + when "ministral3" + { + "model_type" => "ministral3", + "hidden_size" => 16, + "num_hidden_layers" => 2, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "head_dim" => 8, + "max_position_embeddings" => 128, + "rms_norm_eps" => 1e-5, + "vocab_size" => 64, + "layer_types" => ["sliding_attention", "full_attention"], + "sliding_window" => 8, + "rope_parameters" => { + "rope_theta" => 10_000.0, + "llama_4_scaling_beta" => 0.0, + "original_max_position_embeddings" => 128, + }, + } + when "hunyuan" + { + "model_type" => "hunyuan", + "vocab_size" => 64, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "intermediate_size" => 32, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "attention_bias" => false, + "moe_topk" => 1, + "num_experts" => 2, + "num_shared_expert" => 1, + "use_mixed_mlp_moe" => true, + "use_qk_norm" => true, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "use_cla" => false, + } + when "gpt_oss" + { + "model_type" => "gpt_oss", + "num_hidden_layers" => 2, + "num_local_experts" => 2, + "num_experts_per_tok" => 1, + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 24, + "head_dim" => 8, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "sliding_window" => 8, + "layer_types" => ["sliding_attention", "full_attention"], + } + when "mimo_v2_flash" + { + "model_type" => "mimo_v2_flash", + "num_experts_per_tok" => 1, + "hybrid_layer_pattern" => [0], + "moe_layer_freq" => [0], + "add_swa_attention_sink_bias" => false, + "add_full_attention_sink_bias" => false, + "sliding_window_size" => 8, + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 1, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "topk_method" => "noaux_tc", + "scoring_func" => "sigmoid", + "norm_topk_prob" => false, + "n_group" => 1, + "topk_group" => 1, + "max_position_embeddings" => 128, + "layernorm_epsilon" => 1e-5, + "rope_theta" => 10_000.0, + "swa_rope_theta" => 10_000.0, + "swa_num_attention_heads" => 2, + "swa_num_key_value_heads" => 1, + "head_dim" => 8, + "v_head_dim" => 8, + "swa_head_dim" => 8, + "swa_v_head_dim" => 8, + "partial_rotary_factor" => 1.0, + } + else + { + "model_type" => "lfm2_moe", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 24, + "num_hidden_layers" => 1, + "num_experts" => 2, + "num_experts_per_tok" => 1, + "norm_topk_prob" => false, + "num_attention_heads" => 2, + "num_key_value_heads" => 1, + "max_position_embeddings" => 128, + "use_expert_bias" => false, + "num_dense_layers" => 1, + "norm_eps" => 1e-5, + "conv_bias" => false, + "conv_L_cache" => 3, + "full_attn_idxs" => [0], + "rope_theta" => 10_000.0, + } + end + end +end diff --git a/test/parity/registry_keys_pixtral_qwen_vl_qwen35_mistral3_solar_test.rb b/test/parity/registry_keys_pixtral_qwen_vl_qwen35_mistral3_solar_test.rb new file mode 100644 index 0000000..0bc443c --- /dev/null +++ b/test/parity/registry_keys_pixtral_qwen_vl_qwen35_mistral3_solar_test.rb @@ -0,0 +1,141 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase19IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[ + pixtral + qwen2_vl + qwen3_vl + smollm3 + qwen3_5 + qwen3_5_moe + mistral3 + solar_open + ].freeze + + def test_phase19_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + case model_type + when "pixtral" + { + "model_type" => "pixtral", + "text_config" => llama_text_config, + } + when "qwen2_vl" + { + "model_type" => "qwen2_vl", + "text_config" => qwen2_text_config, + } + when "qwen3_vl" + { + "model_type" => "qwen3_vl", + "text_config" => qwen3_text_config("qwen3"), + } + when "qwen3_5" + { + "model_type" => "qwen3_5", + "text_config" => qwen3_text_config("qwen3_5"), + } + when "qwen3_5_moe" + { + "model_type" => "qwen3_5_moe", + "text_config" => qwen3_text_config("qwen3_5_moe").merge( + "num_experts" => 2, + "num_experts_per_tok" => 1 + ), + } + when "mistral3" + { + "model_type" => "mistral3", + "text_config" => llama_text_config, + } + when "solar_open" + { + "model_type" => "solar_open", + "vocab_size" => 64, + "hidden_size" => 16, + "intermediate_size" => 32, + "moe_intermediate_size" => 8, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "head_dim" => 16, + "n_shared_experts" => 1, + "n_routed_experts" => 2, + "routed_scaling_factor" => 1.0, + "num_experts_per_tok" => 1, + "first_k_dense_replace" => 0, + "norm_topk_prob" => false, + "max_position_embeddings" => 64, + "rms_norm_eps" => 1e-5, + "rope_theta" => 10_000.0, + "tie_word_embeddings" => false, + "partial_rotary_factor" => 1.0, + } + else + { + "model_type" => "smollm3", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + end + end + + def llama_text_config + { + "model_type" => "llama", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "tie_word_embeddings" => false, + } + end + + def qwen2_text_config + { + "model_type" => "qwen2", + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "tie_word_embeddings" => true, + } + end + + def qwen3_text_config(model_type) + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + "head_dim" => 16, + "tie_word_embeddings" => false, + } + end +end diff --git a/test/parity/registry_keys_qwen_phi_glm_olmo_seed_oss_test.rb b/test/parity/registry_keys_qwen_phi_glm_olmo_seed_oss_test.rb new file mode 100644 index 0000000..e812a43 --- /dev/null +++ b/test/parity/registry_keys_qwen_phi_glm_olmo_seed_oss_test.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase16IntegrationRegistryTest < Minitest::Test + MODEL_TYPES = %w[qwen3 qwen phi exaone glm helium olmo seed_oss].freeze + + def test_dense_model_keys_resolve_with_tiny_configs + MODEL_TYPES.each do |model_type| + assert MlxLm::Models::REGISTRY.key?(model_type), "#{model_type} should be registered" + + model_class, args_class = MlxLm::Models.get_classes(tiny_config(model_type)) + + assert_kind_of Class, model_class, "#{model_type} should resolve to a model class" + assert_kind_of Class, args_class, "#{model_type} should resolve to a model args class" + assert_instance_of args_class, args_class.from_dict(tiny_config(model_type)) + end + end + + private + + def tiny_config(model_type) + { + "model_type" => model_type, + "hidden_size" => 16, + "num_hidden_layers" => 1, + "num_attention_heads" => 1, + "num_key_value_heads" => 1, + "intermediate_size" => 32, + "vocab_size" => 64, + } + end +end diff --git a/test/parity/phase5_test.rb b/test/parity/registry_model_loading_tokenizer_test.rb similarity index 100% rename from test/parity/phase5_test.rb rename to test/parity/registry_model_loading_tokenizer_test.rb diff --git a/test/parity/rope_mla_cache_integration_smoke_test.rb b/test/parity/rope_mla_cache_integration_smoke_test.rb new file mode 100644 index 0000000..e0403c7 --- /dev/null +++ b/test/parity/rope_mla_cache_integration_smoke_test.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase14IntegrationSmokeTest < Minitest::Test + def test_phase14_components_are_available_from_top_level_require + assert defined?(MlxLm::Models::SuScaledRoPE) + assert defined?(MlxLm::Models::Llama3RoPE) + assert defined?(MlxLm::Models::YarnRoPE) + assert defined?(MlxLm::Models::MLA::MultiLinear) + assert defined?(MlxLm::KVCache) + assert defined?(MlxLm::RotatingKVCache) + end + + def test_phase14_basic_construction_smoke + rope = MlxLm::Models.initialize_rope(8, 10_000.0, false) + mla = MlxLm::Models::MLA::MultiLinear.new(8, 4, 2) + cache = MlxLm::KVCache.new + rotating_cache = MlxLm::RotatingKVCache.new(max_size: 16, keep: 4) + + assert_instance_of MLX::NN::RoPE, rope + assert_instance_of MlxLm::Models::MLA::MultiLinear, mla + assert_instance_of MlxLm::KVCache, cache + assert_instance_of MlxLm::RotatingKVCache, rotating_cache + end +end diff --git a/test/parity/rope_utils_variants_factory_test.rb b/test/parity/rope_utils_variants_factory_test.rb new file mode 100644 index 0000000..1e755c7 --- /dev/null +++ b/test/parity/rope_utils_variants_factory_test.rb @@ -0,0 +1,286 @@ +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" +require "json" +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/rope_utils" + +class Phase14SuScaledRoPETest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_su_scaled_rope_matches_python + rope = MlxLm::Models::SuScaledRoPE.new( + 8, + base: 10_000.0, + max_position_embeddings: 131_072, + original_max_position_embeddings: 4096, + short_factor: [1.0, 1.0, 1.0, 1.0], + long_factor: [1.0, 1.1, 1.2, 1.3] + ) + + x = @mx.arange(0, 16, 1, @mx.float32).reshape([1, 1, 2, 8]) + y = rope.call(x, offset: 2) + + @mx.eval(rope._freqs, y) + + py = python_eval(<<~PY) + import json + import sys + import mlx.core as mx + + sys.path.insert(0, "mlx-lm") + from mlx_lm.models.rope_utils import SuScaledRoPE + + rope = SuScaledRoPE( + dims=8, + base=10000.0, + max_position_embeddings=131072, + original_max_position_embeddings=4096, + short_factor=[1.0, 1.0, 1.0, 1.0], + long_factor=[1.0, 1.1, 1.2, 1.3], + ) + x = mx.arange(0, 16, dtype=mx.float32).reshape(1, 1, 2, 8) + y = rope(x, offset=2) + + mx.eval(rope._freqs, y) + print(json.dumps({ + "scale": float(rope._scale), + "freqs": rope._freqs.tolist(), + "output": y.tolist(), + })) + PY + + assert_in_delta py["scale"], rope._scale, 1e-6 + assert_arrays_close py["freqs"], rope._freqs.to_a, atol: 1e-3, msg: "SuScaledRoPE freqs" + assert_arrays_close py["output"], y.to_a, atol: 1e-5, msg: "SuScaledRoPE output" + end +end + +class Phase14Llama3RoPETest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_llama3_rope_matches_python + scaling_config = { + "type" => "llama3", + "factor" => 8.0, + "low_freq_factor" => 1.0, + "high_freq_factor" => 4.0, + "original_max_position_embeddings" => 8192, + } + + rope = MlxLm::Models::Llama3RoPE.new( + dims: 8, + max_position_embeddings: 2048, + traditional: false, + base: 10_000.0, + scaling_config: scaling_config + ) + + x = @mx.arange(0, 16, 1, @mx.float32).reshape([1, 1, 2, 8]) + y = rope.call(x, offset: 1) + + @mx.eval(rope._freqs, y) + + py = python_eval(<<~PY) + import json + import sys + import mlx.core as mx + + sys.path.insert(0, "mlx-lm") + from mlx_lm.models.rope_utils import Llama3RoPE + + rope = Llama3RoPE( + dims=8, + max_position_embeddings=2048, + traditional=False, + base=10000.0, + scaling_config={ + "type": "llama3", + "factor": 8.0, + "low_freq_factor": 1.0, + "high_freq_factor": 4.0, + "original_max_position_embeddings": 8192, + }, + ) + x = mx.arange(0, 16, dtype=mx.float32).reshape(1, 1, 2, 8) + y = rope(x, offset=1) + + mx.eval(rope._freqs, y) + print(json.dumps({ + "freqs": rope._freqs.tolist(), + "output": y.tolist(), + })) + PY + + assert_arrays_close py["freqs"], rope._freqs.to_a, atol: 1e-3, msg: "Llama3RoPE freqs" + assert_arrays_close py["output"], y.to_a, atol: 1e-5, msg: "Llama3RoPE output" + end +end + +class Phase14YarnRoPETest < Minitest::Test + include ParityTestHelpers + + def setup + @mx = MLX::Core + end + + def test_yarn_rope_matches_python + rope = MlxLm::Models::YarnRoPE.new( + 8, + traditional: false, + max_position_embeddings: 2048, + base: 10_000.0, + scaling_factor: 4.0, + original_max_position_embeddings: 4096, + beta_fast: 32, + beta_slow: 1, + mscale: 1.5, + mscale_all_dim: 0.5 + ) + + x = @mx.arange(0, 24, 1, @mx.float32).reshape([1, 1, 2, 12]) + y = rope.call(x, offset: 3) + + @mx.eval(rope._freqs, y) + + py = python_eval(<<~PY) + import json + import sys + import mlx.core as mx + + sys.path.insert(0, "mlx-lm") + from mlx_lm.models.rope_utils import YarnRoPE + + rope = YarnRoPE( + dims=8, + traditional=False, + max_position_embeddings=2048, + base=10000.0, + scaling_factor=4.0, + original_max_position_embeddings=4096, + beta_fast=32, + beta_slow=1, + mscale=1.5, + mscale_all_dim=0.5, + ) + x = mx.arange(0, 24, dtype=mx.float32).reshape(1, 1, 2, 12) + y = rope(x, offset=3) + + mx.eval(rope._freqs, y) + print(json.dumps({ + "mscale": float(rope.mscale), + "freqs": rope._freqs.tolist(), + "output": y.tolist(), + })) + PY + + assert_in_delta py["mscale"], rope.mscale, 1e-6 + assert_arrays_close py["freqs"], rope._freqs.to_a, atol: 1e-3, msg: "YarnRoPE freqs" + assert_arrays_close py["output"], y.to_a, atol: 1e-5, msg: "YarnRoPE output" + end +end + +class Phase14RoPEFactoryTest < Minitest::Test + def test_initialize_rope_default_and_linear + default_rope = MlxLm::Models.initialize_rope(8, 10_000.0, false) + assert_instance_of MLX::NN::RoPE, default_rope + assert_in_delta 1.0, default_rope.scale, 1e-8 + + linear_rope = MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { "type" => "linear", "factor" => 4.0 } + ) + assert_instance_of MLX::NN::RoPE, linear_rope + assert_in_delta 0.25, linear_rope.scale, 1e-8 + end + + def test_initialize_rope_variant_dispatch + llama3 = MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { + "type" => "llama3", + "factor" => 8.0, + "low_freq_factor" => 1.0, + "high_freq_factor" => 4.0, + "original_max_position_embeddings" => 8192, + }, + max_position_embeddings: 2048 + ) + assert_instance_of MlxLm::Models::Llama3RoPE, llama3 + + yarn = MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { + "rope_type" => "deepseek_yarn", + "factor" => 4.0, + }, + max_position_embeddings: 2048 + ) + assert_instance_of MlxLm::Models::YarnRoPE, yarn + + longrope = MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { + "type" => "longrope", + "original_max_position_embeddings" => 4096, + "short_factor" => [1.0, 1.0, 1.0, 1.0], + "long_factor" => [1.0, 1.0, 1.0, 1.0], + }, + max_position_embeddings: 131_072 + ) + assert_instance_of MlxLm::Models::SuScaledRoPE, longrope + end + + def test_initialize_rope_mrope_and_error_paths + mrope = MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { + "rope_type" => "mrope", + "mrope_section" => [16, 16, 16], + } + ) + assert_instance_of MLX::NN::RoPE, mrope + + err = assert_raises(ArgumentError) do + MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { + "rope_type" => "mrope", + "mrope_section" => [16, 16], + } + ) + end + assert_match(/MRoPE currently only supports 3 sections/, err.message) + + unsupported = assert_raises(ArgumentError) do + MlxLm::Models.initialize_rope( + 8, + 10_000.0, + false, + { "type" => "unknown_rope" } + ) + end + assert_match(/Unsupported RoPE type unknown_rope/, unsupported.message) + end +end diff --git a/test/parity/rwkv7_model_test.rb b/test/parity/rwkv7_model_test.rb new file mode 100644 index 0000000..307e788 --- /dev/null +++ b/test/parity/rwkv7_model_test.rb @@ -0,0 +1,69 @@ +$LOAD_PATH.unshift File.expand_path("../../lib", __dir__) +$LOAD_PATH.unshift File.expand_path("../../mlx-ruby/lib", __dir__) + +require "mlx" +require "minitest/autorun" + +require_relative "../../lib/mlx_lm/model_args" +require_relative "../../lib/mlx_lm/models" +require_relative "../../lib/mlx_lm/models/cache" +require_relative "../../lib/mlx_lm/models/recurrent_gemma" +require_relative "../../lib/mlx_lm/models/rwkv7" + +class Phase45DenseLaneAURwkv7Test < Minitest::Test + def setup + @mx = MLX::Core + end + + def test_rwkv7_construct_forward_shape_sanitize_and_cache + args = MlxLm::Models::Rwkv7::ModelArgs.from_dict({ + "model_type" => "rwkv7", + "vocab_size" => 67, + "hidden_size" => 32, + "intermediate_size" => 64, + "norm_eps" => 1e-5, + "head_dim" => 8, + "num_hidden_layers" => 2, + "a_low_rank_dim" => 8, + "v_low_rank_dim" => 8, + "gate_low_rank_dim" => 8, + "decay_low_rank_dim" => 8, + "tie_word_embeddings" => false, + }) + + model = MlxLm::Models::Rwkv7::Model.new(args) + @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, value| value }) + + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) + output = model.call(tokens) + @mx.eval(output) + assert_equal [1, 3, 67], output.shape + assert_equal 2, model.layers.length + + conv_weight = @mx.array((0...24).to_a, dtype: @mx.float32).reshape([4, 1, 6]) + weights = { + "blocks.0.time_mix.conv_1d.weight" => conv_weight, + "blocks.0.time_mix.linear_x.weight" => @mx.zeros([32, 32]).astype(@mx.float32), + } + sanitized = model.sanitize(weights) + sanitized_conv = sanitized["model.layers.0.temporal_block.conv_1d.weight"] + @mx.eval(sanitized_conv) + assert_equal [4, 6, 1], sanitized_conv.shape + assert sanitized.key?("model.layers.0.temporal_block.linear_x.weight") + + cache = model.make_cache + assert_equal 2, cache.length + assert_instance_of MlxLm::ArraysCache, cache[0] + assert_instance_of MlxLm::ArraysCache, cache[1] + end +end + +class Phase45DenseLaneAURegistryTest < Minitest::Test + def test_rwkv7_registered_and_resolve + assert MlxLm::Models::REGISTRY.key?("rwkv7"), "rwkv7 should be registered" + + model_class, args_class = MlxLm::Models.get_classes({ "model_type" => "rwkv7" }) + assert_equal MlxLm::Models::Rwkv7::Model, model_class + assert_equal MlxLm::Models::Rwkv7::ModelArgs, args_class + end +end diff --git a/test/parity/phase4_test.rb b/test/parity/sampling_and_generation_test.rb similarity index 90% rename from test/parity/phase4_test.rb rename to test/parity/sampling_and_generation_test.rb index 59da2b9..8ae2eff 100644 --- a/test/parity/phase4_test.rb +++ b/test/parity/sampling_and_generation_test.rb @@ -13,7 +13,7 @@ def setup # Test 1: Greedy sampling (temp=0) returns argmax def test_greedy_sampling sampler = MlxLm::SampleUtils.make_sampler(temp: 0.0) - logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1]]).astype(@mx.float32) + logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1]], dtype: @mx.float32) token = sampler.call(logprobs) assert_equal 3, token.item, "Greedy should pick index of max logprob" end @@ -23,7 +23,7 @@ def test_temperature_sampling # With very low temp, should behave close to greedy # We use a fixed seed for reproducibility @mx.random_seed(42) - logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1]]).astype(@mx.float32) + logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1]], dtype: @mx.float32) # At very low temperature, categorical should pick the argmax (almost always) sampler_low = MlxLm::SampleUtils.make_sampler(temp: 0.01) @@ -38,7 +38,7 @@ def test_temperature_sampling # Test 3: Top-k filtering masks low-probability tokens def test_top_k_filtering - logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1, -3.0]]).astype(@mx.float32) + logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1, -3.0]], dtype: @mx.float32) result = MlxLm::SampleUtils.apply_top_k(logprobs, 2) result_list = result.tolist.flatten @@ -54,7 +54,7 @@ def test_top_k_filtering # Test 4: Top-p (nucleus) filtering def test_top_p_filtering # Create logprobs where one token dominates - logprobs = @mx.array([[0.0, -10.0, -10.0, -10.0]]).astype(@mx.float32) + logprobs = @mx.array([[0.0, -10.0, -10.0, -10.0]], dtype: @mx.float32) result = MlxLm::SampleUtils.apply_top_p(logprobs, 0.5) result_list = result.tolist.flatten @@ -65,7 +65,7 @@ def test_top_p_filtering # Test 5: Min-p filtering def test_min_p_filtering # One dominant token, rest very low - logprobs = @mx.array([[-0.1, -5.0, -5.0, -5.0]]).astype(@mx.float32) + logprobs = @mx.array([[-0.1, -5.0, -5.0, -5.0]], dtype: @mx.float32) result = MlxLm::SampleUtils.apply_min_p(logprobs, 0.5, 1) result_list = result.tolist.flatten @@ -79,8 +79,8 @@ def test_min_p_filtering # Test 6: Repetition penalty reduces probability of repeated tokens def test_repetition_penalty processor = MlxLm::SampleUtils.make_repetition_penalty(2.0, 20) - logits = @mx.array([[1.0, 2.0, 3.0, 4.0]]).astype(@mx.float32) - tokens = @mx.array([0, 2]).astype(@mx.int32) # Tokens 0 and 2 were already generated + logits = @mx.array([[1.0, 2.0, 3.0, 4.0]], dtype: @mx.float32) + tokens = @mx.array([0, 2], dtype: @mx.int32) # Tokens 0 and 2 were already generated result = processor.call(tokens, logits) result_list = result.tolist.flatten @@ -99,7 +99,7 @@ def test_repetition_penalty def test_sampler_chaining @mx.random_seed(42) sampler = MlxLm::SampleUtils.make_sampler(temp: 1.0, top_k: 2) - logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1, -3.0]]).astype(@mx.float32) + logprobs = @mx.array([[-1.0, -0.5, -2.0, -0.1, -3.0]], dtype: @mx.float32) counts = Hash.new(0) 100.times do @@ -126,7 +126,7 @@ def test_logits_processors def test_categorical_sampling_distribution @mx.random_seed(123) # Make one token vastly more probable - logprobs = @mx.array([[-10.0, -10.0, 0.0, -10.0]]).astype(@mx.float32) + logprobs = @mx.array([[-10.0, -10.0, 0.0, -10.0]], dtype: @mx.float32) counts = Hash.new(0) 100.times do @@ -188,14 +188,14 @@ def call(tokens, cache: nil) end end - mx.array(logits_data).astype(mx.float32) + mx.array(logits_data, dtype: mx.float32) end end # Test 10: generate_step produces correct tokens from dummy model def test_generate_step_greedy model = DummyModel.new(vocab_size: 10, num_layers: 2) - prompt = @mx.array([1, 2]).astype(@mx.uint32) + prompt = @mx.array([1, 2], dtype: @mx.uint32) tokens = [] MlxLm::Generate.generate_step(prompt, model, max_tokens: 5).each do |token, logprobs| @@ -210,7 +210,7 @@ def test_generate_step_greedy # Test 11: generate_step respects max_tokens def test_generate_step_max_tokens model = DummyModel.new(vocab_size: 10, num_layers: 2) - prompt = @mx.array([1]).astype(@mx.uint32) + prompt = @mx.array([1], dtype: @mx.uint32) tokens = [] MlxLm::Generate.generate_step(prompt, model, max_tokens: 3).each do |token, _| @@ -223,10 +223,10 @@ def test_generate_step_max_tokens # Test 12: generate_step with custom sampler def test_generate_step_custom_sampler model = DummyModel.new(vocab_size: 10, num_layers: 2) - prompt = @mx.array([1, 2]).astype(@mx.uint32) + prompt = @mx.array([1, 2], dtype: @mx.uint32) # Custom sampler that always picks token 7 - custom_sampler = ->(_logprobs) { @mx.array([7]).astype(@mx.int32) } + custom_sampler = ->(_logprobs) { @mx.array([7], dtype: @mx.int32) } tokens = [] MlxLm::Generate.generate_step(prompt, model, max_tokens: 3, sampler: custom_sampler).each do |token, _| diff --git a/test/parity/ssm_attention_update_paths_test.rb b/test/parity/ssm_attention_update_paths_test.rb new file mode 100644 index 0000000..bc70de3 --- /dev/null +++ b/test/parity/ssm_attention_update_paths_test.rb @@ -0,0 +1,93 @@ +# frozen_string_literal: true + +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/ssm" + +class Phase15SsmTest < Minitest::Test + def test_compute_dt_applies_softplus_and_clip + mx = MLX::Core + + dt = mx.array([[-10.0, 0.0, 10.0]], dtype: mx.float32) + dt_bias = mx.array([[0.0, 0.0, 0.0]], dtype: mx.float32) + + out = MlxLm::Models::SSM.compute_dt(dt, dt_bias, [0.1, 1.0]) + actual = out.tolist.flatten + + expected = [-10.0, 0.0, 10.0].map do |v| + softplus = Math.log1p(Math.exp(v)) + [[softplus, 0.1].max, 1.0].min + end + + expected.zip(actual).each do |e, a| + assert_in_delta e, a, 1e-5 + end + end + + def test_segsum_masked_positions_become_negative_infinity + mx = MLX::Core + + x = mx.array([[[1.0, 2.0, 3.0]]], dtype: mx.float32) + mask = mx.array([[1, 1, 0]], dtype: mx.float32) + + out = MlxLm::Models::SSM.segsum(x, mask: mask) + values = out.tolist + + assert_equal 1, values.length + assert_equal 1, values[0].length + + # Any relation touching masked index 2 should be -inf in the masked segsum output. + row = values[0][0] + assert row[2][0].infinite? + assert row[2][1].infinite? + assert row[2][2].infinite? + end + + def test_ssm_attn_shapes + mx = MLX::Core + + batch = 1 + seq = 3 + heads = 2 + head_dim = 2 + groups = 1 + state_dim = 3 + + x = mx.random_uniform([batch, seq, heads, head_dim], -0.5, 0.5, mx.float32) + a_log = mx.random_uniform([heads], -0.2, 0.2, mx.float32) + b = mx.random_uniform([batch, seq, groups, state_dim], -0.5, 0.5, mx.float32) + c = mx.random_uniform([batch, seq, groups, state_dim], -0.5, 0.5, mx.float32) + d = mx.random_uniform([heads], -0.2, 0.2, mx.float32) + dt = mx.random_uniform([batch, seq, heads], 0.01, 0.2, mx.float32) + dt_bias = mx.zeros([heads], mx.float32) + + y, state = MlxLm::Models::SSM.ssm_attn( + x, a_log, b, c, d, dt, dt_bias + ) + + assert_equal [batch, seq, heads, head_dim], y.shape + assert_equal [batch, heads, head_dim, state_dim], state.shape + end + + def test_ssm_update_dispatches_to_attn_path + mx = MLX::Core + + x = mx.random_uniform([1, 2, 2, 2], -0.5, 0.5, mx.float32) + a_log = mx.random_uniform([2], -0.2, 0.2, mx.float32) + b = mx.random_uniform([1, 2, 1, 3], -0.5, 0.5, mx.float32) + c = mx.random_uniform([1, 2, 1, 3], -0.5, 0.5, mx.float32) + d = mx.random_uniform([2], -0.2, 0.2, mx.float32) + dt = mx.random_uniform([1, 2, 2], 0.01, 0.2, mx.float32) + dt_bias = mx.zeros([2], mx.float32) + + y, state = MlxLm::Models::SSM.ssm_update(x, a_log, b, c, d, dt, dt_bias) + + assert_equal [1, 2, 2, 2], y.shape + assert_equal [1, 2, 2, 3], state.shape + end + + def test_ssm_update_kernel_explicitly_not_implemented + assert_raises(NotImplementedError) do + MlxLm::Models::SSM.ssm_update_kernel(nil) + end + end +end diff --git a/test/parity/ssm_gated_delta_bitlinear_pipeline_integration_smoke_test.rb b/test/parity/ssm_gated_delta_bitlinear_pipeline_integration_smoke_test.rb new file mode 100644 index 0000000..7af81ea --- /dev/null +++ b/test/parity/ssm_gated_delta_bitlinear_pipeline_integration_smoke_test.rb @@ -0,0 +1,26 @@ +# frozen_string_literal: true + +require_relative "../test_helper" + +class Phase15IntegrationSmokeTest < Minitest::Test + def test_new_phase15_modules_load + assert defined?(MlxLm::Models::SSM) + assert defined?(MlxLm::Models::GatedDelta) + assert defined?(MlxLm::Models::Activations) + assert defined?(MlxLm::Models::PipelineMixin) + assert defined?(MlxLm::Models::BitLinear) + end + + def test_basic_construction_paths + mx = MLX::Core + + xielu = MlxLm::Models::Activations::XieLU.new + x = mx.array([-1.0, 0.0, 1.0], dtype: mx.float32) + y = xielu.call(x) + assert_equal [3], y.shape + + bitlinear = MlxLm::Models::BitLinear.new(4, 3) + out = bitlinear.call(mx.random_uniform([2, 4], -0.5, 0.5, mx.float32)) + assert_equal [2, 3], out.shape + end +end diff --git a/test/parity/phase8_test.rb b/test/parity/stablelm_cohere_gemma2_test.rb similarity index 92% rename from test/parity/phase8_test.rb rename to test/parity/stablelm_cohere_gemma2_test.rb index 05d66a0..734a619 100644 --- a/test/parity/phase8_test.rb +++ b/test/parity/stablelm_cohere_gemma2_test.rb @@ -37,7 +37,7 @@ def test_stablelm_forward model = MlxLm::Models::StableLM::Model.new(args) @mx.eval(*model.parameters.values.flat_map { |v| v.is_a?(Hash) ? v.values : [v] }) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 128], output.shape end @@ -50,11 +50,11 @@ def test_stablelm_with_cache cache = Array.new(2) { MlxLm::KVCache.new } - token1 = @mx.array([[1]]).astype(@mx.int32) + token1 = @mx.array([[1]], dtype: @mx.int32) out1 = model.call(token1, cache: cache) assert_equal [1, 1, 128], out1.shape - token2 = @mx.array([[2]]).astype(@mx.int32) + token2 = @mx.array([[2]], dtype: @mx.int32) out2 = model.call(token2, cache: cache) assert_equal [1, 1, 128], out2.shape end @@ -79,7 +79,7 @@ def test_stablelm_parallel_residual model = MlxLm::Models::StableLM::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 128], output.shape end @@ -91,7 +91,7 @@ def test_stablelm_partial_rotary # Partial rotary factor of 0.5 with head_dim=32 means RoPE on 16 dims # Just verify the model constructs and produces valid output @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1]]).astype(@mx.int32) + tokens = @mx.array([[1]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 1, 128], output.shape end @@ -132,7 +132,7 @@ def test_cohere_forward model = MlxLm::Models::Cohere::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 128], output.shape end @@ -145,11 +145,11 @@ def test_cohere_with_cache cache = Array.new(2) { MlxLm::KVCache.new } - token1 = @mx.array([[1]]).astype(@mx.int32) + token1 = @mx.array([[1]], dtype: @mx.int32) out1 = model.call(token1, cache: cache) assert_equal [1, 1, 128], out1.shape - token2 = @mx.array([[2]]).astype(@mx.int32) + token2 = @mx.array([[2]], dtype: @mx.int32) out2 = model.call(token2, cache: cache) assert_equal [1, 1, 128], out2.shape end @@ -160,7 +160,7 @@ def test_cohere_parallel_residuals model = MlxLm::Models::Cohere::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) # Just verify it works (parallel residual is an internal implementation detail) assert_equal [1, 3, 128], output.shape @@ -172,7 +172,7 @@ def test_cohere_logit_scaling model = MlxLm::Models::Cohere::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1]]).astype(@mx.int32) + tokens = @mx.array([[1]], dtype: @mx.int32) output = model.call(tokens) # Output should be scaled by logit_scale (0.0625) # Values should generally be small due to the scaling @@ -216,7 +216,7 @@ def test_gemma2_forward model = MlxLm::Models::Gemma2::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1, 2, 3]]).astype(@mx.int32) + tokens = @mx.array([[1, 2, 3]], dtype: @mx.int32) output = model.call(tokens) assert_equal [1, 3, 128], output.shape end @@ -229,11 +229,11 @@ def test_gemma2_with_cache cache = Array.new(2) { MlxLm::KVCache.new } - token1 = @mx.array([[1]]).astype(@mx.int32) + token1 = @mx.array([[1]], dtype: @mx.int32) out1 = model.call(token1, cache: cache) assert_equal [1, 1, 128], out1.shape - token2 = @mx.array([[2]]).astype(@mx.int32) + token2 = @mx.array([[2]], dtype: @mx.int32) out2 = model.call(token2, cache: cache) assert_equal [1, 1, 128], out2.shape end @@ -258,7 +258,7 @@ def test_gemma2_logit_softcapping model = MlxLm::Models::Gemma2::Model.new(args) @mx.eval(*MLX::Utils.tree_flatten(model.parameters).map { |_, v| v }) - tokens = @mx.array([[1]]).astype(@mx.int32) + tokens = @mx.array([[1]], dtype: @mx.int32) output = model.call(tokens) @mx.eval(output) diff --git a/test/parity/switch_layers_pipeline_mixin_test.rb b/test/parity/switch_layers_pipeline_mixin_test.rb new file mode 100644 index 0000000..407c2fe --- /dev/null +++ b/test/parity/switch_layers_pipeline_mixin_test.rb @@ -0,0 +1,116 @@ +require_relative "../test_helper" +require_relative "../../lib/mlx_lm/models/pipeline" + +class Phase15SwitchPipelineTest < Minitest::Test + include ParityTestHelpers + + class FakeGroup + def initialize(rank, size) + @rank = rank + @size = size + end + + def rank + @rank + end + + def size + @size + end + end + + class DummyPipelineModel < MLX::NN::Module + include MlxLm::Models::PipelineMixin + end + + def setup + @mx = MLX::Core + end + + def test_switch_linear_to_quantized_and_forward + @mx.random_seed(7) + linear = MlxLm::Models::SwitchLayers::SwitchLinear.new(32, 16, 4, bias: true) + qlinear = linear.to_quantized(group_size: 32, bits: 4, mode: "affine") + + assert_instance_of MlxLm::Models::SwitchLayers::QuantizedSwitchLinear, qlinear + assert_equal 32, qlinear.group_size + assert_equal 4, qlinear.bits + assert_equal "affine", qlinear.mode + + x = @mx.random_uniform([2, 3, 1, 1, 32], -0.5, 0.5, @mx.float32) + indices = @mx.array( + [ + [[0, 1], [1, 2], [2, 3]], + [[3, 2], [2, 1], [1, 0]] + ], + dtype: @mx.int32 + ) + + expected = linear.call(x, indices) + actual = qlinear.call(x, indices) + @mx.eval(expected, actual) + + assert_equal expected.shape, actual.shape + assert_arrays_close( + expected.tolist, + actual.tolist, + atol: 0.35, + msg: "QuantizedSwitchLinear output should closely track SwitchLinear" + ) + end + + def test_switch_linear_to_quantized_rejects_quantized_input + linear = MlxLm::Models::SwitchLayers::SwitchLinear.new(8, 4, 2) + assert_raises(ArgumentError) do + linear.to_quantized(quantize_input: true) + end + end + + def test_switch_mlp_forward_shape_with_sorted_route_path + @mx.random_seed(123) + mlp = MlxLm::Models::SwitchLayers::SwitchMLP.new(16, 32, 8) + x = @mx.random_uniform([4, 8, 16], -1.0, 1.0, @mx.float32) + + route_data = Array.new(4) do + Array.new(8) do + [rand(8), rand(8)] + end + end + indices = @mx.array(route_data, dtype: @mx.int32) # 4*8*2 = 64 routes -> sorted path + + y = mlp.call(x, indices) + @mx.eval(y) + assert_equal [4, 8, 2, 16], y.shape + end + + def test_pipeline_mixin_defaults_and_partitioning + base = DummyPipelineModel.new + base.layers = (0...6).to_a + + assert_equal 0, base.pipeline_rank + assert_equal 1, base.pipeline_size + assert_equal 0, base.start_idx + assert_nil base.end_idx + assert_equal [0, 1, 2, 3, 4, 5], base.pipeline_layers + + rank1 = DummyPipelineModel.new + rank1.layers = (0...10).to_a + returned = rank1.pipeline(FakeGroup.new(1, 3)) + + assert_same rank1, returned + assert_equal 1, rank1.pipeline_rank + assert_equal 3, rank1.pipeline_size + assert_equal 3, rank1.start_idx + assert_equal 6, rank1.end_idx + assert_equal [nil, nil, nil, 3, 4, 5], rank1.layers + assert_equal [3, 4, 5], rank1.pipeline_layers + + rank0 = DummyPipelineModel.new + rank0.layers = (0...10).to_a + rank0.pipeline(FakeGroup.new(0, 3)) + assert_equal 8, rank0.start_idx + assert_equal 12, rank0.end_idx + assert_equal [nil, nil, nil, nil, nil, nil, nil, nil, 8, 9], rank0.layers + assert_equal [8, 9], rank0.pipeline_layers + end +end diff --git a/test/parity/phase2_test.rb b/test/parity/tokenizer_streaming_detokenizer_test.rb similarity index 100% rename from test/parity/phase2_test.rb rename to test/parity/tokenizer_streaming_detokenizer_test.rb diff --git a/test/reports/python_ruby_parity_inventory_snapshot.json b/test/reports/python_ruby_parity_inventory_snapshot.json new file mode 100644 index 0000000..8201d5f --- /dev/null +++ b/test/reports/python_ruby_parity_inventory_snapshot.json @@ -0,0 +1,491 @@ +{ + "inventory_version": 1, + "source_paths": { + "python_models_dir": "mlx-lm/mlx_lm/models", + "ruby_models_dir": "lib/mlx_lm/models", + "ruby_registry_file": "lib/mlx_lm/models.rb" + }, + "python": { + "model_files_total": 116, + "shared_infra_files": [ + "activations.py", + "base.py", + "bitlinear_layers.py", + "cache.py", + "gated_delta.py", + "mla.py", + "pipeline.py", + "rope_utils.py", + "ssm.py", + "switch_layers.py" + ], + "architecture_files_total": 106, + "architecture_files": [ + "Klear.py", + "afm7.py", + "afmoe.py", + "apertus.py", + "baichuan_m1.py", + "bailing_moe.py", + "bailing_moe_linear.py", + "bitnet.py", + "cohere.py", + "cohere2.py", + "dbrx.py", + "deepseek.py", + "deepseek_v2.py", + "deepseek_v3.py", + "deepseek_v32.py", + "dots1.py", + "ernie4_5.py", + "ernie4_5_moe.py", + "exaone.py", + "exaone4.py", + "exaone_moe.py", + "falcon_h1.py", + "gemma.py", + "gemma2.py", + "gemma3.py", + "gemma3_text.py", + "gemma3n.py", + "glm.py", + "glm4.py", + "glm4_moe.py", + "glm4_moe_lite.py", + "glm_moe_dsa.py", + "gpt2.py", + "gpt_bigcode.py", + "gpt_neox.py", + "gpt_oss.py", + "granite.py", + "granitemoe.py", + "granitemoehybrid.py", + "helium.py", + "hunyuan.py", + "hunyuan_v1_dense.py", + "internlm2.py", + "internlm3.py", + "iquestloopcoder.py", + "jamba.py", + "kimi_k25.py", + "kimi_linear.py", + "kimi_vl.py", + "lfm2-vl.py", + "lfm2.py", + "lfm2_moe.py", + "lille-130m.py", + "llama.py", + "llama4.py", + "llama4_text.py", + "longcat_flash.py", + "longcat_flash_ngram.py", + "mamba.py", + "mamba2.py", + "mimo.py", + "mimo_v2_flash.py", + "minicpm.py", + "minicpm3.py", + "minimax.py", + "ministral3.py", + "mistral3.py", + "mixtral.py", + "nanochat.py", + "nemotron-nas.py", + "nemotron.py", + "nemotron_h.py", + "olmo.py", + "olmo2.py", + "olmo3.py", + "olmoe.py", + "openelm.py", + "phi.py", + "phi3.py", + "phi3small.py", + "phimoe.py", + "phixtral.py", + "pixtral.py", + "plamo.py", + "plamo2.py", + "qwen.py", + "qwen2.py", + "qwen2_moe.py", + "qwen2_vl.py", + "qwen3.py", + "qwen3_5.py", + "qwen3_5_moe.py", + "qwen3_moe.py", + "qwen3_next.py", + "qwen3_vl.py", + "qwen3_vl_moe.py", + "recurrent_gemma.py", + "rwkv7.py", + "seed_oss.py", + "smollm3.py", + "solar_open.py", + "stablelm.py", + "starcoder2.py", + "step3p5.py", + "telechat3.py", + "youtu_llm.py" + ] + }, + "ruby": { + "model_files_total": 115, + "shared_infra_files": [ + "cache.rb", + "switch_layers.rb" + ], + "architecture_files_total": 113, + "architecture_files": [ + "activations.rb", + "afm7.rb", + "afmoe.rb", + "apertus.rb", + "baichuan_m1.rb", + "bailing_moe.rb", + "bailing_moe_linear.rb", + "bitlinear_layers.rb", + "bitnet.rb", + "cohere.rb", + "cohere2.rb", + "dbrx.rb", + "deepseek.rb", + "deepseek_v2.rb", + "deepseek_v3.rb", + "deepseek_v32.rb", + "dots1.rb", + "ernie4_5.rb", + "ernie4_5_moe.rb", + "exaone.rb", + "exaone4.rb", + "exaone_moe.rb", + "falcon_h1.rb", + "gated_delta.rb", + "gemma.rb", + "gemma2.rb", + "gemma3.rb", + "gemma3_text.rb", + "gemma3n.rb", + "glm.rb", + "glm4.rb", + "glm4_moe.rb", + "glm4_moe_lite.rb", + "glm_moe_dsa.rb", + "gpt2.rb", + "gpt_bigcode.rb", + "gpt_neox.rb", + "gpt_oss.rb", + "granite.rb", + "granitemoe.rb", + "granitemoehybrid.rb", + "helium.rb", + "hunyuan.rb", + "hunyuan_v1_dense.rb", + "internlm2.rb", + "internlm3.rb", + "iquestloopcoder.rb", + "jamba.rb", + "kimi_k25.rb", + "kimi_linear.rb", + "kimi_vl.rb", + "klear.rb", + "lfm2.rb", + "lfm2_moe.rb", + "lfm2_vl.rb", + "lille_130m.rb", + "llama.rb", + "llama4.rb", + "llama4_text.rb", + "longcat_flash.rb", + "longcat_flash_ngram.rb", + "mamba.rb", + "mamba2.rb", + "mimo.rb", + "mimo_v2_flash.rb", + "minicpm.rb", + "minicpm3.rb", + "minimax.rb", + "ministral3.rb", + "mistral3.rb", + "mixtral.rb", + "mla.rb", + "nanochat.rb", + "nemotron.rb", + "nemotron_h.rb", + "nemotron_nas.rb", + "olmo.rb", + "olmo2.rb", + "olmo3.rb", + "olmoe.rb", + "openelm.rb", + "phi.rb", + "phi3.rb", + "phi3small.rb", + "phimoe.rb", + "phixtral.rb", + "pipeline.rb", + "pixtral.rb", + "plamo.rb", + "plamo2.rb", + "qwen.rb", + "qwen2.rb", + "qwen2_moe.rb", + "qwen2_vl.rb", + "qwen3.rb", + "qwen3_5.rb", + "qwen3_5_moe.rb", + "qwen3_moe.rb", + "qwen3_next.rb", + "qwen3_vl.rb", + "qwen3_vl_moe.rb", + "recurrent_gemma.rb", + "rope_utils.rb", + "rwkv7.rb", + "seed_oss.rb", + "smollm3.rb", + "solar_open.rb", + "ssm.rb", + "stablelm.rb", + "starcoder2.rb", + "step3p5.rb", + "telechat3.rb", + "youtu_llm.rb" + ], + "architecture_model_keys_total": 113, + "architecture_model_keys": [ + "activations", + "afm7", + "afmoe", + "apertus", + "baichuan_m1", + "bailing_moe", + "bailing_moe_linear", + "bitlinear_layers", + "bitnet", + "cohere", + "cohere2", + "dbrx", + "deepseek", + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "dots1", + "ernie4_5", + "ernie4_5_moe", + "exaone", + "exaone4", + "exaone_moe", + "falcon_h1", + "gated_delta", + "gemma", + "gemma2", + "gemma3", + "gemma3_text", + "gemma3n", + "glm", + "glm4", + "glm4_moe", + "glm4_moe_lite", + "glm_moe_dsa", + "gpt2", + "gpt_bigcode", + "gpt_neox", + "gpt_oss", + "granite", + "granitemoe", + "granitemoehybrid", + "helium", + "hunyuan", + "hunyuan_v1_dense", + "internlm2", + "internlm3", + "iquestloopcoder", + "jamba", + "kimi_k25", + "kimi_linear", + "kimi_vl", + "klear", + "lfm2", + "lfm2_moe", + "lfm2_vl", + "lille_130m", + "llama", + "llama4", + "llama4_text", + "longcat_flash", + "longcat_flash_ngram", + "mamba", + "mamba2", + "mimo", + "mimo_v2_flash", + "minicpm", + "minicpm3", + "minimax", + "ministral3", + "mistral3", + "mixtral", + "mla", + "nanochat", + "nemotron", + "nemotron_h", + "nemotron_nas", + "olmo", + "olmo2", + "olmo3", + "olmoe", + "openelm", + "phi", + "phi3", + "phi3small", + "phimoe", + "phixtral", + "pipeline", + "pixtral", + "plamo", + "plamo2", + "qwen", + "qwen2", + "qwen2_moe", + "qwen2_vl", + "qwen3", + "qwen3_5", + "qwen3_5_moe", + "qwen3_moe", + "qwen3_next", + "qwen3_vl", + "qwen3_vl_moe", + "recurrent_gemma", + "rope_utils", + "rwkv7", + "seed_oss", + "smollm3", + "solar_open", + "ssm", + "stablelm", + "starcoder2", + "step3p5", + "telechat3", + "youtu_llm" + ], + "registered_model_keys_total": 106, + "registered_model_keys": [ + "Klear", + "afm7", + "afmoe", + "apertus", + "baichuan_m1", + "bailing_moe", + "bailing_moe_linear", + "bitnet", + "cohere", + "cohere2", + "dbrx", + "deepseek", + "deepseek_v2", + "deepseek_v3", + "deepseek_v32", + "dots1", + "ernie4_5", + "ernie4_5_moe", + "exaone", + "exaone4", + "exaone_moe", + "falcon_h1", + "gemma", + "gemma2", + "gemma3", + "gemma3_text", + "gemma3n", + "glm", + "glm4", + "glm4_moe", + "glm4_moe_lite", + "glm_moe_dsa", + "gpt2", + "gpt_bigcode", + "gpt_neox", + "gpt_oss", + "granite", + "granitemoe", + "granitemoehybrid", + "helium", + "hunyuan", + "hunyuan_v1_dense", + "internlm2", + "internlm3", + "iquestloopcoder", + "jamba", + "kimi_k25", + "kimi_linear", + "kimi_vl", + "lfm2", + "lfm2-vl", + "lfm2_moe", + "lille-130m", + "llama", + "llama4", + "llama4_text", + "longcat_flash", + "longcat_flash_ngram", + "mamba", + "mamba2", + "mimo", + "mimo_v2_flash", + "minicpm", + "minicpm3", + "minimax", + "ministral3", + "mistral3", + "mixtral", + "nanochat", + "nemotron", + "nemotron-nas", + "nemotron_h", + "olmo", + "olmo2", + "olmo3", + "olmoe", + "openelm", + "phi", + "phi3", + "phi3small", + "phimoe", + "phixtral", + "pixtral", + "plamo", + "plamo2", + "qwen", + "qwen2", + "qwen2_moe", + "qwen2_vl", + "qwen3", + "qwen3_5", + "qwen3_5_moe", + "qwen3_moe", + "qwen3_next", + "qwen3_vl", + "qwen3_vl_moe", + "recurrent_gemma", + "rwkv7", + "seed_oss", + "smollm3", + "solar_open", + "stablelm", + "starcoder2", + "step3p5", + "telechat3", + "youtu_llm" + ], + "remappings_total": 2, + "remappings": { + "mistral": "llama", + "falcon_mamba": "mamba" + } + }, + "parity": { + "missing_architecture_file_count": 0, + "missing_architecture_files": [], + "extra_registered_model_keys_count": 0, + "extra_registered_model_keys": [] + } +} diff --git a/test/test_helper.rb b/test/test_helper.rb index 642e753..7436581 100644 --- a/test/test_helper.rb +++ b/test/test_helper.rb @@ -7,6 +7,7 @@ require "json" require "tempfile" require "fileutils" +require "open3" module ParityTestHelpers FIXTURES_DIR = File.expand_path("fixtures", __dir__) @@ -17,9 +18,27 @@ def fixtures_dir # Run a Python snippet, capture JSON output, return parsed result def python_eval(code) - result = `python3 -c '#{code.gsub("'", "'\\\\''")}'` - raise "Python eval failed: #{result}" unless $?.success? - JSON.parse(result) + stdout, stderr, status = Open3.capture3( + { + # Linux mlx Python can fail JIT C++ compilation in CI toolchains. + # Parity checks do not require compiled mode. + "MLX_DISABLE_COMPILE" => "1", + }, + "python3", "-c", code + ) + unless status.success? + raise <<~MSG + Python eval failed (exit #{status.exitstatus}) + STDERR: + #{stderr} + STDOUT: + #{stdout} + MSG + end + + JSON.parse(stdout) + rescue Errno::ENOENT + raise "Python eval failed: python3 is not installed" end # Assert two flat arrays are element-wise close