Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNsight.VLLM tests #301

Merged
merged 3 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@

def pytest_addoption(parser):
parser.addoption("--device", action="store", default="cuda:0")

parser.addoption(
"--tp",
action="store",
type=int,
default="1",
help="An argument for specifying the number of gpus to be used by VLLM"
)

def pytest_generate_tests(metafunc):
# This is called for every test. Only get/set command line arguments
Expand Down
4 changes: 2 additions & 2 deletions src/nnsight/intervention/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def copy(
new_graph.compiled = self.compiled

for key, value in self.call_counter.items():
self.call_counter[memo[key]] = value
new_graph.call_counter[memo[key]] = value

if new_graph.compiled:

Expand All @@ -382,7 +382,7 @@ def copy(

for key, values in self.deferred.items():

new_graph[memo[key]] = [memo[index] for index in values]
new_graph.deferred[memo[key]] = [memo[index] for index in values]

new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph]

Expand Down
5 changes: 3 additions & 2 deletions src/nnsight/modeling/vllm/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,11 @@ def prepare(
intervention_graph = None
elif n_graphs == 1:
intervention_graph =intervention_graphs[0]
else:

""" else:
intervention_graph = MultiGraph(intervention_graphs.values())

InterventionProtocol.shift(intervention_graph)
InterventionProtocol.shift(intervention_graph) """

###########################################

Expand Down
12 changes: 0 additions & 12 deletions src/nnsight/modeling/vllm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,8 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module":
)

# creating the vllm engine configuration
""" dict(
(field.name, getattr(self, field.name)) for field in fields(self)) """
vllm_config = engine_args.create_engine_config()
vllm_config_dict = {field.name: getattr(vllm_config, field.name) for field in fields(type(vllm_config))}

#dict((field.name, getattr(self, field.name)) for field in fields(engine_args.create_engine_config()))
#vllm_config_dict = engine_args.create_engine_config().to_dict()

# starting the distributed environment
init_distributed_environment(
Expand All @@ -103,13 +98,6 @@ def _load_meta(self, repo_id: str, **kwargs) -> "Module":
initialize_model_parallel(backend="gloo")

# initialize the model
""" model = _initialize_model(
model_config=vllm_config_dict["model_config"],
load_config=vllm_config_dict["load_config"],
lora_config=None,
cache_config=vllm_config_dict["cache_config"],
scheduler_config=vllm_config_dict["scheduler_config"],
) """

model = _initialize_model(vllm_config)

Expand Down
213 changes: 213 additions & 0 deletions tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import pytest
import nnsight
import torch
from typing import TYPE_CHECKING

from nnsight.tracing.backends import Backend
from nnsight.tracing.protocols import StopProtocol

if TYPE_CHECKING:
from nnsight.tracing.graph import Graph

try:
from nnsight.modeling.vllm import VLLM
except:
pytest.skip("Skipping VLLM tests", allow_module_level=True)


class AssertSavedLenBackend(Backend):

def __init__(self, len:int) -> None:
self.len = len

def __call__(self, graph: "Graph") -> None:

try:

graph.nodes[-1].execute()

except StopProtocol.StopException:

pass

finally:

assert self.len == len([node for node in graph.nodes if node.done])

graph.nodes.clear()
graph.stack.clear()


@pytest.fixture(scope="module")
def tp(request):
tp = request.config.getoption("--tp")
if tp > torch.cuda.device_count() or tp < 1:
pytest.exit("--tp can't be higher than the number of availale GPUs.")
return tp

@pytest.fixture(scope="module")
def vllm_gpt2(tp: int):
return VLLM("gpt2", tensor_parallel_size=tp, dispatch=True)

@pytest.fixture
def ET_prompt():
return "The Eiffel Tower is located in the city of"

@pytest.fixture
def MSG_prompt():
return "Madison Square Garden is located in the city of"


def test_single_logit(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(1)):
logits = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
assert next_token == " Paris"


def test_multi_token_generation(vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=3):
logits = nnsight.list().save()
for ii in range(3):
logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [" New", " York", " City"]


""" def test_max_token_generation(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, max_tokens=10):
logits = nnsight.list().save()
with vllm_gpt2.logits.all():
logits.append(vllm_gpt2.logits.output)

assert len(logits) == 10 """


""" def test_sampling(vllm_gpt2, ET_prompt:str):
with vllm_gpt2.trace(ET_prompt, temperature=0.8, top_p=0.95, max_tokens=3):
samples = nnsight.list().save()
with vllm_gpt2.sample.all():
li.append(vllm_gpt2.sample.output)

samples = vllm_gpt2.batch_decode([sample.argmax(dim=-1) for sample in samples])
assert samples == [' Canary', ' Wh', 'arf'] """


def test_intervention(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(2)) as tracer:
vllm_gpt2.transformer.h[-2].mlp.output[:] = 0
hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
logits = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
assert next_token == " London"
assert torch.all(hs == 0)


def test_swap_intervention(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1, backend=AssertSavedLenBackend(2)) as tracer:
vllm_gpt2.transformer.h[-2].mlp.output = torch.zeros_like(vllm_gpt2.transformer.h[-2].mlp.output)
hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
logits = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logits.argmax(dim=-1))
assert next_token == " London"
assert torch.all(hs == 0)


def test_batched_intervention(vllm_gpt2, ET_prompt: str,):
with vllm_gpt2.trace(temperature=0.0, top_p=1, backend=AssertSavedLenBackend(4)) as tracer:

with tracer.invoke(ET_prompt):
clean_hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
clean_logits = vllm_gpt2.logits.output.save()
with tracer.invoke(ET_prompt):
vllm_gpt2.transformer.h[-2].mlp.output[:] = 0
corrupted_hs = vllm_gpt2.transformer.h[-2].mlp.output.save()
corrupted_logits = vllm_gpt2.logits.output.save()

clean_token = vllm_gpt2.tokenizer.decode(clean_logits.argmax(dim=-1))
corrupted_token = vllm_gpt2.tokenizer.decode(corrupted_logits.argmax(dim=-1))

assert clean_token == " Paris"
assert corrupted_token == " London"
assert not torch.all(clean_hs == 0)
assert torch.all(corrupted_hs == 0)


def test_batched_multi_token_generation(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
with vllm_gpt2.trace() as tracer:
with tracer.invoke(ET_prompt, max_tokens=3):
ET_logits = nnsight.list().save()
for ii in range(3):
ET_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()
with tracer.invoke(MSG_prompt, max_tokens=5):
MSG_logits = nnsight.list().save()
for ii in range(5):
MSG_logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert len(ET_logits) == 3
assert len(MSG_logits) == 5


""" def test_batched_multi_token_generation_with_iter(vllm_gpt2, ET_prompt: str, MSG_prompt: str):
with vllm_gpt2.trace(max_tokens=10) as tracer:
with tracer.invoke(ET_prompt):
ET_logits = nnsight.list().save()
with vllm_gpt2.logits.iter[:3]:
ET_logits.append(vllm_gpt2.logits.output)
#vllm_gpt2.output.save()
with tracer.invoke(MSG_prompt, max_tokens=5):
MSG_logits = nnsight.list().save()
with vllm_gpt2.logits.iter[:5]:
MSG_logits.append(vllm_gpt2.logits.output)

assert len(ET_logits.value) == 3
assert len(MSG_logits.value) == 5 """


def test_mutli_token_generation_with_intervention(tp, vllm_gpt2, MSG_prompt: str):
with vllm_gpt2.trace(MSG_prompt, temperature=0.0, top_p=1.0, max_tokens=5) as tracer:
logits = nnsight.list().save()
hs_list = nnsight.list().save()
for ii in range(5):
if ii == 2:
vllm_gpt2.transformer.h[-2].output[0][:] = 0
hs_list.append(vllm_gpt2.transformer.h[-2].output[0])
vllm_gpt2.transformer.h[-2].next()
logits.append(vllm_gpt2.logits.output)
vllm_gpt2.logits.next()

assert [torch.all(hs == 0) for hs in hs_list.value] == [False, False, True, False, False]

if tp == 1:
assert vllm_gpt2.tokenizer.batch_decode([logit.argmax(dim=-1) for logit in logits.value]) == [' New', ' York', '\n', '\n', 'The']


""" def test_multi_referenced_module(vllm_gpt2, ET_prompt: str):
with vllm_gpt2.trace(ET_prompt):
act_in = vllm_gpt2.transformer.h[0].mlp.act.input.save()
vllm_gpt2.transformer.h[0].mlp.act.next()
act_in_other = vllm_gpt2.transformer.h[1].mlp.act.input.save()

assert not torch.equal(act_in, act_in_other) """


def test_tensor_parallelism(tp, vllm_gpt2, ET_prompt: str):
if tp < 2:
pytest.skip("Skipping test for tp>1!")

with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0):
vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0
hs = vllm_gpt2.transformer.h[5].mlp.c_fc.output[0].save()
logit = vllm_gpt2.logits.output.save()

next_token = vllm_gpt2.tokenizer.decode(logit.argmax(dim=-1))

#assert next_token != " Paris"
assert hs.shape == torch.Size([11, 3072])
assert torch.all(hs[:, 2000:] == 0)
Loading