Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
139 commits
Select commit Hold shift + click to select a range
b133e2a
Cleaned up evals to have same names as training data.
finbarrtimbers Jul 29, 2025
1265789
Refactored evals to use a batch.
finbarrtimbers Jul 29, 2025
aa76132
Now, we accumulate eval results.
finbarrtimbers Jul 29, 2025
4e30841
Merge branch 'main' into fix-eval
finbarrtimbers Jul 29, 2025
66af972
Updated scripts so they run.
finbarrtimbers Jul 29, 2025
cfa55c9
More refactoring.
finbarrtimbers Jul 29, 2025
433242a
Now, use the minimum of the number of requested samples and the actua…
finbarrtimbers Jul 29, 2025
0836fca
Ran linter, and fixed extra arg issue.
finbarrtimbers Jul 29, 2025
8028a31
Always insert into pending_queries_map.
finbarrtimbers Jul 29, 2025
9816f34
Update signature in eval.
finbarrtimbers Jul 29, 2025
97b8de9
Merge branch 'main' into fix-eval
finbarrtimbers Jul 29, 2025
e862a14
Another attempted fix.
finbarrtimbers Jul 29, 2025
9676db3
Ran linter.
finbarrtimbers Jul 29, 2025
d044278
Now, eval requests use the eval params, and normal ones use the norma…
finbarrtimbers Jul 29, 2025
6a694bf
Now, tests should pass.
finbarrtimbers Jul 29, 2025
f45b951
Merge branch 'main' into fix-eval
finbarrtimbers Jul 29, 2025
96df985
Remove simple config and pass generation_config through.
finbarrtimbers Jul 30, 2025
b931a35
Now, generation config is passed through.
finbarrtimbers Jul 30, 2025
aa0facb
Ran linter.
finbarrtimbers Jul 30, 2025
9dd0711
Ran linter.
finbarrtimbers Jul 30, 2025
cbf7aa7
Added a while loop.
finbarrtimbers Jul 30, 2025
84b9a4c
Added a while loop with retries.
finbarrtimbers Jul 30, 2025
93c0a97
Merge branch 'main' into fix-eval
finbarrtimbers Jul 30, 2025
87aa0fa
Added logs.
finbarrtimbers Jul 30, 2025
b636127
Fix queue issue.
finbarrtimbers Jul 30, 2025
d0f8870
Add progress bars to all ray.get calls.
finbarrtimbers Jul 30, 2025
9f9e644
Merge branch 'main' into fix-eval
finbarrtimbers Jul 30, 2025
08de6ea
Cleaned up some of the logging.
finbarrtimbers Jul 30, 2025
634e1fb
Changed how we handle full queues.
finbarrtimbers Jul 30, 2025
ada6556
Ran linter.
finbarrtimbers Jul 30, 2025
c29d1d0
Clean up for PR.
finbarrtimbers Jul 30, 2025
95960bc
Switched LLMRayActor to use LLMEngine.
finbarrtimbers Jul 29, 2025
4be2693
Fixes expected output.
finbarrtimbers Jul 29, 2025
d2c1db7
Keep backwards compatibility for tool use.
finbarrtimbers Jul 29, 2025
c1fdd90
Remove manual reorganization.
finbarrtimbers Jul 29, 2025
45791df
Cleaned up implementation.
finbarrtimbers Jul 29, 2025
ad45c6a
Now, we use a generation loop.
finbarrtimbers Jul 29, 2025
5e1c2a6
Uses an ActorManager to manage weight updates.
finbarrtimbers Jul 30, 2025
c13f951
Cleaned up code to use actor manager.
finbarrtimbers Jul 30, 2025
c4dde78
Now, tests pass.
finbarrtimbers Jul 30, 2025
b0fd4f2
Fixed error when calling process_from_queue.
finbarrtimbers Jul 31, 2025
d93d4a8
Added ActorManager.
finbarrtimbers Jul 31, 2025
a444508
Added a test for the actor manager.
finbarrtimbers Jul 31, 2025
10bc07a
Tests pass. Fixed another issue.
finbarrtimbers Jul 31, 2025
4805d46
Ran linter.
finbarrtimbers Jul 31, 2025
30bdce2
Added better error handling.
finbarrtimbers Jul 31, 2025
28ceca9
Potential fix to hanging forever issue.
finbarrtimbers Jul 31, 2025
b1053b8
Another attempt to fix the deadlock.
finbarrtimbers Jul 31, 2025
a7be9bf
Fix code so that it no longer expects process_from_queue to return a …
finbarrtimbers Jul 31, 2025
8343f92
Fixed cleanup code.
finbarrtimbers Jul 31, 2025
54357cd
Fixed issue; now should exit.
finbarrtimbers Jul 31, 2025
27198ef
Added test scripts.
finbarrtimbers Aug 1, 2025
ed467aa
Break out requests into N separate ones.
finbarrtimbers Aug 1, 2025
1276c0f
Found why LLMEngine behaviour differs from LLM. Fixed issue.
finbarrtimbers Aug 1, 2025
4811ca4
Code runs now.
finbarrtimbers Aug 1, 2025
72bfb59
Merge branch 'main' into async-updates
finbarrtimbers Aug 7, 2025
a01b414
Linter passes.
finbarrtimbers Aug 7, 2025
bdb376f
Merge branch 'main' into async-updates
finbarrtimbers Aug 7, 2025
16dad63
Now, linter passes.
finbarrtimbers Aug 7, 2025
0ce50cb
Now, tests should pass.
finbarrtimbers Aug 7, 2025
8a45ad9
Cleanup.
finbarrtimbers Aug 7, 2025
807cd30
More cleanup
finbarrtimbers Aug 7, 2025
fac1d8a
Removed tests.
finbarrtimbers Aug 7, 2025
620752a
Removed test.
finbarrtimbers Aug 7, 2025
7614a02
Removed debugging code.
finbarrtimbers Aug 7, 2025
8fb5209
Now, we share classes.
finbarrtimbers Aug 8, 2025
8f1ef3b
Update TOolLLMRayActor.
finbarrtimbers Aug 8, 2025
f0d7cc1
Cleaned up PR.
finbarrtimbers Aug 8, 2025
8ea458c
Incporated tool use.
finbarrtimbers Aug 11, 2025
678ece1
Combined loops.
finbarrtimbers Aug 12, 2025
069eccb
Clean up.
finbarrtimbers Aug 12, 2025
a2a1db4
More cleanup.
finbarrtimbers Aug 12, 2025
5edded3
More clean up.
finbarrtimbers Aug 12, 2025
a67d197
more clean up.
finbarrtimbers Aug 12, 2025
3b7aab7
Ran linter.
finbarrtimbers Aug 12, 2025
b4957b6
More clean up.
finbarrtimbers Aug 12, 2025
85277c4
Merge branch 'main' into async-updates
finbarrtimbers Aug 12, 2025
315b32d
Tests pass.
finbarrtimbers Aug 12, 2025
5cea20a
Now, linter passes.
finbarrtimbers Aug 12, 2025
a3361f9
Merge branch 'main' into async-updates
finbarrtimbers Aug 12, 2025
14757e3
Now, tests pass.
finbarrtimbers Aug 12, 2025
463faab
Attempt at fix.
finbarrtimbers Aug 12, 2025
5c34471
Changes.
finbarrtimbers Aug 12, 2025
509467c
add tool debug script
hamishivi Aug 13, 2025
d14195c
Merge branch 'main' into async-updates
finbarrtimbers Aug 13, 2025
de3e466
Added logging.
finbarrtimbers Aug 13, 2025
3332123
Merge branch 'main' into async-updates
finbarrtimbers Aug 13, 2025
aef83b3
Now, flag order is consistent in files.
finbarrtimbers Aug 13, 2025
d3b2b65
Added local test script.
finbarrtimbers Aug 14, 2025
3e4fa13
Updated script.
finbarrtimbers Aug 14, 2025
86cf3dd
Script runs.
finbarrtimbers Aug 14, 2025
fc941ae
Updated script
finbarrtimbers Aug 14, 2025
cf21c85
Added tool test script.
finbarrtimbers Aug 14, 2025
b3eb1c9
Added single gpu test script.
finbarrtimbers Aug 14, 2025
f0f3127
Updated local_grpo_test.py.
finbarrtimbers Aug 14, 2025
4228488
Fixed hanging issue.
finbarrtimbers Aug 14, 2025
a50624b
Updated cleanup.
finbarrtimbers Aug 14, 2025
23d7226
More debugging.
finbarrtimbers Aug 14, 2025
bd65acb
Restored ToolUseLLM.
finbarrtimbers Aug 15, 2025
a7a514f
Cleaned up logging.
finbarrtimbers Aug 15, 2025
48ea7a3
Updated file.
finbarrtimbers Aug 15, 2025
c55faf2
now runs
finbarrtimbers Aug 15, 2025
ace8e51
Updated tool script.
finbarrtimbers Aug 15, 2025
65eb61f
Ran linter.
finbarrtimbers Aug 15, 2025
0b742fe
Ran linter, cleaned up code.
finbarrtimbers Aug 15, 2025
074b510
Merge branch 'main' into async-updates
finbarrtimbers Aug 15, 2025
f00ec80
Cleaned up code.
finbarrtimbers Aug 15, 2025
fe408d7
Cleaned up PR.
finbarrtimbers Aug 15, 2025
5aaa2de
Cleaned up PR.
finbarrtimbers Aug 15, 2025
a7edefa
Cleaned up code.
finbarrtimbers Aug 15, 2025
82852af
Rearranged Dockerfile for better caching.
finbarrtimbers Aug 15, 2025
d7bda0d
Fixed error
finbarrtimbers Aug 15, 2025
9241c9b
now, passes in args to generate_thread.
finbarrtimbers Aug 15, 2025
0d1cc2f
moved git commands to bottom of Dockerfile
finbarrtimbers Aug 15, 2025
c5c2064
Merge branch 'main' into async-updates
finbarrtimbers Aug 15, 2025
b862ae2
Trying to fix flashinfer
finbarrtimbers Aug 15, 2025
d640258
Merge branch 'main' into async-updates
finbarrtimbers Aug 16, 2025
abe6854
Ran linter.
finbarrtimbers Aug 16, 2025
446b5e0
Broke tests out into gpu enabled tests.
finbarrtimbers Aug 18, 2025
fe5e695
Added GPU tests.
finbarrtimbers Aug 18, 2025
8f56f7a
Added tests.
finbarrtimbers Aug 18, 2025
efd6279
Updated workflow.
finbarrtimbers Aug 18, 2025
791166d
update image name.
finbarrtimbers Aug 18, 2025
951d218
Fixed GPU tests (hopefully).
finbarrtimbers Aug 18, 2025
075b847
Moved to use smaller GPUs.
finbarrtimbers Aug 18, 2025
d167db3
Fixed GPU tests.
finbarrtimbers Aug 18, 2025
d296bf4
Updated timeout.
finbarrtimbers Aug 18, 2025
f7e2af0
ANother attempt to fix tests.
finbarrtimbers Aug 18, 2025
3b87be0
Merge branch 'main' into gpu-tests
finbarrtimbers Aug 18, 2025
ee5a927
Cleaned up PR.
finbarrtimbers Aug 18, 2025
5dc1a29
Cleaned up PR.
finbarrtimbers Aug 18, 2025
b072843
Merge branch 'main' into gpu-tests
finbarrtimbers Aug 18, 2025
dcd555b
Cleaned up PR...
finbarrtimbers Aug 18, 2025
1cab101
Cleaned up PR.
finbarrtimbers Aug 18, 2025
82ec769
Merge branch 'main' into gpu-tests
finbarrtimbers Aug 18, 2025
8297c9e
Linter passes.
finbarrtimbers Aug 18, 2025
ead5ba6
Merge branch 'main' into gpu-tests
finbarrtimbers Aug 22, 2025
22365a0
Now use gpu runner.
finbarrtimbers Aug 22, 2025
810c948
Updated code to trigger tests.
finbarrtimbers Aug 22, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions .github/workflows/gpu-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
name: GPU Tests

on:
pull_request:
branches:
- main
paths:
- 'open_instruct/grpo_fast.py'
- 'open_instruct/vllm_utils3.py'
merge_group:
workflow_dispatch:

jobs:
gpu-tests:
name: Run GPU tests
runs-on: GPU-Enabled-Runner
steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5
with:
version: "0.8.6"

- name: Set up Python
run: uv sync --frozen

- name: Run GPU tests
run: |
uv run --frozen pytest -xvs tests/*_gpu.py open_instruct/*_gpu.py
2 changes: 1 addition & 1 deletion open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class Args:
warmup_ratio: float = 0.0
"""Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)"""
weight_decay: float = 0.0
"""Weight decay for AdamW if we apply some."""
"""Weight decay for AdamW."""
set_weight_decay_on_bias_and_norm: bool = True
"""Whether to set weight decay on bias and norm layers"""
fused_optimizer: bool = False
Expand Down
76 changes: 0 additions & 76 deletions open_instruct/test_grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,11 @@
from unittest.mock import Mock

import ray
import torch
from parameterized import parameterized
from ray.util import queue as ray_queue
from transformers import AutoTokenizer
from vllm import SamplingParams

from open_instruct import grpo_fast, model_utils, utils
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo
from open_instruct.vllm_utils3 import create_vllm_engines


class TestGrpoFastBase(unittest.TestCase):
Expand Down Expand Up @@ -177,78 +173,6 @@ def setup_and_split_batch(self, queries, ground_truths, datasets, indices, num_e


class TestGrpoFastVLLM(TestGrpoFastBase):
def test_vllm_queue_system_single_prompt(self):
"""Test the new queue-based vLLM system with a single prompt 'What is the capital of France?'"""
# Check if CUDA is available
if not torch.cuda.is_available():
self.skipTest("CUDA is not available, skipping test")

# Set up tokenizer
tokenizer_name = "EleutherAI/pythia-14m" # Using a small model for testing
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

# Tokenize the test prompt
test_prompt = "What is the capital of France?"
prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0]

# Create Ray queues
param_prompt_Q = ray_queue.Queue(maxsize=1)
inference_results_Q = ray_queue.Queue(maxsize=1)

# Track queues for cleanup
self._ray_queues.extend([param_prompt_Q, inference_results_Q])

# Create vLLM engines with queues
vllm_engines = create_vllm_engines(
num_engines=1,
tensor_parallel_size=1,
enforce_eager=True,
tokenizer_name_or_path=tokenizer_name,
pretrain=tokenizer_name,
revision="main",
seed=42,
enable_prefix_caching=False,
max_model_len=512,
vllm_gpu_memory_utilization=0.5, # Use less GPU memory for testing
prompt_queue=param_prompt_Q,
results_queue=inference_results_Q,
)

# Set up generation config
generation_config = SamplingParams(
temperature=0.0, # Deterministic generation
top_p=1.0,
max_tokens=5,
seed=42,
)

# Start vLLM engines to process from queues
[e.process_from_queue.remote() for e in vllm_engines]

# Put the test prompt in the queue using PromptRequest
param_prompt_Q.put(
PromptRequest(prompts=[prompt_token_ids], dataset_index=0, sampling_params=generation_config)
)

# Get the result
result = inference_results_Q.get()

# Verify it's a GenerationResult dataclass
self.assertIsInstance(result, GenerationResult)

# Check that we got a response
self.assertGreater(len(result.responses), 0)
response_ids = result.responses[0]

# Decode the response
generated_text = tokenizer.decode(response_ids, skip_special_tokens=True)

self.assertIsInstance(generated_text, str)
self.assertGreater(len(generated_text), 0)

# Send stop signal
param_prompt_Q.put(None)

@parameterized.expand([(1, 16), (2, 32), (4, 64), (8, 128)])
def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, num_unique_prompts_rollout: int):
"""Test batch splitting and accumulation with various engine configurations."""
Expand Down
83 changes: 83 additions & 0 deletions open_instruct/test_grpo_fast_gpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import gc
import unittest

import ray
import torch
from ray.util import queue as ray_queue
from transformers import AutoTokenizer
from vllm import SamplingParams

from open_instruct import utils, vllm_utils3
from open_instruct.queue_types import GenerationResult, PromptRequest
from open_instruct.vllm_utils3 import create_vllm_engines


class TestGrpoFastGPUBase(unittest.TestCase):
"""Base class with common test utilities for GPU tests."""

def setUp(self):
"""Initialize Ray and check for pre-existing leaks."""
if not torch.cuda.is_available():
self.skipTest("CUDA is not available, skipping test")

ray.init(include_dashboard=False)

def tearDown(self):
ray.shutdown()

gc.collect()

utils.check_runtime_leaks()


class TestGrpoFastVLLMGPU(TestGrpoFastGPUBase):
def test_vllm_queue_system_single_prompt(self):
"""Test the new queue-based vLLM system with a single prompt."""
tokenizer_name = "EleutherAI/pythia-14m"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
test_prompt = "What is the capital of France?"
prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0]
param_prompt_Q = ray_queue.Queue(maxsize=1)
inference_results_Q = ray_queue.Queue(maxsize=1)
actor_manager = vllm_utils3.ActorManager.remote()
vllm_engines = create_vllm_engines(
num_engines=1,
tensor_parallel_size=1,
enforce_eager=True,
tokenizer_name_or_path=tokenizer_name,
pretrain=tokenizer_name,
revision="main",
seed=42,
enable_prefix_caching=False,
max_model_len=512,
vllm_gpu_memory_utilization=0.5,
prompt_queue=param_prompt_Q,
results_queue=inference_results_Q,
actor_manager=actor_manager,
)
Comment on lines +36 to +57
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure but feels like the tokenizer and vllm engine could be in the setup of the test


generation_config = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=5, n=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think its possible to get a deterministic output for Pythia 14m and specifically test for the expected completion for this prompt? Not sure how important that is


param_prompt_Q.put(
PromptRequest(
prompts=[prompt_token_ids], generation_config=generation_config, dataset_index=[0], training_step=0
)
)

ray.get(vllm_engines[0].process_from_queue.remote(timeout=30))
result = inference_results_Q.get_nowait()

self.assertIsInstance(result, GenerationResult)
self.assertIsNotNone(result.responses)
self.assertEqual(len(result.responses), 1)
self.assertEqual(result.dataset_index, [0])

response_ids = result.responses[0]

generated_text = tokenizer.decode(response_ids, skip_special_tokens=True)

self.assertIsInstance(generated_text, str)
self.assertGreater(len(generated_text), 0)

for queue in [param_prompt_Q, inference_results_Q]:
queue.shutdown()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ dev = [
]

[tool.pytest.ini_options]
addopts = "--ignore=oe-eval-internal/"
addopts = "--ignore=oe-eval-internal/ --ignore-glob='**/*_gpu.py'"


[tool.black]
Expand Down
File renamed without changes.
Loading