-
Notifications
You must be signed in to change notification settings - Fork 456
Adds unit tests which run on GPU for open-instruct. #905
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
finbarrtimbers
wants to merge
139
commits into
main
Choose a base branch
from
gpu-tests
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
139 commits
Select commit
Hold shift + click to select a range
b133e2a
Cleaned up evals to have same names as training data.
finbarrtimbers 1265789
Refactored evals to use a batch.
finbarrtimbers aa76132
Now, we accumulate eval results.
finbarrtimbers 4e30841
Merge branch 'main' into fix-eval
finbarrtimbers 66af972
Updated scripts so they run.
finbarrtimbers cfa55c9
More refactoring.
finbarrtimbers 433242a
Now, use the minimum of the number of requested samples and the actua…
finbarrtimbers 0836fca
Ran linter, and fixed extra arg issue.
finbarrtimbers 8028a31
Always insert into pending_queries_map.
finbarrtimbers 9816f34
Update signature in eval.
finbarrtimbers 97b8de9
Merge branch 'main' into fix-eval
finbarrtimbers e862a14
Another attempted fix.
finbarrtimbers 9676db3
Ran linter.
finbarrtimbers d044278
Now, eval requests use the eval params, and normal ones use the norma…
finbarrtimbers 6a694bf
Now, tests should pass.
finbarrtimbers f45b951
Merge branch 'main' into fix-eval
finbarrtimbers 96df985
Remove simple config and pass generation_config through.
finbarrtimbers b931a35
Now, generation config is passed through.
finbarrtimbers aa0facb
Ran linter.
finbarrtimbers 9dd0711
Ran linter.
finbarrtimbers cbf7aa7
Added a while loop.
finbarrtimbers 84b9a4c
Added a while loop with retries.
finbarrtimbers 93c0a97
Merge branch 'main' into fix-eval
finbarrtimbers 87aa0fa
Added logs.
finbarrtimbers b636127
Fix queue issue.
finbarrtimbers d0f8870
Add progress bars to all ray.get calls.
finbarrtimbers 9f9e644
Merge branch 'main' into fix-eval
finbarrtimbers 08de6ea
Cleaned up some of the logging.
finbarrtimbers 634e1fb
Changed how we handle full queues.
finbarrtimbers ada6556
Ran linter.
finbarrtimbers c29d1d0
Clean up for PR.
finbarrtimbers 95960bc
Switched LLMRayActor to use LLMEngine.
finbarrtimbers 4be2693
Fixes expected output.
finbarrtimbers d2c1db7
Keep backwards compatibility for tool use.
finbarrtimbers c1fdd90
Remove manual reorganization.
finbarrtimbers 45791df
Cleaned up implementation.
finbarrtimbers ad45c6a
Now, we use a generation loop.
finbarrtimbers 5e1c2a6
Uses an ActorManager to manage weight updates.
finbarrtimbers c13f951
Cleaned up code to use actor manager.
finbarrtimbers c4dde78
Now, tests pass.
finbarrtimbers b0fd4f2
Fixed error when calling process_from_queue.
finbarrtimbers d93d4a8
Added ActorManager.
finbarrtimbers a444508
Added a test for the actor manager.
finbarrtimbers 10bc07a
Tests pass. Fixed another issue.
finbarrtimbers 4805d46
Ran linter.
finbarrtimbers 30bdce2
Added better error handling.
finbarrtimbers 28ceca9
Potential fix to hanging forever issue.
finbarrtimbers b1053b8
Another attempt to fix the deadlock.
finbarrtimbers a7be9bf
Fix code so that it no longer expects process_from_queue to return a …
finbarrtimbers 8343f92
Fixed cleanup code.
finbarrtimbers 54357cd
Fixed issue; now should exit.
finbarrtimbers 27198ef
Added test scripts.
finbarrtimbers ed467aa
Break out requests into N separate ones.
finbarrtimbers 1276c0f
Found why LLMEngine behaviour differs from LLM. Fixed issue.
finbarrtimbers 4811ca4
Code runs now.
finbarrtimbers 72bfb59
Merge branch 'main' into async-updates
finbarrtimbers a01b414
Linter passes.
finbarrtimbers bdb376f
Merge branch 'main' into async-updates
finbarrtimbers 16dad63
Now, linter passes.
finbarrtimbers 0ce50cb
Now, tests should pass.
finbarrtimbers 8a45ad9
Cleanup.
finbarrtimbers 807cd30
More cleanup
finbarrtimbers fac1d8a
Removed tests.
finbarrtimbers 620752a
Removed test.
finbarrtimbers 7614a02
Removed debugging code.
finbarrtimbers 8fb5209
Now, we share classes.
finbarrtimbers 8f1ef3b
Update TOolLLMRayActor.
finbarrtimbers f0d7cc1
Cleaned up PR.
finbarrtimbers 8ea458c
Incporated tool use.
finbarrtimbers 678ece1
Combined loops.
finbarrtimbers 069eccb
Clean up.
finbarrtimbers a2a1db4
More cleanup.
finbarrtimbers 5edded3
More clean up.
finbarrtimbers a67d197
more clean up.
finbarrtimbers 3b7aab7
Ran linter.
finbarrtimbers b4957b6
More clean up.
finbarrtimbers 85277c4
Merge branch 'main' into async-updates
finbarrtimbers 315b32d
Tests pass.
finbarrtimbers 5cea20a
Now, linter passes.
finbarrtimbers a3361f9
Merge branch 'main' into async-updates
finbarrtimbers 14757e3
Now, tests pass.
finbarrtimbers 463faab
Attempt at fix.
finbarrtimbers 5c34471
Changes.
finbarrtimbers 509467c
add tool debug script
hamishivi d14195c
Merge branch 'main' into async-updates
finbarrtimbers de3e466
Added logging.
finbarrtimbers 3332123
Merge branch 'main' into async-updates
finbarrtimbers aef83b3
Now, flag order is consistent in files.
finbarrtimbers d3b2b65
Added local test script.
finbarrtimbers 3e4fa13
Updated script.
finbarrtimbers 86cf3dd
Script runs.
finbarrtimbers fc941ae
Updated script
finbarrtimbers cf21c85
Added tool test script.
finbarrtimbers b3eb1c9
Added single gpu test script.
finbarrtimbers f0f3127
Updated local_grpo_test.py.
finbarrtimbers 4228488
Fixed hanging issue.
finbarrtimbers a50624b
Updated cleanup.
finbarrtimbers 23d7226
More debugging.
finbarrtimbers bd65acb
Restored ToolUseLLM.
finbarrtimbers a7a514f
Cleaned up logging.
finbarrtimbers 48ea7a3
Updated file.
finbarrtimbers c55faf2
now runs
finbarrtimbers ace8e51
Updated tool script.
finbarrtimbers 65eb61f
Ran linter.
finbarrtimbers 0b742fe
Ran linter, cleaned up code.
finbarrtimbers 074b510
Merge branch 'main' into async-updates
finbarrtimbers f00ec80
Cleaned up code.
finbarrtimbers fe408d7
Cleaned up PR.
finbarrtimbers 5aaa2de
Cleaned up PR.
finbarrtimbers a7edefa
Cleaned up code.
finbarrtimbers 82852af
Rearranged Dockerfile for better caching.
finbarrtimbers d7bda0d
Fixed error
finbarrtimbers 9241c9b
now, passes in args to generate_thread.
finbarrtimbers 0d1cc2f
moved git commands to bottom of Dockerfile
finbarrtimbers c5c2064
Merge branch 'main' into async-updates
finbarrtimbers b862ae2
Trying to fix flashinfer
finbarrtimbers d640258
Merge branch 'main' into async-updates
finbarrtimbers abe6854
Ran linter.
finbarrtimbers 446b5e0
Broke tests out into gpu enabled tests.
finbarrtimbers fe5e695
Added GPU tests.
finbarrtimbers 8f56f7a
Added tests.
finbarrtimbers efd6279
Updated workflow.
finbarrtimbers 791166d
update image name.
finbarrtimbers 951d218
Fixed GPU tests (hopefully).
finbarrtimbers 075b847
Moved to use smaller GPUs.
finbarrtimbers d167db3
Fixed GPU tests.
finbarrtimbers d296bf4
Updated timeout.
finbarrtimbers f7e2af0
ANother attempt to fix tests.
finbarrtimbers 3b87be0
Merge branch 'main' into gpu-tests
finbarrtimbers ee5a927
Cleaned up PR.
finbarrtimbers 5dc1a29
Cleaned up PR.
finbarrtimbers b072843
Merge branch 'main' into gpu-tests
finbarrtimbers dcd555b
Cleaned up PR...
finbarrtimbers 1cab101
Cleaned up PR.
finbarrtimbers 82ec769
Merge branch 'main' into gpu-tests
finbarrtimbers 8297c9e
Linter passes.
finbarrtimbers ead5ba6
Merge branch 'main' into gpu-tests
finbarrtimbers 22365a0
Now use gpu runner.
finbarrtimbers 810c948
Updated code to trigger tests.
finbarrtimbers File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,31 @@ | ||
| name: GPU Tests | ||
|
|
||
| on: | ||
| pull_request: | ||
| branches: | ||
| - main | ||
| paths: | ||
| - 'open_instruct/grpo_fast.py' | ||
| - 'open_instruct/vllm_utils3.py' | ||
| merge_group: | ||
| workflow_dispatch: | ||
|
|
||
| jobs: | ||
| gpu-tests: | ||
| name: Run GPU tests | ||
| runs-on: GPU-Enabled-Runner | ||
| steps: | ||
| - name: Checkout code | ||
| uses: actions/checkout@v4 | ||
|
|
||
| - name: Install uv | ||
| uses: astral-sh/setup-uv@v5 | ||
| with: | ||
| version: "0.8.6" | ||
|
|
||
| - name: Set up Python | ||
| run: uv sync --frozen | ||
|
|
||
| - name: Run GPU tests | ||
| run: | | ||
| uv run --frozen pytest -xvs tests/*_gpu.py open_instruct/*_gpu.py |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| import gc | ||
| import unittest | ||
|
|
||
| import ray | ||
| import torch | ||
| from ray.util import queue as ray_queue | ||
| from transformers import AutoTokenizer | ||
| from vllm import SamplingParams | ||
|
|
||
| from open_instruct import utils, vllm_utils3 | ||
| from open_instruct.queue_types import GenerationResult, PromptRequest | ||
| from open_instruct.vllm_utils3 import create_vllm_engines | ||
|
|
||
|
|
||
| class TestGrpoFastGPUBase(unittest.TestCase): | ||
| """Base class with common test utilities for GPU tests.""" | ||
|
|
||
| def setUp(self): | ||
| """Initialize Ray and check for pre-existing leaks.""" | ||
| if not torch.cuda.is_available(): | ||
| self.skipTest("CUDA is not available, skipping test") | ||
|
|
||
| ray.init(include_dashboard=False) | ||
|
|
||
| def tearDown(self): | ||
| ray.shutdown() | ||
|
|
||
| gc.collect() | ||
|
|
||
| utils.check_runtime_leaks() | ||
|
|
||
|
|
||
| class TestGrpoFastVLLMGPU(TestGrpoFastGPUBase): | ||
| def test_vllm_queue_system_single_prompt(self): | ||
| """Test the new queue-based vLLM system with a single prompt.""" | ||
| tokenizer_name = "EleutherAI/pythia-14m" | ||
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | ||
| test_prompt = "What is the capital of France?" | ||
| prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0] | ||
| param_prompt_Q = ray_queue.Queue(maxsize=1) | ||
| inference_results_Q = ray_queue.Queue(maxsize=1) | ||
| actor_manager = vllm_utils3.ActorManager.remote() | ||
| vllm_engines = create_vllm_engines( | ||
| num_engines=1, | ||
| tensor_parallel_size=1, | ||
| enforce_eager=True, | ||
| tokenizer_name_or_path=tokenizer_name, | ||
| pretrain=tokenizer_name, | ||
| revision="main", | ||
| seed=42, | ||
| enable_prefix_caching=False, | ||
| max_model_len=512, | ||
| vllm_gpu_memory_utilization=0.5, | ||
| prompt_queue=param_prompt_Q, | ||
| results_queue=inference_results_Q, | ||
| actor_manager=actor_manager, | ||
| ) | ||
|
|
||
| generation_config = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=5, n=1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you think its possible to get a deterministic output for Pythia 14m and specifically test for the expected completion for this prompt? Not sure how important that is |
||
|
|
||
| param_prompt_Q.put( | ||
| PromptRequest( | ||
| prompts=[prompt_token_ids], generation_config=generation_config, dataset_index=[0], training_step=0 | ||
| ) | ||
| ) | ||
|
|
||
| ray.get(vllm_engines[0].process_from_queue.remote(timeout=30)) | ||
| result = inference_results_Q.get_nowait() | ||
|
|
||
| self.assertIsInstance(result, GenerationResult) | ||
| self.assertIsNotNone(result.responses) | ||
| self.assertEqual(len(result.responses), 1) | ||
| self.assertEqual(result.dataset_index, [0]) | ||
|
|
||
| response_ids = result.responses[0] | ||
|
|
||
| generated_text = tokenizer.decode(response_ids, skip_special_tokens=True) | ||
|
|
||
| self.assertIsInstance(generated_text, str) | ||
| self.assertGreater(len(generated_text), 0) | ||
|
|
||
| for queue in [param_prompt_Q, inference_results_Q]: | ||
| queue.shutdown() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure but feels like the tokenizer and vllm engine could be in the setup of the test