From f578859de6a8f32df812e7ffb1c62d32212b1a12 Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Thu, 5 Dec 2024 11:55:23 -0500 Subject: [PATCH 1/2] test (vllm): Adding vllm dedicated test folder --- conftest.py | 8 +- src/nnsight/intervention/graph/graph.py | 4 +- src/nnsight/modeling/vllm/sampling.py | 5 +- src/nnsight/modeling/vllm/vllm.py | 12 -- tests/test_vllm.py | 213 ++++++++++++++++++++++++ 5 files changed, 225 insertions(+), 17 deletions(-) create mode 100644 tests/test_vllm.py diff --git a/conftest.py b/conftest.py index 86a26183..0cb41963 100755 --- a/conftest.py +++ b/conftest.py @@ -3,7 +3,13 @@ def pytest_addoption(parser): parser.addoption("--device", action="store", default="cuda:0") - + parser.addoption( + "--gpus", + action="store", + type=int, + default="1", + help="An argument for specifying the number of gpus to be used by VLLM" + ) def pytest_generate_tests(metafunc): # This is called for every test. Only get/set command line arguments diff --git a/src/nnsight/intervention/graph/graph.py b/src/nnsight/intervention/graph/graph.py index 7a6edf86..a4b5b4ae 100755 --- a/src/nnsight/intervention/graph/graph.py +++ b/src/nnsight/intervention/graph/graph.py @@ -359,7 +359,7 @@ def copy( new_graph.compiled = self.compiled for key, value in self.call_counter.items(): - self.call_counter[memo[key]] = value + new_graph.call_counter[memo[key]] = value if new_graph.compiled: @@ -371,7 +371,7 @@ def copy( for key, values in self.deferred.items(): - new_graph[memo[key]] = [memo[index] for index in values] + new_graph.deferred[memo[key]] = [memo[index] for index in values] new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph] diff --git a/src/nnsight/modeling/vllm/sampling.py b/src/nnsight/modeling/vllm/sampling.py index 293bbdbd..8df9726b 100755 --- a/src/nnsight/modeling/vllm/sampling.py +++ b/src/nnsight/modeling/vllm/sampling.py @@ -138,10 +138,11 @@ def prepare( intervention_graph = None elif n_graphs == 1: intervention_graph =intervention_graphs[0] - else: + + """ else: intervention_graph = MultiGraph(intervention_graphs.values()) - InterventionProtocol.shift(intervention_graph) + InterventionProtocol.shift(intervention_graph) """ ########################################### diff --git a/src/nnsight/modeling/vllm/vllm.py b/src/nnsight/modeling/vllm/vllm.py index c3c34c5d..08b6448f 100644 --- a/src/nnsight/modeling/vllm/vllm.py +++ b/src/nnsight/modeling/vllm/vllm.py @@ -82,13 +82,8 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module": ) # creating the vllm engine configuration - """ dict( - (field.name, getattr(self, field.name)) for field in fields(self)) """ vllm_config = engine_args.create_engine_config() vllm_config_dict = {field.name: getattr(vllm_config, field.name) for field in fields(type(vllm_config))} - - #dict((field.name, getattr(self, field.name)) for field in fields(engine_args.create_engine_config())) - #vllm_config_dict = engine_args.create_engine_config().to_dict() # starting the distributed environment init_distributed_environment( @@ -103,13 +98,6 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module": initialize_model_parallel(backend="gloo") # initialize the model - """ model = _initialize_model( - model_config=vllm_config_dict["model_config"], - load_config=vllm_config_dict["load_config"], - lora_config=None, - cache_config=vllm_config_dict["cache_config"], - scheduler_config=vllm_config_dict["scheduler_config"], - ) """ model = _initialize_model(vllm_config) diff --git a/tests/test_vllm.py b/tests/test_vllm.py new file mode 100644 index 00000000..fda505e7 --- /dev/null +++ b/tests/test_vllm.py @@ -0,0 +1,213 @@ +import pytest +import nnsight +import torch +from typing import TYPE_CHECKING + +from nnsight.tracing.backends import Backend +from nnsight.tracing.protocols import StopProtocol + +if TYPE_CHECKING: + from nnsight.tracing.graph import Graph + +try: + from nnsight.modeling.vllm import VLLM +except: + pytest.skip("Skipping VLLM tests", allow_module_level=True) + + +class AssertSavedLenBackend(Backend): + + def __init__(self, len:int) -> None: + self.len = len + + def __call__(self, graph: "Graph") -> None: + + try: + + graph.nodes[-1].execute() + + except StopProtocol.StopException: + + pass + + finally: + + assert self.len == len([node for node in graph.nodes if node.done]) + + graph.nodes.clear() + graph.stack.clear() + + +@pytest.fixture(scope="module") +def gpus(request): + gpus = request.config.getoption("--gpus") + if gpus > torch.cuda.device_count(): + pytest.exit("--gpus can be higher than the number of availale GPUs.") + return gpus + +@pytest.fixture(scope="module") +def vllm_gpt2(gpus: int): + return VLLM("gpt2", tensor_parallel_size=gpus, dispatch=True) + +@pytest.fixture +def ET_prompt(): + return "The Eiffel Tower is located in the city of" + +@pytest.fixture +def MSG_prompt(): + return "Madison Square Garden is located in the city of" + + +def test_single_logit(vllm_gpt2, ET_prompt: str): + with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(1)): + logits = vllm_gpt2.logits.output.save() + + next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1)) + assert next_token == " Paris" + + +def test_multi_token_generation(vllm_gpt2, MSG_prompt: str): + with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=3): + logits = nnsight.list().save() + for ii in range(3): + logits.append(vllm_gpt2.logits.output) + vllm_gpt2.logits.next() + + assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [" New", " York", " City"] + + +""" def test_max_token_generation(vllm_gpt2, ET_prompt: str): + with vllm_gpt2.trace(ET_prompt, max_tokens=10): + logits = nnsight.list().save() + with vllm_gpt2.logits.all(): + logits.append(vllm_gpt2.logits.output) + + assert len(logits) == 10 """ + + +""" def test_sampling(vllm_gpt2, ET_prompt:str): + with vllm_gpt2.trace(ET_prompt, temperature=0.8, top_p=0.95, max_tokens=3): + samples = nnsight.list().save() + with vllm_gpt2.sample.all(): + li.append(vllm_gpt2.sample.output) + + samples = vllm_gpt2.batch_decode([sample.argmax(dim=-1) for sample in samples]) + assert samples == [' Canary', ' Wh', 'arf'] """ + + +def test_intervention(vllm_gpt2, ET_prompt: str): + with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(2)) as tracer: + vllm_gpt2.transformer.h[-2].mlp.output[:] = 0 + hs = vllm_gpt2.transformer.h[-2].mlp.output.save() + logits = vllm_gpt2.logits.output.save() + + next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1)) + assert next_token == " London" + assert torch.all(hs == 0) + + +def test_swap_intervention(vllm_gpt2, ET_prompt: str): + with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(2)) as tracer: + vllm_gpt2.transformer.h[-2].mlp.output = torch.zeros_like(vllm_gpt2.transformer.h[-2].mlp.output) + hs = vllm_gpt2.transformer.h[-2].mlp.output.save() + logits = vllm_gpt2.logits.output.save() + + next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1)) + assert next_token == " London" + assert torch.all(hs == 0) + + +def test_batched_intervention(vllm_gpt2, ET_prompt: str,): + with vllm_gpt2.trace(temperature=0.0, top_p=1, backend=AssertSavedLenBackend(4)) as tracer: + + with tracer.invoke(ET_prompt): + clean_hs = vllm_gpt2.transformer.h[-2].mlp.output.save() + clean_logits = vllm_gpt2.logits.output.save() + with tracer.invoke(ET_prompt): + vllm_gpt2.transformer.h[-2].mlp.output[:] = 0 + corrupted_hs = vllm_gpt2.transformer.h[-2].mlp.output.save() + corrupted_logits = vllm_gpt2.logits.output.save() + + clean_token = vllm_gpt2.tokenizer.decode(clean_logits.argmax(dim=-1)) + corrupted_token = vllm_gpt2.tokenizer.decode(corrupted_logits.argmax(dim=-1)) + + assert clean_token == " Paris" + assert corrupted_token == " London" + assert not torch.all(clean_hs == 0) + assert torch.all(corrupted_hs == 0) + + +def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: str): + with vllm_gpt2.trace() as tracer: + with tracer.invoke(ET_prompt, max_tokens=3): + ET_logits = nnsight.list().save() + for ii in range(3): + ET_logits.append(vllm_gpt2.logits.output) + vllm_gpt2.logits.next() + with tracer.invoke(MSG_prompt, max_tokens=5): + MSG_logits = nnsight.list().save() + for ii in range(5): + MSG_logits.append(vllm_gpt2.logits.output) + vllm_gpt2.logits.next() + + assert len(ET_logits) == 3 + assert len(MSG_logits) == 5 + + +""" def test_batched_multi_token_generation_with_iter(vllm_gpt2, ET_prompt: str, MSG_prompt: str): + with vllm_gpt2.trace(max_tokens=10) as tracer: + with tracer.invoke(ET_prompt): + ET_logits = nnsight.list().save() + with vllm_gpt2.logits.iter[:3]: + ET_logits.append(vllm_gpt2.logits.output) + #vllm_gpt2.output.save() + with tracer.invoke(MSG_prompt, max_tokens=5): + MSG_logits = nnsight.list().save() + with vllm_gpt2.logits.iter[:5]: + MSG_logits.append(vllm_gpt2.logits.output) + + assert len(ET_logits.value) == 3 + assert len(MSG_logits.value) == 5 """ + + +def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: str): + with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=5) as tracer: + logits = nnsight.list().save() + hs_list = nnsight.list().save() + for ii in range(5): + if ii == 2: + vllm_gpt2.transformer.h[-2].output[0][:] = 0 + hs_list.append(vllm_gpt2.transformer.h[-2].output[0]) + vllm_gpt2.transformer.h[-2].next() + logits.append(vllm_gpt2.logits.output) + vllm_gpt2.logits.next() + + assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False] + + if gpus == 1: + assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The'] + + +""" def test_multi_referenced_module(vllm_gpt2, ET_prompt: str): + with vllm_gpt2.trace(ET_prompt): + act_in = vllm_gpt2.transformer.h[0].mlp.act.input.save() + vllm_gpt2.transformer.h[0].mlp.act.next() + act_in_other = vllm_gpt2.transformer.h[1].mlp.act.input.save() + + assert not torch.equal(act_in, act_in_other) """ + + +def test_tensor_parallelism(gpus, vllm_gpt2, ET_prompt: str): + if gpus < 2: + pytest.skip("Must sp") + + with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0): + vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0 + hs = vllm_gpt2.transformer.h[5].mlp.c_fc.output[0].save() + logit = vllm_gpt2.logits.output.save() + + next_token = vllm_gpt2.tokenizer.decode(logit.argmax(dim=-1)) + + #assert next_token != " Paris" + assert hs.shape == torch.Size([11, 3072]) + assert torch.all(hs[:, 2000:] == 0) From 76fb4cec638af15699fbc1b0ea1485140947d5a5 Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Thu, 5 Dec 2024 12:08:52 -0500 Subject: [PATCH 2/2] change naming for vllm test var --- conftest.py | 2 +- tests/test_vllm.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/conftest.py b/conftest.py index 0cb41963..9402c375 100755 --- a/conftest.py +++ b/conftest.py @@ -4,7 +4,7 @@ def pytest_addoption(parser): parser.addoption("--device", action="store", default="cuda:0") parser.addoption( - "--gpus", + "--tp", action="store", type=int, default="1", diff --git a/tests/test_vllm.py b/tests/test_vllm.py index fda505e7..f9bb9e12 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -39,15 +39,15 @@ def __call__(self, graph: "Graph") -> None: @pytest.fixture(scope="module") -def gpus(request): - gpus = request.config.getoption("--gpus") - if gpus > torch.cuda.device_count(): - pytest.exit("--gpus can be higher than the number of availale GPUs.") - return gpus +def tp(request): + tp = request.config.getoption("--tp") + if tp > torch.cuda.device_count() or tp < 1: + pytest.exit("--tp can't be higher than the number of availale GPUs.") + return tp @pytest.fixture(scope="module") -def vllm_gpt2(gpus: int): - return VLLM("gpt2", tensor_parallel_size=gpus, dispatch=True) +def vllm_gpt2(tp: int): + return VLLM("gpt2", tensor_parallel_size=tp, dispatch=True) @pytest.fixture def ET_prompt(): @@ -170,7 +170,7 @@ def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: s assert len(MSG_logits.value) == 5 """ -def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: str): +def test_mutli_token_generation_with_intervention(tp, vllm_gpt2, MSG_prompt: str): with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=5) as tracer: logits = nnsight.list().save() hs_list = nnsight.list().save() @@ -184,7 +184,7 @@ def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: s assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False] - if gpus == 1: + if tp == 1: assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The'] @@ -197,9 +197,9 @@ def test_mutli_token_generation_with_intervention(gpus, vllm_gpt2, MSG_prompt: s assert not torch.equal(act_in, act_in_other) """ -def test_tensor_parallelism(gpus, vllm_gpt2, ET_prompt: str): - if gpus < 2: - pytest.skip("Must sp") +def test_tensor_parallelism(tp, vllm_gpt2, ET_prompt: str): + if tp < 2: + pytest.skip("Skipping test for tp>1!") with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0): vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0