Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNsight.VLLM tests #301

Merged
merged 3 commits into from
Dec 5, 2024
Merged

NNsight.VLLM tests #301

merged 3 commits into from
Dec 5, 2024

Conversation

AdamBelfki3
Copy link
Collaborator

This PR sets up the ground work for testing the integration and functionality support of the NNsight.VLLM model.

To run the tests in question:

pytest tests/test_vllm.py --tp <tp>

Use the argument --tp to specify the engine argument tensor_parallel_size on the VLLM model, defaults to 1. If --tp > 1, all the tests will be ran within the distributed environment of the model and will also run additional tests exclusively related to that functionality.


Future Work:

Some of the tests included in this PR require further improvements to pass the expected behavior. For the time being, these tests are commented out and the goal is continue the efforts to reach a full working integration of vllm, with minimal change to the user's NNsight experience. Here's a list of future improvements related to these tests:

  • Handle max token generation with the VLLM extra forward pass.
def test_max_token_generation(vllm_gpt2, ET_prompt: str):
    with vllm_gpt2.trace(ET_prompt, max_tokens=10):
        logits = nnsight.list().save()
        with vllm_gpt2.logits.all():
            logits.append(vllm_gpt2.logits.output)

    assert len(logits) == 10
  • Expose the final token sampled by the VLLM model. Something like this:
def test_sampling(vllm_gpt2, ET_prompt:str):
    with vllm_gpt2.trace(ET_prompt, temperature=0.8, top_p=0.95, max_tokens=3):
        samples = nnsight.list().save()
        with vllm_gpt2.sample.all():
            li.append(vllm_gpt2.sample.output)

    samples = vllm_gpt2.batch_decode([sample.argmax(dim=-1) for sample in samples])
    assert samples == [' Canary', ' Wh', 'arf']
  • Handle single modules with multiple references in the inner model's architecture.
def test_multi_referenced_module(vllm_gpt2, ET_prompt: str):
    with vllm_gpt2.trace(ET_prompt):
        act_in = vllm_gpt2.transformer.h[0].mlp.act.input.save()
        vllm_gpt2.transformer.h[0].mlp.act.next()
        act_in_other = vllm_gpt2.transformer.h[1].mlp.act.input.save()

    assert not torch.equal(act_in, act_in_other)
  • Investigate scale of potential discrepancies between tp environment and its impact on token generation with interventions:
def test_tensor_parallelism(tp, vllm_gpt2, ET_prompt: str):
    if tp < 2:
        pytest.skip("Skipping test for tp>1!")

    with vllm_gpt2.trace(ET_prompt, temperature=0.0, top_p=1.0):
        vllm_gpt2.transformer.h[5].mlp.c_fc.output[0][:, 2000:] = 0
        hs = vllm_gpt2.transformer.h[5].mlp.c_fc.output[0].save()
        logit = vllm_gpt2.logits.output.save()

    next_token = vllm_gpt2.tokenizer.decode(logit.argmax(dim=-1))

    assert next_token != " Paris" # this passes for `tp == 1`
    assert hs.shape == torch.Size([11, 3072])
    assert torch.all(hs[:, 2000:] == 0)

@JadenFiotto-Kaufman
Copy link
Member

@AdamBelfki3 Gorgeous 🔥

Base automatically changed from vllm-tp-2 to 0.4 December 5, 2024 17:37
@JadenFiotto-Kaufman JadenFiotto-Kaufman merged commit 043e0ec into 0.4 Dec 5, 2024
1 check failed
@JadenFiotto-Kaufman JadenFiotto-Kaufman deleted the vllm-testing branch December 5, 2024 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants