Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NNsight <> vLLM #244

Closed
wants to merge 2 commits into from
Closed

NNsight <> vLLM #244

wants to merge 2 commits into from

Conversation

AdamBelfki3
Copy link
Collaborator

@AdamBelfki3 AdamBelfki3 commented Sep 23, 2024

NNsight wrapper to conduct interventions on the vLLM inference engine.

Example 1 - vLLM inference generation:

from nnsight.models.VLLM import VLLM
from vllm import SamplingParams

model = VLLM("gpt2")

prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

outputs = model.generate(prompts, sampling_params=sampling_params, trace=False)

for output in outputs.value:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

Example 2 - intervention on vLLM inference model:

from nnsight.models.VLLM import VLLM
from vllm import SamplingParams

model = VLLM("gpt2")

prompt = ["The Eiffel Tower is in the city of"]
sampling_params = SamplingParams(temperature=0.0, top_p=0.95, stop=["."])

with model.trace(prompt, sampling_params=sampling_params) as tracer:
      model.model.transformer.h[8].output[-1][:] = 0

      outputs = model.output.save()

for output in outputs.value:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
>>> Prompt: 'The Eiffel Tower is in the city of', Generated text: ' and near the city of London'

Example 3 - initialization loads the model on the meta device:

Tensor parallelism is currently not supported with this model type.

from nnsight.models.VLLM import VLLM

vllm_model = VLLM("meta-llama/Meta-Llama-3.1-405B")
print(vllm_model)
>>> VLLModel(
  (model): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): VocabParallelEmbedding(num_embeddings=128256, embedding_dim=16384, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
      (layers): ModuleList(
        (0-125): 126 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (qkv_proj): QKVParallelLinear(in_features=16384, output_features=18432, bias=False, tp_size=1, gather_output=False)
            (o_proj): RowParallelLinear(input_features=16384, output_features=16384, bias=False, tp_size=1, reduce_results=True)
            (rotary_emb): Llama3RotaryEmbedding(head_size=128, rotary_dim=128, max_position_embeddings=131072, base=500000.0, is_neox_style=True)
            (attn): Attention(head_size=128, num_heads=128, num_kv_heads=8, scale=0.08838834764831845, backend=XFormersImpl)
          )
          (mlp): LlamaMLP(
            (gate_up_proj): MergedColumnParallelLinear(in_features=16384, output_features=106496, bias=False, tp_size=1, gather_output=False)
            (down_proj): RowParallelLinear(input_features=53248, output_features=16384, bias=False, tp_size=1, reduce_results=True)
            (act_fn): SiluAndMul()
          )
          (input_layernorm): RMSNorm(hidden_size=16384, eps=1e-05)
          (post_attention_layernorm): RMSNorm(hidden_size=16384, eps=1e-05)
        )
      )
      (norm): RMSNorm(hidden_size=16384, eps=1e-05)
    )
    (lm_head): ParallelLMHead(num_embeddings=128256, embedding_dim=16384, org_vocab_size=128256, num_embeddings_padded=128256, tp_size=1)
    (logits_processor): LogitsProcessor(vocab_size=128256, forg_vocab_size=128256, scale=1.0, logits_as_input=False)
    (sampler): Sampler()
  )
)

Additional work must be introduced to provide full support for vLLM with NNsight:

@AdamBelfki3 AdamBelfki3 linked an issue Oct 9, 2024 that may be closed by this pull request
@JadenFiotto-Kaufman JadenFiotto-Kaufman deleted the vllm-model branch December 22, 2024 18:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

vLLM gets initialized on 'meta' device first
2 participants