diff --git a/conftest.py b/conftest.py index 0cb41963..9402c375 100755 --- a/conftest.py +++ b/conftest.py @@ -4,7 +4,7 @@ def pytest_addoption(parser): parser.addoption("--device", action="store", default="cuda:0") parser.addoption( - "--gpus", + "--tp", action="store", type=int, default="1", diff --git a/tests/test_vllm.py b/tests/test_vllm.py index fda505e7..f9bb9e12 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -39,15 +39,15 @@ def __call__(self, graph: "Graph") -> None: @pytest.fixture(scope="module") -def gpus(request): - gpus = request.config.getoption("--gpus") - if gpus > torch.cuda.device_count(): - pytest.exit("--gpus can be higher than the number of availale GPUs.") - return gpus +def tp(request): + tp = request.config.getoption("--tp") + if tp > torch.cuda.device_count() or tp < 1: + pytest.exit("--tp can't be higher than the number of availale GPUs.") + return tp @pytest.fixture(scope="module") -def vllm_gpt2(gpus: int): - return VLLM("gpt2", tensor_parallel_size=gpus, dispatch=True) +def vllm_gpt2(tp: int): + return VLLM("gpt2", tensor_parallel_size=tp, dispatch=True) @pytest.fixture def ET_prompt(): @@ -170,7 +170,7 @@ def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: s assert len(MSG_logits.value) == 5 """ -def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: str): +def test_mutli_token_generation_with_intervention(tp, vllm_gpt2, MSG_prompt: str): with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=5) as tracer: logits = nnsight.list().save() hs_list = nnsight.list().save() @@ -184,7 +184,7 @@ def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: s assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False] - if gpus == 1: + if tp == 1: assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The'] @@ -197,9 +197,9 @@ def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: s assert not torch.equal(act_in, act_in_other) """ -def test_tensor_parallelism(gpus, vllm_gpt2, ET_prompt: str): - if gpus < 2: - pytest.skip("Must sp") +def test_tensor_parallelism(tp, vllm_gpt2, ET_prompt: str): + if tp < 2: + pytest.skip("Skipping test for tp>1!") with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0): vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0