Skip to content

Commit

Permalink
change naming for vllm test var
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamBelfki3 committed Dec 5, 2024
1 parent 11baa38 commit 76fb4ce
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
2 changes: 1 addition & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cuda:0")
parser.addoption(
"--gpus",
"--tp",
action="store",
type=int,
default="1",
Expand Down
24 changes: 12 additions & 12 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def __call__(self, graph: "Graph") -> None:


@pytest.fixture(scope="module")
def gpus(request):
gpus = request.config.getoption("--gpus")
if gpus > torch.cuda.device_count():
pytest.exit("--gpus can be higher than the number of availale GPUs.")
return gpus
def tp(request):
tp = request.config.getoption("--tp")
if tp > torch.cuda.device_count() or tp < 1:
pytest.exit("--tp can't be higher than the number of availale GPUs.")
return tp

@pytest.fixture(scope="module")
def vllm_gpt2(gpus: int):
return VLLM("gpt2", tensor_parallel_size=gpus, dispatch=True)
def vllm_gpt2(tp: int):
return VLLM("gpt2", tensor_parallel_size=tp, dispatch=True)

@pytest.fixture
def ET_prompt():
Expand Down Expand Up @@ -170,7 +170,7 @@ def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: s
assert len(MSG_logits.value) == 5 """


def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: str):
def test_mutli_token_generation_with_intervention(tp, vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=5) as tracer:
logits = nnsight.list().save()
hs_list = nnsight.list().save()
Expand All @@ -184,7 +184,7 @@ def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: s

assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False]

if gpus == 1:
if tp == 1:
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The']


Expand All @@ -197,9 +197,9 @@ def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: s
assert not torch.equal(act_in, act_in_other) """


def test_tensor_parallelism(gpus, vllm_gpt2, ET_prompt: str):
if gpus < 2:
pytest.skip("Must sp")
def test_tensor_parallelism(tp, vllm_gpt2, ET_prompt: str):
if tp < 2:
pytest.skip("Skipping test for tp>1!")

with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0):
vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0
Expand Down

0 comments on commit 76fb4ce

Please sign in to comment.