diff --git a/.github/workflows/gpu-tests.yml b/.github/workflows/gpu-tests.yml new file mode 100644 index 0000000000..642f36da23 --- /dev/null +++ b/.github/workflows/gpu-tests.yml @@ -0,0 +1,31 @@ +name: GPU Tests + +on: + pull_request: + branches: + - main + paths: + - 'open_instruct/grpo_fast.py' + - 'open_instruct/vllm_utils3.py' + merge_group: + workflow_dispatch: + +jobs: + gpu-tests: + name: Run GPU tests + runs-on: GPU-Enabled-Runner + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Install uv + uses: astral-sh/setup-uv@v5 + with: + version: "0.8.6" + + - name: Set up Python + run: uv sync --frozen + + - name: Run GPU tests + run: | + uv run --frozen pytest -xvs tests/*_gpu.py open_instruct/*_gpu.py \ No newline at end of file diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7319659c7d..5e8c1795f8 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -194,7 +194,7 @@ class Args: warmup_ratio: float = 0.0 """Ratio of warmup steps to total steps (takes precedence over `warm_up_steps`)""" weight_decay: float = 0.0 - """Weight decay for AdamW if we apply some.""" + """Weight decay for AdamW.""" set_weight_decay_on_bias_and_norm: bool = True """Whether to set weight decay on bias and norm layers""" fused_optimizer: bool = False diff --git a/open_instruct/test_grpo_fast.py b/open_instruct/test_grpo_fast.py index 6d8929a012..b848af4be1 100644 --- a/open_instruct/test_grpo_fast.py +++ b/open_instruct/test_grpo_fast.py @@ -4,15 +4,11 @@ from unittest.mock import Mock import ray -import torch from parameterized import parameterized from ray.util import queue as ray_queue -from transformers import AutoTokenizer -from vllm import SamplingParams from open_instruct import grpo_fast, model_utils, utils from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo -from open_instruct.vllm_utils3 import create_vllm_engines class TestGrpoFastBase(unittest.TestCase): @@ -177,78 +173,6 @@ def setup_and_split_batch(self, queries, ground_truths, datasets, indices, num_e class TestGrpoFastVLLM(TestGrpoFastBase): - def test_vllm_queue_system_single_prompt(self): - """Test the new queue-based vLLM system with a single prompt 'What is the capital of France?'""" - # Check if CUDA is available - if not torch.cuda.is_available(): - self.skipTest("CUDA is not available, skipping test") - - # Set up tokenizer - tokenizer_name = "EleutherAI/pythia-14m" # Using a small model for testing - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - - # Tokenize the test prompt - test_prompt = "What is the capital of France?" - prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0] - - # Create Ray queues - param_prompt_Q = ray_queue.Queue(maxsize=1) - inference_results_Q = ray_queue.Queue(maxsize=1) - - # Track queues for cleanup - self._ray_queues.extend([param_prompt_Q, inference_results_Q]) - - # Create vLLM engines with queues - vllm_engines = create_vllm_engines( - num_engines=1, - tensor_parallel_size=1, - enforce_eager=True, - tokenizer_name_or_path=tokenizer_name, - pretrain=tokenizer_name, - revision="main", - seed=42, - enable_prefix_caching=False, - max_model_len=512, - vllm_gpu_memory_utilization=0.5, # Use less GPU memory for testing - prompt_queue=param_prompt_Q, - results_queue=inference_results_Q, - ) - - # Set up generation config - generation_config = SamplingParams( - temperature=0.0, # Deterministic generation - top_p=1.0, - max_tokens=5, - seed=42, - ) - - # Start vLLM engines to process from queues - [e.process_from_queue.remote() for e in vllm_engines] - - # Put the test prompt in the queue using PromptRequest - param_prompt_Q.put( - PromptRequest(prompts=[prompt_token_ids], dataset_index=0, sampling_params=generation_config) - ) - - # Get the result - result = inference_results_Q.get() - - # Verify it's a GenerationResult dataclass - self.assertIsInstance(result, GenerationResult) - - # Check that we got a response - self.assertGreater(len(result.responses), 0) - response_ids = result.responses[0] - - # Decode the response - generated_text = tokenizer.decode(response_ids, skip_special_tokens=True) - - self.assertIsInstance(generated_text, str) - self.assertGreater(len(generated_text), 0) - - # Send stop signal - param_prompt_Q.put(None) - @parameterized.expand([(1, 16), (2, 32), (4, 64), (8, 128)]) def test_batch_splitting_and_engine_configurations(self, vllm_num_engines: int, num_unique_prompts_rollout: int): """Test batch splitting and accumulation with various engine configurations.""" diff --git a/open_instruct/test_grpo_fast_gpu.py b/open_instruct/test_grpo_fast_gpu.py new file mode 100644 index 0000000000..ccd274eccb --- /dev/null +++ b/open_instruct/test_grpo_fast_gpu.py @@ -0,0 +1,83 @@ +import gc +import unittest + +import ray +import torch +from ray.util import queue as ray_queue +from transformers import AutoTokenizer +from vllm import SamplingParams + +from open_instruct import utils, vllm_utils3 +from open_instruct.queue_types import GenerationResult, PromptRequest +from open_instruct.vllm_utils3 import create_vllm_engines + + +class TestGrpoFastGPUBase(unittest.TestCase): + """Base class with common test utilities for GPU tests.""" + + def setUp(self): + """Initialize Ray and check for pre-existing leaks.""" + if not torch.cuda.is_available(): + self.skipTest("CUDA is not available, skipping test") + + ray.init(include_dashboard=False) + + def tearDown(self): + ray.shutdown() + + gc.collect() + + utils.check_runtime_leaks() + + +class TestGrpoFastVLLMGPU(TestGrpoFastGPUBase): + def test_vllm_queue_system_single_prompt(self): + """Test the new queue-based vLLM system with a single prompt.""" + tokenizer_name = "EleutherAI/pythia-14m" + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + test_prompt = "What is the capital of France?" + prompt_token_ids = tokenizer.encode(test_prompt, return_tensors="pt").tolist()[0] + param_prompt_Q = ray_queue.Queue(maxsize=1) + inference_results_Q = ray_queue.Queue(maxsize=1) + actor_manager = vllm_utils3.ActorManager.remote() + vllm_engines = create_vllm_engines( + num_engines=1, + tensor_parallel_size=1, + enforce_eager=True, + tokenizer_name_or_path=tokenizer_name, + pretrain=tokenizer_name, + revision="main", + seed=42, + enable_prefix_caching=False, + max_model_len=512, + vllm_gpu_memory_utilization=0.5, + prompt_queue=param_prompt_Q, + results_queue=inference_results_Q, + actor_manager=actor_manager, + ) + + generation_config = SamplingParams(temperature=0.0, top_p=1.0, max_tokens=5, n=1) + + param_prompt_Q.put( + PromptRequest( + prompts=[prompt_token_ids], generation_config=generation_config, dataset_index=[0], training_step=0 + ) + ) + + ray.get(vllm_engines[0].process_from_queue.remote(timeout=30)) + result = inference_results_Q.get_nowait() + + self.assertIsInstance(result, GenerationResult) + self.assertIsNotNone(result.responses) + self.assertEqual(len(result.responses), 1) + self.assertEqual(result.dataset_index, [0]) + + response_ids = result.responses[0] + + generated_text = tokenizer.decode(response_ids, skip_special_tokens=True) + + self.assertIsInstance(generated_text, str) + self.assertGreater(len(generated_text), 0) + + for queue in [param_prompt_Q, inference_results_Q]: + queue.shutdown() diff --git a/pyproject.toml b/pyproject.toml index 67553aa760..e46327f8b7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ dev = [ ] [tool.pytest.ini_options] -addopts = "--ignore=oe-eval-internal/" +addopts = "--ignore=oe-eval-internal/ --ignore-glob='**/*_gpu.py'" [tool.black] diff --git a/tests/test_padding_free.py b/tests/test_padding_free_gpu.py similarity index 100% rename from tests/test_padding_free.py rename to tests/test_padding_free_gpu.py