Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ on:
push:
branches:
- master
paths-ignore:
- README.md
- docs/**
pull_request:
paths-ignore:
- README.md
- docs/**

jobs:
test:
Expand Down Expand Up @@ -55,4 +61,4 @@ jobs:
echo "${GITHUB_WORKSPACE}/.venv-test/bin" >> "$GITHUB_PATH"

- name: Run all tests
run: bundle exec rake test
run: bundle exec rake test:all
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# mlx-ruby-lm

[![Tests](https://github.com/skryl/mlx-ruby-lm/actions/workflows/ci.yml/badge.svg)](https://github.com/skryl/mlx-ruby-lm/actions/workflows/ci.yml) [![Gem Version](https://badge.fury.io/rb/mlx-ruby-lm.svg)](https://rubygems.org/gems/mlx-ruby-lm)

Ruby LLM inference toolkit built on the `mlx` gem.

## Included tools
Expand Down
76 changes: 75 additions & 1 deletion Rakefile
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,69 @@ require_relative "tasks/parity_inventory_task"
VENV_DIR = File.expand_path(".venv-test", __dir__)
VENV_PYTHON = File.join(VENV_DIR, "bin", "python")
REQUIREMENTS_FILE = File.expand_path("requirements.txt", __dir__)
TEST_DEVICE_CHOICES = %w[cpu gpu].freeze
DEFAULT_TEST_DEVICES = %w[cpu gpu].freeze

Rake::TestTask.new(:test) do |t|
def parse_test_devices(args)
raw_values = []
raw_values << args[:devices] if args[:devices]
raw_values.concat(args.extras) if args.respond_to?(:extras)

values = if raw_values.empty?
DEFAULT_TEST_DEVICES.dup
else
raw_values
.flat_map { |value| value.to_s.split(",") }
.map(&:strip)
.reject(&:empty?)
.map(&:downcase)
end

values = DEFAULT_TEST_DEVICES.dup if values.empty?
invalid = values - TEST_DEVICE_CHOICES
unless invalid.empty?
raise ArgumentError, "invalid test device(s): #{invalid.join(', ')} (supported: #{TEST_DEVICE_CHOICES.join(', ')})"
end

values.uniq
end

Rake::TestTask.new("test:run") do |t|
t.libs << "test" << "lib"
t.test_files = FileList["test/**/*_test.rb"]
end

desc "Run tests on devices (default: cpu then gpu). Examples: rake test, rake \"test[cpu]\", rake \"test[cpu,gpu]\""
task :test, [:devices] do |_task, args|
devices = parse_test_devices(args)

devices.each do |device|
puts "\n==> Running test suite on #{device.upcase}"

previous_mlx_default_device = ENV["MLX_DEFAULT_DEVICE"]
previous_device = ENV["DEVICE"]

begin
ENV["MLX_DEFAULT_DEVICE"] = device
ENV["DEVICE"] = device
Rake::Task["test:run"].reenable
Rake::Task["test:run"].invoke
ensure
if previous_mlx_default_device.nil?
ENV.delete("MLX_DEFAULT_DEVICE")
else
ENV["MLX_DEFAULT_DEVICE"] = previous_mlx_default_device
end

if previous_device.nil?
ENV.delete("DEVICE")
else
ENV["DEVICE"] = previous_device
end
end
end
end

namespace :test do
desc "Install Python dependencies required by parity tests"
task :deps do
Expand All @@ -25,6 +82,23 @@ namespace :test do
t.libs << "test" << "lib"
t.test_files = FileList["test/parity/**/*_test.rb"]
end

desc "Run full test suite including ONNX full export tests"
task :all do
previous_full_export = ENV["ONNX_FULL_EXPORT"]
ENV["ONNX_FULL_EXPORT"] = "1"

begin
Rake::Task[:test].reenable
Rake::Task[:test].invoke
ensure
if previous_full_export.nil?
ENV.delete("ONNX_FULL_EXPORT")
else
ENV["ONNX_FULL_EXPORT"] = previous_full_export
end
end
end
end

namespace :parity do
Expand Down
10 changes: 10 additions & 0 deletions lib/mlx_lm.rb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
require_relative "mlx_lm/model_args"
require_relative "mlx_lm/weight_utils"
require_relative "mlx_lm/config"
require_relative "mlx_lm/gguf"
require_relative "mlx_lm/share"
require_relative "mlx_lm/tokenizer_utils"
require_relative "mlx_lm/sample_utils"
require_relative "mlx_lm/models"
Expand Down Expand Up @@ -120,9 +122,17 @@
require_relative "mlx_lm/models/plamo2"
require_relative "mlx_lm/models/qwen3_next"
require_relative "mlx_lm/models/rwkv7"
require_relative "mlx_lm/parity_aliases"
require_relative "mlx_lm/generate"
require_relative "mlx_lm/quantize"
require_relative "mlx_lm/quant/awq"
require_relative "mlx_lm/quant/gptq"
require_relative "mlx_lm/load_utils"
require_relative "mlx_lm/evaluate"
require_relative "mlx_lm/tuner/callbacks"
require_relative "mlx_lm/tuner/datasets"
require_relative "mlx_lm/tuner/dora"
require_relative "mlx_lm/tuner/trainer"
require_relative "mlx_lm/tuner/lora"
require_relative "mlx_lm/chat_template"
require_relative "mlx_lm/cli"
Expand Down
50 changes: 50 additions & 0 deletions lib/mlx_lm/evaluate.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
module MlxLm
class MLXLM
DEFAULT_MAX_TOKENS = 8192

attr_reader :model, :tokenizer, :max_tokens, :batch_size, :use_chat_template

def initialize(
path_or_hf_repo,
max_tokens: nil,
batch_size: 8,
use_chat_template: nil,
trust_remote_code: false,
sampler: nil
)
tokenizer_config = trust_remote_code ? { "trust_remote_code" => true } : nil
@model, @tokenizer = LoadUtils.load(path_or_hf_repo, tokenizer_config: tokenizer_config)
@max_tokens = max_tokens
@batch_size = batch_size
@sampler = sampler
@use_chat_template = if use_chat_template.nil?
tokenizer.respond_to?(:has_chat_template) && tokenizer.has_chat_template
else
use_chat_template
end
end

def tokenizer_name
name = if tokenizer.respond_to?(:name_or_path)
tokenizer.name_or_path
else
tokenizer.class.name
end
name.to_s.gsub("/", "__")
end

def generate(prompt, max_tokens: nil, sampler: nil, **kwargs)
options = kwargs.dup
options[:max_tokens] = max_tokens || self.max_tokens || DEFAULT_MAX_TOKENS
options[:sampler] = sampler || @sampler
Generate.generate(model, tokenizer, prompt, **options)
end

def stream_generate(prompt, max_tokens: nil, sampler: nil, **kwargs)
options = kwargs.dup
options[:max_tokens] = max_tokens || self.max_tokens || DEFAULT_MAX_TOKENS
options[:sampler] = sampler || @sampler
Generate.stream_generate(model, tokenizer, prompt, **options)
end
end
end
73 changes: 70 additions & 3 deletions lib/mlx_lm/generate.rb
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,54 @@ module MlxLm
keyword_init: true
)

# Batch-level generation statistics.
BatchStats = Struct.new(
:prompt_tokens,
:generation_tokens,
:prompt_tps,
:generation_tps,
:peak_memory,
keyword_init: true
)

# Response item for batched generation streams.
BatchResponse = Struct.new(
:index,
:response,
keyword_init: true
)

# Python API compatibility alias for generation response objects.
Response = GenerationResponse

# Small batch container used by batch generation APIs.
class Batch
attr_reader :items

def initialize(items)
@items = Array(items)
end

def size
@items.size
end
end

# Enumerable wrapper for batched response streams.
class BatchGenerator
include Enumerable

def initialize(enum)
@enum = enum
end

def each(&block)
return @enum.each unless block

@enum.each(&block)
end
end

module Generate
module_function

Expand Down Expand Up @@ -127,16 +175,13 @@ def stream_generate(model, tokenizer, prompt, max_tokens: 256, **kwargs)

Enumerator.new do |yielder|
n = 0
last_token = nil
token_generator.each do |token, logprobs|
if n == 0
prompt_time = Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic
prompt_tps = prompt.size.to_f / [prompt_time, 1e-9].max
tic = Process.clock_gettime(Process::CLOCK_MONOTONIC)
end

last_token = token

if tokenizer.eos_token_ids.include?(token)
detokenizer.finalize
elapsed = [Process.clock_gettime(Process::CLOCK_MONOTONIC) - tic, 1e-9].max
Expand Down Expand Up @@ -200,5 +245,27 @@ def generate(model, tokenizer, prompt, verbose: false, **kwargs)
end
text
end

# Simple batched generation helper.
# Returns an Array<String> with one completion per prompt.
def batch_generate(model, tokenizer, prompts, **kwargs)
batch = Batch.new(prompts)
batch.items.map do |prompt|
generate(model, tokenizer, prompt, **kwargs)
end
end

# Batched streaming helper that yields BatchResponse items.
def stream_batch_generate(model, tokenizer, prompts, **kwargs)
batch = Batch.new(prompts)
enum = Enumerator.new do |yielder|
batch.items.each_with_index do |prompt, idx|
stream_generate(model, tokenizer, prompt, **kwargs).each do |resp|
yielder << BatchResponse.new(index: idx, response: resp)
end
end
end
BatchGenerator.new(enum)
end
end
end
Loading
Loading