Skip to content

Commit

Permalink
Merge pull request #301 from ndif-team/vllm-testing
Browse files Browse the repository at this point in the history
NNsight.VLLM tests
  • Loading branch information
JadenFiotto-Kaufman authored Dec 5, 2024
2 parents 9290dd0 + 76fb4ce commit 043e0ec
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 17 deletions.
8 changes: 7 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cuda:0")

parser.addoption(
"--tp",
action="store",
type=int,
default="1",
help="An argument for specifying the number of gpus to be used by VLLM"
)

def pytest_generate_tests(metafunc):
# This is called for every test. Only get/set command line arguments
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/intervention/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def copy(
new_graph.compiled = self.compiled

for key, value in self.call_counter.items():
self.call_counter[memo[key]] = value
new_graph.call_counter[memo[key]] = value

if new_graph.compiled:

Expand All @@ -382,7 +382,7 @@ def copy(

for key, values in self.deferred.items():

new_graph[memo[key]] = [memo[index] for index in values]
new_graph.deferred[memo[key]] = [memo[index] for index in values]

new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph]

Expand Down
5 changes: 3 additions & 2 deletions src/nnsight/modeling/vllm/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ def prepare(
intervention_graph = None
elif n_graphs == 1:
intervention_graph =intervention_graphs[0]
else:

""" else:
intervention_graph = MultiGraph(intervention_graphs.values())
InterventionProtocol.shift(intervention_graph)
InterventionProtocol.shift(intervention_graph) """

###########################################

Expand Down
12 changes: 0 additions & 12 deletions src/nnsight/modeling/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,8 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module":
)

# creating the vllm engine configuration
""" dict(
(field.name, getattr(self, field.name)) for field in fields(self)) """
vllm_config = engine_args.create_engine_config()
vllm_config_dict = {field.name: getattr(vllm_config, field.name) for field in fields(type(vllm_config))}

#dict((field.name, getattr(self, field.name)) for field in fields(engine_args.create_engine_config()))
#vllm_config_dict = engine_args.create_engine_config().to_dict()

# starting the distributed environment
init_distributed_environment(
Expand All @@ -103,13 +98,6 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module":
initialize_model_parallel(backend="gloo")

# initialize the model
""" model = _initialize_model(
model_config=vllm_config_dict["model_config"],
load_config=vllm_config_dict["load_config"],
lora_config=None,
cache_config=vllm_config_dict["cache_config"],
scheduler_config=vllm_config_dict["scheduler_config"],
) """

model = _initialize_model(vllm_config)

Expand Down
213 changes: 213 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import pytest
import nnsight
import torch
from typing import TYPE_CHECKING

from nnsight.tracing.backends import Backend
from nnsight.tracing.protocols import StopProtocol

if TYPE_CHECKING:
from nnsight.tracing.graph import Graph

try:
from nnsight.modeling.vllm import VLLM
except:
pytest.skip("Skipping VLLM tests", allow_module_level=True)


class AssertSavedLenBackend(Backend):

def __init__(self, len:int) -> None:
self.len = len

def __call__(self, graph: "Graph") -> None:

try:

graph.nodes[-1].execute()

except StopProtocol.StopException:

pass

finally:

assert self.len == len([node for node in graph.nodes if node.done])

graph.nodes.clear()
graph.stack.clear()


@pytest.fixture(scope="module")
def tp(request):
tp = request.config.getoption("--tp")
if tp > torch.cuda.device_count() or tp < 1:
pytest.exit("--tp can't be higher than the number of availale GPUs.")
return tp

@pytest.fixture(scope="module")
def vllm_gpt2(tp: int):
return VLLM("gpt2", tensor_parallel_size=tp, dispatch=True)

@pytest.fixture
def ET_prompt():
return "The Eiffel Tower is located in the city of"

@pytest.fixture
def MSG_prompt():
return "Madison Square Garden is located in the city of"


def test_single_logit(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(1)):
logits = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
assert next_token == " Paris"


def test_multi_token_generation(vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=3):
logits = nnsight.list().save()
for ii in range(3):
logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [" New", " York", " City"]


""" def test_max_token_generation(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, max_tokens=10):
logits = nnsight.list().save()
with vllm_gpt2.logits.all():
logits.append(vllm_gpt2.logits.output)
assert len(logits) == 10 """


""" def test_sampling(vllm_gpt2, ET_prompt:str):
with vllm_gpt2.trace(ET_prompt, temperature=0.8, top_p=0.95, max_tokens=3):
samples = nnsight.list().save()
with vllm_gpt2.sample.all():
li.append(vllm_gpt2.sample.output)
samples = vllm_gpt2.batch_decode([sample.argmax(dim=-1) for sample in samples])
assert samples == [' Canary', ' Wh', 'arf'] """


def test_intervention(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(2)) as tracer:
vllm_gpt2.transformer.h[-2].mlp.output[:] = 0
hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
logits = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
assert next_token == " London"
assert torch.all(hs == 0)


def test_swap_intervention(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(2)) as tracer:
vllm_gpt2.transformer.h[-2].mlp.output = torch.zeros_like(vllm_gpt2.transformer.h[-2].mlp.output)
hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
logits = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
assert next_token == " London"
assert torch.all(hs == 0)


def test_batched_intervention(vllm_gpt2, ET_prompt: str,):
with vllm_gpt2.trace(temperature=0.0, top_p=1, backend=AssertSavedLenBackend(4)) as tracer:

with tracer.invoke(ET_prompt):
clean_hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
clean_logits = vllm_gpt2.logits.output.save()
with tracer.invoke(ET_prompt):
vllm_gpt2.transformer.h[-2].mlp.output[:] = 0
corrupted_hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
corrupted_logits = vllm_gpt2.logits.output.save()

clean_token = vllm_gpt2.tokenizer.decode(clean_logits.argmax(dim=-1))
corrupted_token = vllm_gpt2.tokenizer.decode(corrupted_logits.argmax(dim=-1))

assert clean_token == " Paris"
assert corrupted_token == " London"
assert not torch.all(clean_hs == 0)
assert torch.all(corrupted_hs == 0)


def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
with vllm_gpt2.trace() as tracer:
with tracer.invoke(ET_prompt, max_tokens=3):
ET_logits = nnsight.list().save()
for ii in range(3):
ET_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()
with tracer.invoke(MSG_prompt, max_tokens=5):
MSG_logits = nnsight.list().save()
for ii in range(5):
MSG_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert len(ET_logits) == 3
assert len(MSG_logits) == 5


""" def test_batched_multi_token_generation_with_iter(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
with vllm_gpt2.trace(max_tokens=10) as tracer:
with tracer.invoke(ET_prompt):
ET_logits = nnsight.list().save()
with vllm_gpt2.logits.iter[:3]:
ET_logits.append(vllm_gpt2.logits.output)
#vllm_gpt2.output.save()
with tracer.invoke(MSG_prompt, max_tokens=5):
MSG_logits = nnsight.list().save()
with vllm_gpt2.logits.iter[:5]:
MSG_logits.append(vllm_gpt2.logits.output)
assert len(ET_logits.value) == 3
assert len(MSG_logits.value) == 5 """


def test_mutli_token_generation_with_intervention(tp, vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=5) as tracer:
logits = nnsight.list().save()
hs_list = nnsight.list().save()
for ii in range(5):
if ii == 2:
vllm_gpt2.transformer.h[-2].output[0][:] = 0
hs_list.append(vllm_gpt2.transformer.h[-2].output[0])
vllm_gpt2.transformer.h[-2].next()
logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False]

if tp == 1:
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The']


""" def test_multi_referenced_module(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt):
act_in = vllm_gpt2.transformer.h[0].mlp.act.input.save()
vllm_gpt2.transformer.h[0].mlp.act.next()
act_in_other = vllm_gpt2.transformer.h[1].mlp.act.input.save()
assert not torch.equal(act_in, act_in_other) """


def test_tensor_parallelism(tp, vllm_gpt2, ET_prompt: str):
if tp < 2:
pytest.skip("Skipping test for tp>1!")

with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0):
vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0
hs = vllm_gpt2.transformer.h[5].mlp.c_fc.output[0].save()
logit = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logit.argmax(dim=-1))

#assert next_token != " Paris"
assert hs.shape == torch.Size([11, 3072])
assert torch.all(hs[:, 2000:] == 0)

0 comments on commit 043e0ec

Please sign in to comment.