From bbc317850a63057112ee325b5e609d50eed596cf Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Sun, 13 Oct 2024 13:12:34 -0400 Subject: [PATCH 1/5] feat (api): Add support for intervening on parallelized modules --- src/nnsight/intervention.py | 6 +- src/nnsight/models/NNsightModel.py | 11 +-- src/nnsight/models/vllm/__init__.py | 98 ++++++++++++++++--- .../vllm/model_runners/GPUModelRunner.py | 2 +- src/nnsight/patching.py | 4 +- 5 files changed, 98 insertions(+), 23 deletions(-) diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 0e63cb81..8d4d3c3b 100755 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -438,6 +438,7 @@ def intervene( cls, activations: Any, module_path: str, + module: torch.nn.Module, key: str, intervention_handler: InterventionHandler, ): @@ -458,6 +459,7 @@ def intervene( Args: activations (Any): Either the inputs or outputs of a torch module. module_path (str): Module path of the current relevant module relative to the root model. + module (torch.nn.Module): Module to be intervened on. key (str): Key denoting either "input" or "output" of module. intervention_handler (InterventionHandler): Handler object that stores the intervention graph and keeps track of module call count. @@ -621,7 +623,7 @@ def __enter__(self) -> HookHandler: if hook_type == "input": def input_hook(module, input, kwargs, module_path=module_path): - return self.input_hook((input, kwargs), module_path) + return self.input_hook((input, kwargs), module_path, module) self.handles.append( module.register_forward_pre_hook( @@ -632,7 +634,7 @@ def input_hook(module, input, kwargs, module_path=module_path): elif hook_type == "output": def output_hook(module, input, output, module_path=module_path): - return self.output_hook(output, module_path) + return self.output_hook(output, module_path, module) self.handles.append( module.register_forward_hook(output_hook, prepend=True) diff --git a/src/nnsight/models/NNsightModel.py b/src/nnsight/models/NNsightModel.py index 0c1f8b9a..b66a3410 100755 --- a/src/nnsight/models/NNsightModel.py +++ b/src/nnsight/models/NNsightModel.py @@ -5,11 +5,9 @@ Any, Callable, Dict, - List, Optional, Tuple, Type, - TypeVar, Union, ) @@ -35,6 +33,7 @@ InterventionProxy, ) from ..tracing.Graph import Graph +from ..tracing import protocols class NNsight: @@ -411,11 +410,11 @@ def interleave( with HookHandler( self._model, list(module_paths), - input_hook=lambda activations, module_path: InterventionProtocol.intervene( - activations, module_path, "input", intervention_handler + input_hook=lambda activations, module_path, module: InterventionProtocol.intervene( + activations, module_path, module, "input", intervention_handler ), - output_hook=lambda activations, module_path: InterventionProtocol.intervene( - activations, module_path, "output", intervention_handler + output_hook=lambda activations, module_path, module: InterventionProtocol.intervene( + activations, module_path, module, "output", intervention_handler ), ): try: diff --git a/src/nnsight/models/vllm/__init__.py b/src/nnsight/models/vllm/__init__.py index 86fe9f37..6964111e 100755 --- a/src/nnsight/models/vllm/__init__.py +++ b/src/nnsight/models/vllm/__init__.py @@ -1,10 +1,20 @@ -import weakref -from typing import Any, Callable, Dict, List, Tuple, Union +from functools import wraps +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union +import torch + +from vllm.distributed import (get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from ...envoy import Envoy +from ...intervention import InterventionProtocol +from ...patching import Patch, Patcher from ...tracing import protocols from ...tracing.Graph import Graph from ...util import TypeHint, WrapperModule, hint @@ -13,13 +23,14 @@ from .executors.RayGPUExecutor import NNsightRayGPUExecutor from .sampling import NNsightSamplingParams +if TYPE_CHECKING: + from ...intervention import InterventionHandler + try: - from vllm.distributed import ( - destroy_distributed_environment, - destroy_model_parallel, - init_distributed_environment, - initialize_model_parallel, - ) + from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.entrypoints.llm import LLM from vllm.model_executor.model_loader.loader import _initialize_model @@ -110,13 +121,13 @@ def _load_meta(self, repo_id: str, **kwargs): enable_lora=bool(engine_config_dict["lora_config"]), ).tokenizer - destroy_model_parallel() - destroy_distributed_environment() - return model def _load(self, repo_id: str, **kwargs): + destroy_model_parallel() + destroy_distributed_environment() + distributed_executor_backend = NNsightGPUExecutor if ( "tensor_parallel_size" in kwargs.keys() @@ -204,7 +215,70 @@ def interleave( param.intervention_graph = intervention_graph - fn(prompts, params, **kwargs) + def parallel_intervene(intervene_func: Callable) -> Callable: + """ Create an intervene wrapper that handles tensor parallelism execution of vLLM models. + + Args: + intervene_func (Callable): intervention function. + + Returns + """ + + @wraps(intervene_func) + def parallel_intervene_wrapper( + activations: Any, + module_path: str, + module: torch.nn.Module, + key: str, + intervention_handler: "InterventionHandler" + ) -> Any: + """ InterventionProtocol.intervene wrapper handling the parallelized modules of vLLM. + If some activations were parallelized, then they need to be gathered as a full tensor to intervene on them, + and then split again before returning them. + + Args: + activations (Any): Either the inputs or outputs of a torch module. + module_path (str): Module path of the current relevant module relative to the root model. + module (torch.nn.Module): Module to be intervened on. + key (str): Key denoting either "input" or "output" of module. + intervention_handler (InterventionHandler): Handler object that stores the intervention graph and keeps track of module call count. + + Returns: + Any: The activations, potentially modified by the intervention graph. + """ + # If the activations are parallelized, they must be gathered before intervening on them + if isinstance(module, ColumnParallelLinear) and key == "output" and not module.gather_output: + full_tensor = tensor_model_parallel_all_gather(activations[0]) + activations = (full_tensor, ) + activations[1:] + if isinstance(module, RowParallelLinear) and key == "input" and module.input_is_parallel: + full_tensor = tensor_model_parallel_all_gather(activations[0][0]) + activations = ((full_tensor,) + activations[0][1:], ) + activations[1:] + + activations = intervene_func(activations, module_path, module, key, intervention_handler) + + # If the activations were parallelized originally, they must be split again before returning them + if isinstance(module, ColumnParallelLinear) and key == "output" and not module.gather_output: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim(activations[0], num_partitions=get_tensor_model_parallel_world_size()) + activations = (splitted_input[tp_rank].contiguous(),) + activations[1:] + if isinstance(module, RowParallelLinear) and key == "input" and module.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim(activations[0][0], num_partitions=get_tensor_model_parallel_world_size()) + activations = ((splitted_input[tp_rank].contiguous(),) + activations[0][1:],) + activations[1:] + + return activations + + return parallel_intervene_wrapper + + # handle parallelmodules with custom intervene function for inference with tensor parallelism + if get_tensor_model_parallel_world_size() > 1: + intervene_patch = Patch(InterventionProtocol, parallel_intervene(InterventionProtocol.intervene), "intervene") + else: + intervene_patch = Patch(InterventionProtocol, InterventionProtocol.intervene, "intervene") + + with Patcher([intervene_patch]): + + fn(prompts, params, **kwargs) intervention_graph.alive = False diff --git a/src/nnsight/models/vllm/model_runners/GPUModelRunner.py b/src/nnsight/models/vllm/model_runners/GPUModelRunner.py index 6f2d876c..33ca6c99 100755 --- a/src/nnsight/models/vllm/model_runners/GPUModelRunner.py +++ b/src/nnsight/models/vllm/model_runners/GPUModelRunner.py @@ -1,10 +1,10 @@ import dataclasses from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union -from nnsight.models.NNsightModel import NNsight import torch import torch.distributed +from nnsight.models.NNsightModel import NNsight from vllm.distributed import get_pp_group from vllm.forward_context import set_forward_context from vllm.model_executor.layers.sampler import SamplerOutput diff --git a/src/nnsight/patching.py b/src/nnsight/patching.py index 263a187b..cee141b5 100644 --- a/src/nnsight/patching.py +++ b/src/nnsight/patching.py @@ -13,9 +13,9 @@ class Patch: """Class representing a replacement of an attribute on a module. Attributes: - obj (Any): Object to replace. - replacement (Any): Object that replaces. parent (Any): Module or class to replace attribute. + replacement (Any): Object that replaces. + key (Any): Object to replace. """ def __init__(self, parent: Any, replacement: Any, key: str) -> None: From 4c410a6f2d66707d576f42c4486718be671d7a9e Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Sun, 27 Oct 2024 17:01:57 -0400 Subject: [PATCH 2/5] fix support for tensor parallelism on vllm models + some improvements --- src/nnsight/intervention.py | 2 +- src/nnsight/models/vllm/__init__.py | 293 +----------------- .../vllm/model_runners/GPUModelRunner.py | 90 +++++- src/nnsight/models/vllm/vllm.py | 243 +++++++++++++++ 4 files changed, 325 insertions(+), 303 deletions(-) create mode 100644 src/nnsight/models/vllm/vllm.py diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 8d4d3c3b..8b2466a0 100755 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -12,7 +12,7 @@ import inspect from collections import defaultdict from contextlib import AbstractContextManager -from typing import Any, Callable, Collection, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Collection, Dict, List, Tuple, Union import torch from torch.utils.hooks import RemovableHandle diff --git a/src/nnsight/models/vllm/__init__.py b/src/nnsight/models/vllm/__init__.py index 6964111e..1a180b71 100755 --- a/src/nnsight/models/vllm/__init__.py +++ b/src/nnsight/models/vllm/__init__.py @@ -1,292 +1 @@ -from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union - -import torch - -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - split_tensor_along_last_dim, - tensor_model_parallel_all_gather) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - RowParallelLinear) -from vllm.transformers_utils.tokenizer import AnyTokenizer -from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs - -from ...envoy import Envoy -from ...intervention import InterventionProtocol -from ...patching import Patch, Patcher -from ...tracing import protocols -from ...tracing.Graph import Graph -from ...util import TypeHint, WrapperModule, hint -from ..mixins import RemoteableMixin -from .executors.GPUExecutor import NNsightGPUExecutor -from .executors.RayGPUExecutor import NNsightRayGPUExecutor -from .sampling import NNsightSamplingParams - -if TYPE_CHECKING: - from ...intervention import InterventionHandler - -try: - from vllm.distributed import (destroy_distributed_environment, - destroy_model_parallel, - init_distributed_environment, - initialize_model_parallel) - from vllm.engine.arg_utils import EngineArgs - from vllm.entrypoints.llm import LLM - from vllm.model_executor.model_loader.loader import _initialize_model -except Exception as e: - - raise type(e)( - "Install vllm in your environment to use it with NNsight. " - + "https://docs.vllm.ai/en/latest/getting_started/installation.html" - ) from e - - -@hint -class VLLM(RemoteableMixin, TypeHint[Union[LLM, Envoy]]): - """NNsight wrapper to conduct interventions on a vLLM inference engine. - - .. code-block:: python - from nnsight.models.VLLM import VLLM - from vllm import SamplingParams - - model = VLLM("gpt2") - - prompt = ["The Eiffel Tower is in the city of"] - sampling_params = SamplingParams(temperature=0.0, top_p=0.95, stop=["."]) - - with model.trace(prompt, sampling_params=sampling_params) as tracer: - model.model.transformer.h[8].output[-1][:] = 0 - - outputs = model.output.save() - - for output in outputs.value: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - """ - - __methods__ = {"generate": "_execute"} - - def __init__(self, *args, **kwargs) -> None: - - self.vllm_entrypoint: LLM = None - self.tokenizer: AnyTokenizer = None - - super().__init__(*args, **kwargs) - - self.logits = WrapperModule() - self.tokens = WrapperModule() - - def _load_meta(self, repo_id: str, **kwargs): - - # no parallelism during initialization - kwargs["tensor_parallel_size"] = 1 - kwargs["pipeline_parallel_size"] = 1 - - # creating vLLM Engine args - engine_args = EngineArgs( - model=repo_id, - **kwargs, - ) - - # creating the vllm engine configuration - engine_config_dict = engine_args.create_engine_config().to_dict() - - # starting the distributed environment - init_distributed_environment( - 1, - 0, - "tcp://127.0.0.1:47303", - 0, - backend="gloo", - ) - - # start tensor parallel group - initialize_model_parallel(backend="gloo") - - # initialize the model - model = _initialize_model( - model_config=engine_config_dict["model_config"], - load_config=engine_config_dict["load_config"], - lora_config=None, - cache_config=engine_config_dict["cache_config"], - scheduler_config=engine_config_dict["scheduler_config"], - ) - - self.tokenizer = init_tokenizer_from_configs( - model_config=engine_config_dict["model_config"], - scheduler_config=engine_config_dict["scheduler_config"], - parallel_config=engine_config_dict["parallel_config"], - enable_lora=bool(engine_config_dict["lora_config"]), - ).tokenizer - - return model - - def _load(self, repo_id: str, **kwargs): - - destroy_model_parallel() - destroy_distributed_environment() - - distributed_executor_backend = NNsightGPUExecutor - if ( - "tensor_parallel_size" in kwargs.keys() - and kwargs["tensor_parallel_size"] > 1 - ): - distributed_executor_backend = NNsightRayGPUExecutor - - llm = LLM( - repo_id, - **kwargs, - distributed_executor_backend=distributed_executor_backend, - ) - - self.vllm_entrypoint = llm - - return self._model - - def _prepare_input( - self, *args, **kwargs - ) -> Tuple[Tuple[Tuple[Any], Dict[str, Any]], int]: - - if "processed" in kwargs: - return (args, kwargs), len(args[0]) - - prompts = [] - params = [] - - for arg in args: - - if not type(arg) is list: - arg = [arg] - - for prompt in arg: - - param = NNsightSamplingParams( - **kwargs, - ) - - prompts.append(prompt) - params.append(param) - - return ((prompts, params), {"processed": True}), len(prompts) - - def _batch( - self, - batched_inputs: Tuple[Tuple[Any] | protocols.Dict[str, Any]] | None, - prompts: List[str], - params: List[NNsightSamplingParams], - **kwargs, - ) -> Tuple[Tuple[Any] | protocols.Dict[str, Any]]: - - if batched_inputs is None: - batched_inputs = ([], []), {"invoker_group": 0} - - (bprompts, bparams), kwargs = batched_inputs - - invoker_group = kwargs["invoker_group"] - - for prompt in prompts: - bprompts.append(prompt) - - for param in params: - - param.invoker_group = invoker_group - - bparams.append(param) - - kwargs["invoker_group"] += 1 - - return (bprompts, bparams), kwargs - - def interleave( - self, - fn: Callable, - intervention_graph: Graph, - prompts: List[str], - params: List[NNsightSamplingParams], - **kwargs, - ) -> Any: - - if not self.dispatched: - self.dispatch() - - for param in params: - - param.intervention_graph = intervention_graph - - def parallel_intervene(intervene_func: Callable) -> Callable: - """ Create an intervene wrapper that handles tensor parallelism execution of vLLM models. - - Args: - intervene_func (Callable): intervention function. - - Returns - """ - - @wraps(intervene_func) - def parallel_intervene_wrapper( - activations: Any, - module_path: str, - module: torch.nn.Module, - key: str, - intervention_handler: "InterventionHandler" - ) -> Any: - """ InterventionProtocol.intervene wrapper handling the parallelized modules of vLLM. - If some activations were parallelized, then they need to be gathered as a full tensor to intervene on them, - and then split again before returning them. - - Args: - activations (Any): Either the inputs or outputs of a torch module. - module_path (str): Module path of the current relevant module relative to the root model. - module (torch.nn.Module): Module to be intervened on. - key (str): Key denoting either "input" or "output" of module. - intervention_handler (InterventionHandler): Handler object that stores the intervention graph and keeps track of module call count. - - Returns: - Any: The activations, potentially modified by the intervention graph. - """ - # If the activations are parallelized, they must be gathered before intervening on them - if isinstance(module, ColumnParallelLinear) and key == "output" and not module.gather_output: - full_tensor = tensor_model_parallel_all_gather(activations[0]) - activations = (full_tensor, ) + activations[1:] - if isinstance(module, RowParallelLinear) and key == "input" and module.input_is_parallel: - full_tensor = tensor_model_parallel_all_gather(activations[0][0]) - activations = ((full_tensor,) + activations[0][1:], ) + activations[1:] - - activations = intervene_func(activations, module_path, module, key, intervention_handler) - - # If the activations were parallelized originally, they must be split again before returning them - if isinstance(module, ColumnParallelLinear) and key == "output" and not module.gather_output: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim(activations[0], num_partitions=get_tensor_model_parallel_world_size()) - activations = (splitted_input[tp_rank].contiguous(),) + activations[1:] - if isinstance(module, RowParallelLinear) and key == "input" and module.input_is_parallel: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim(activations[0][0], num_partitions=get_tensor_model_parallel_world_size()) - activations = ((splitted_input[tp_rank].contiguous(),) + activations[0][1:],) + activations[1:] - - return activations - - return parallel_intervene_wrapper - - # handle parallelmodules with custom intervene function for inference with tensor parallelism - if get_tensor_model_parallel_world_size() > 1: - intervene_patch = Patch(InterventionProtocol, parallel_intervene(InterventionProtocol.intervene), "intervene") - else: - intervene_patch = Patch(InterventionProtocol, InterventionProtocol.intervene, "intervene") - - with Patcher([intervene_patch]): - - fn(prompts, params, **kwargs) - - intervention_graph.alive = False - - def _execute( - self, - prompts: List[str], - params: List[NNsightSamplingParams], - **kwargs, - ) -> Any: - - self.vllm_entrypoint.generate(prompts, sampling_params=params) +from .vllm import VLLM \ No newline at end of file diff --git a/src/nnsight/models/vllm/model_runners/GPUModelRunner.py b/src/nnsight/models/vllm/model_runners/GPUModelRunner.py index 33ca6c99..ee59bba9 100755 --- a/src/nnsight/models/vllm/model_runners/GPUModelRunner.py +++ b/src/nnsight/models/vllm/model_runners/GPUModelRunner.py @@ -1,13 +1,20 @@ import dataclasses -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union +from functools import wraps +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union, Callable import torch import torch.distributed from nnsight.models.NNsightModel import NNsight -from vllm.distributed import get_pp_group +from vllm.distributed import (get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) from vllm.forward_context import set_forward_context from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) from vllm.multimodal import MultiModalInputs from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, @@ -16,7 +23,8 @@ _add_attn_metadata_broadcastable_dict, _init_attn_metadata_from_tensor_dict, dump_input_when_exception) -from ....intervention import InterventionHandler +from ....intervention import InterventionHandler, InterventionProtocol +from ....patching import Patch, Patcher from .. import VLLM from ..sampling import NNsightSamplingMetadata @@ -209,7 +217,7 @@ def execute_model( "finished_requests_ids": model_input.finished_requests_ids, "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, } - if self.has_seqlen_agnostic + if self.has_inner_state else {} ) if ( @@ -331,11 +339,73 @@ def inner(): output.hidden_states = hidden_states return output + + def parallel_intervene(intervene_func: Callable) -> Callable: + """ Create an intervene wrapper that handles tensor parallelism execution of vLLM models. + + Args: + intervene_func (Callable): intervention function. + + Returns + """ + + @wraps(intervene_func) + def parallel_intervene_wrapper( + activations: Any, + module_path: str, + module: torch.nn.Module, + key: str, + intervention_handler: "InterventionHandler" + ) -> Any: + """ InterventionProtocol.intervene wrapper handling the parallelized modules of vLLM. + If some activations were parallelized, then they need to be gathered as a full tensor to intervene on them, + and then split again before returning them. + + Args: + activations (Any): Either the inputs or outputs of a torch module. + module_path (str): Module path of the current relevant module relative to the root model. + module (torch.nn.Module): Module to be intervened on. + key (str): Key denoting either "input" or "output" of module. + intervention_handler (InterventionHandler): Handler object that stores the intervention graph and keeps track of module call count. + + Returns: + Any: The activations, potentially modified by the intervention graph. + """ + # If the activations are parallelized, they must be gathered before intervening on them + if isinstance(module, ColumnParallelLinear) and key == "output" and not module.gather_output: + full_tensor = tensor_model_parallel_all_gather(activations[0]) + activations = (full_tensor, ) + activations[1:] + if isinstance(module, RowParallelLinear) and key == "input" and module.input_is_parallel: + full_tensor = tensor_model_parallel_all_gather(activations[0][0]) + activations = ((full_tensor,) + activations[0][1:], ) + activations[1:] + + activations = intervene_func(activations, module_path, module, key, intervention_handler) + + # If the activations were parallelized originally, they must be split again before returning them + if isinstance(module, ColumnParallelLinear) and key == "output" and not module.gather_output: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim(activations[0], num_partitions=get_tensor_model_parallel_world_size()) + activations = (splitted_input[tp_rank].contiguous(),) + activations[1:] + if isinstance(module, RowParallelLinear) and key == "input" and module.input_is_parallel: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim(activations[0][0], num_partitions=get_tensor_model_parallel_world_size()) + activations = ((splitted_input[tp_rank].contiguous(),) + activations[0][1:],) + activations[1:] + + return activations + + return parallel_intervene_wrapper + + if get_tensor_model_parallel_world_size() > 1: + intervene_patch = Patch(InterventionProtocol, parallel_intervene(InterventionProtocol.intervene), "intervene") + else: + intervene_patch = Patch(InterventionProtocol, InterventionProtocol.intervene, "intervene") + + with Patcher([intervene_patch]): + output = NNsight.interleave( + self.model, + inner, + intervention_graph, + intervention_handler=intervention_handler, + ) - output = NNsight.interleave( - self.model, - inner, - intervention_graph, - intervention_handler=intervention_handler, - ) return [output] diff --git a/src/nnsight/models/vllm/vllm.py b/src/nnsight/models/vllm/vllm.py new file mode 100644 index 00000000..66a559bf --- /dev/null +++ b/src/nnsight/models/vllm/vllm.py @@ -0,0 +1,243 @@ +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Tuple, Union + +from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs + +from ...envoy import Envoy +from ...tracing import protocols +from ...tracing.Graph import Graph +from ...util import TypeHint, WrapperModule, hint +from ..mixins import RemoteableMixin +from .executors.GPUExecutor import NNsightGPUExecutor +from .executors.RayGPUExecutor import NNsightRayGPUExecutor +from .sampling import NNsightSamplingParams + +if TYPE_CHECKING: + from torch.nn import Module + from vllm.transformers_utils.tokenizer import AnyTokenizer + from vllm.config import ModelConfig, SchedulerConfig, ParallelConfig + +try: + from vllm.distributed import (destroy_distributed_environment, + destroy_model_parallel, + init_distributed_environment, + initialize_model_parallel) + from vllm.engine.arg_utils import EngineArgs + from vllm.entrypoints.llm import LLM + from vllm.model_executor.model_loader.loader import _initialize_model +except Exception as e: + raise type(e)( + "Install vllm in your environment to use it with NNsight. " + + "https://docs.vllm.ai/en/latest/getting_started/installation.html" + ) from e + + +@hint +class VLLM(RemoteableMixin, TypeHint[Union[LLM, Envoy]]): + """NNsight wrapper to conduct interventions on a vLLM inference engine.\ + + Attributes: + - vllm_entrypoint (vllm.LLM): vLLM language model. + - tokenizer (vllm.transformers_utils.tokenizer.AnyTokenizer): tokenizer. + - logits (nnsight.WrapperModule): logits. + - tokens (nnsight.WrapperModule): tokens. + + .. code-block:: python + from nnsight.models.VLLM import VLLM + from vllm import SamplingParams + + model = VLLM("gpt2") + + prompt = ["The Eiffel Tower is in the city of"] + + with model.trace(prompt, temperature=0.0, top_p=0.95, stop=['.']) as tracer: + model.transformer.h[8].output[-1][:] = 0 + + output = model.output.save() + + print(model.tokenizer.decode(output.value.argmax(dim=-1)[-1])) + """ + + __methods__ = {"generate": "_execute"} + + def __init__(self, *args, **kwargs) -> None: + + self.vllm_entrypoint: LLM = None + self.tokenizer: "AnyTokenizer" = None + + super().__init__(*args, **kwargs) + + self.logits: WrapperModule = WrapperModule() + self.tokens: WrapperModule = WrapperModule() + + def _load_meta(self, repo_id: str, **kwargs) -> "Module": + + # no parallelism during initialization + kwargs["tensor_parallel_size"] = 1 + kwargs["pipeline_parallel_size"] = 1 + + # creating vLLM Engine args + engine_args = EngineArgs( + model=repo_id, + **kwargs, + ) + + # creating the vllm engine configuration + engine_config_dict = engine_args.create_engine_config().to_dict() + + # starting the distributed environment + init_distributed_environment( + 1, + 0, + "tcp://127.0.0.1:47303", + 0, + backend="gloo", + ) + + # start tensor parallel group + initialize_model_parallel(backend="gloo") + + # initialize the model + model = _initialize_model( + model_config=engine_config_dict["model_config"], + load_config=engine_config_dict["load_config"], + lora_config=None, + cache_config=engine_config_dict["cache_config"], + scheduler_config=engine_config_dict["scheduler_config"], + ) + + # load the tokenzier + self.tokenizer = self._load_tokenizer( + model_config=engine_config_dict["model_config"], + scheduler_config=engine_config_dict["scheduler_config"], + parallel_config=engine_config_dict["parallel_config"], + enable_lora=bool(engine_config_dict["lora_config"]), + ) + + return model + + def _load_tokenizer( + self, + model_config: "ModelConfig", + scheduler_config: "SchedulerConfig", + parallel_config: "ParallelConfig", + enable_lora: bool) -> "AnyTokenizer": + + return init_tokenizer_from_configs( + model_config=model_config, + scheduler_config=scheduler_config, + parallel_config=parallel_config, + enable_lora=enable_lora, + ).tokenizer + + def _load(self, repo_id: str, **kwargs) -> "Module": + + destroy_model_parallel() + destroy_distributed_environment() + + distributed_executor_backend = NNsightGPUExecutor + if ( + "tensor_parallel_size" in kwargs.keys() + and kwargs["tensor_parallel_size"] > 1 + ): + distributed_executor_backend = NNsightRayGPUExecutor + + llm = LLM( + repo_id, + **kwargs, + distributed_executor_backend=distributed_executor_backend, + ) + + self.vllm_entrypoint = llm + + # load the tokenizer + self.tokenizer = self._load_tokenizer( + model_config=llm.llm_engine.model_config, + scheduler_config=llm.llm_engine.scheduler_config, + parallel_config=llm.llm_engine.parallel_config, + enable_lora=bool(llm.llm_engine.lora_config), + ) + + return llm.llm_engine.model_executor.driver_worker.model_runner.model + + def _prepare_input( + self, *args, **kwargs + ) -> Tuple[Tuple[Tuple[Any], Dict[str, Any]], int]: + + if "processed" in kwargs: + return (args, kwargs), len(args[0]) + + prompts = [] + params = [] + + for arg in args: + + if not type(arg) is list: + arg = [arg] + + for prompt in arg: + + param = NNsightSamplingParams( + **kwargs, + ) + + prompts.append(prompt) + params.append(param) + + return ((prompts, params), {"processed": True}), len(prompts) + + def _batch( + self, + batched_inputs: Tuple[Tuple[Any] | protocols.Dict[str, Any]] | None, + prompts: List[str], + params: List[NNsightSamplingParams], + **kwargs, + ) -> Tuple[Tuple[Any] | protocols.Dict[str, Any]]: + + if batched_inputs is None: + batched_inputs = ([], []), {"invoker_group": 0} + + (bprompts, bparams), kwargs = batched_inputs + + invoker_group = kwargs["invoker_group"] + + for prompt in prompts: + bprompts.append(prompt) + + for param in params: + + param.invoker_group = invoker_group + + bparams.append(param) + + kwargs["invoker_group"] += 1 + + return (bprompts, bparams), kwargs + + def interleave( + self, + fn: Callable, + intervention_graph: Graph, + prompts: List[str], + params: List[NNsightSamplingParams], + **kwargs, + ) -> Any: + + if not self.dispatched: + self.dispatch() + + for param in params: + + param.intervention_graph = intervention_graph + + fn(prompts, params, **kwargs) + + intervention_graph.alive = False + + def _execute( + self, + prompts: List[str], + params: List[NNsightSamplingParams], + **kwargs, + ) -> Any: + + self.vllm_entrypoint.generate(prompts, sampling_params=params) From 5b95b4ff12f63adc9521f922836d7e310ae62b08 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Tue, 19 Nov 2024 11:02:21 -0500 Subject: [PATCH 3/5] Buf fixes --- src/nnsight/intervention/graph/graph.py | 85 +++++++++++++------------ src/nnsight/modeling/vllm/vllm.py | 13 ++-- 2 files changed, 50 insertions(+), 48 deletions(-) diff --git a/src/nnsight/intervention/graph/graph.py b/src/nnsight/intervention/graph/graph.py index 8ffa2df3..7512aacc 100755 --- a/src/nnsight/intervention/graph/graph.py +++ b/src/nnsight/intervention/graph/graph.py @@ -1,14 +1,15 @@ +import copy import sys from collections import defaultdict -from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union, Tuple +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union from typing_extensions import Self from ...tracing.contexts import Context from ...tracing.graph import SubGraph +from ...util import NNsightError from ..protocols import ApplyModuleProtocol, GradProtocol, InterventionProtocol from . import InterventionNode, InterventionNodeType, InterventionProxyType -from ...util import NNsightError if TYPE_CHECKING: from .. import NNsight @@ -36,7 +37,7 @@ def __init__( ) -> None: super().__init__(*args, **kwargs) - + self.model = model self.interventions: Dict[str, List[InterventionNode]] = defaultdict(list) @@ -44,7 +45,7 @@ def __init__( self.compiled = False self.call_counter: Dict[int, int] = defaultdict(int) - self.deferred:Dict[int, List[int]] = defaultdict(list) + self.deferred: Dict[int, List[int]] = defaultdict(list) def __getstate__(self) -> Dict: @@ -108,15 +109,13 @@ def compile(self) -> None: return self.interventions if len(self.nodes) == 1: - return + return intervention_subgraphs: List[SubGraph] = [] start = self[0].index - if isinstance(self[0].target, type) and issubclass( - self[0].target, Context - ): + if isinstance(self[0].target, type) and issubclass(self[0].target, Context): graph = self[0].args[0] if len(graph) > 0: @@ -133,17 +132,17 @@ def compile(self) -> None: node: InterventionNodeType = self.nodes[index] if context_node is None and node.graph is not self: - + context_node = self.nodes[node.graph[-1].index + 1] context_start = self.subset.index(context_node.index) - + defer_start = node.index self.context_dependency(context_node, intervention_subgraphs) if node.target is InterventionProtocol: - + subgraph = SubGraph(self, subset=sorted(list(node.subgraph()))) module_path, *_ = node.args @@ -207,18 +206,24 @@ def compile(self) -> None: self.compiled = True - def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_start:int=0) -> None: - + def execute( + self, + start: int = 0, + grad: bool = False, + defer: bool = False, + defer_start: int = 0, + ) -> None: + err: Tuple[int, NNsightError] = None - + if defer_start in self.deferred: - + for index in self.deferred[defer_start]: - + self.nodes[index].reset() - + del self.deferred[defer_start] - + if defer: self.defer_stack.append(defer_start) @@ -227,7 +232,9 @@ def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_st if node.executed: continue - elif node.index != self[start].index and node.target is InterventionProtocol: + elif ( + node.index != self[start].index and node.target is InterventionProtocol + ): break elif node.fulfilled: try: @@ -241,7 +248,7 @@ def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_st continue else: break - + if defer: self.defer_stack.pop() @@ -252,9 +259,7 @@ def execute(self, start: int = 0, grad: bool = False, defer:bool=False, defer_st self.defer_stack = defer_stack raise err[1] - def count( - self, index: int, iteration: Union[int, List[int], slice] - ) -> bool: + def count(self, index: int, iteration: Union[int, List[int], slice]) -> bool: """Increments the count of times a given Intervention Node has tried to be executed and returns if the Node is ready and if it needs to be deferred. Args: @@ -296,7 +301,7 @@ def count( self.call_counter[index] += 1 return ready, defer - + def clean(self, start: Optional[int] = None): if start is None: @@ -306,15 +311,14 @@ def clean(self, start: Optional[int] = None): # Loop over ALL nodes within the span of this graph. for index in range(start, end): - + node = self.nodes[index] - + if node.executed: break node.update_dependencies() - - + def cleanup(self) -> None: """Because some modules may be executed more than once, and to accommodate memory management just like a loop, intervention graph sections defer updating the remaining listeners of Nodes if this is not the last time this section will be executed. @@ -340,8 +344,8 @@ def cleanup(self) -> None: if dependency.redundant: dependency.destroy() - - def copy(self, + def copy( + self, new_graph: Self = None, parent: Optional["GraphType"] = None, memo: Optional[Dict[int, "NodeType"]] = None, @@ -349,30 +353,29 @@ def copy(self, if memo is None: memo = {} - + new_graph = super().copy(new_graph, parent, memo) new_graph.compiled = self.compiled - interventions = {} - for module_path, list_of_nodes in self.interventions.items(): - interventions[module_path] = [new_graph.nodes[memo[node.index]] for node in list_of_nodes] + new_graph.interventions[module_path] = [ + new_graph.nodes[memo[node.index]] for node in list_of_nodes + ] - new_graph.interventions = interventions - - new_graph.call_counter = self.call_counter.copy() + new_graph.call_counter = { + memo[index]: value for index, value in self.call_counter.items() + } - new_graph.deferred = self.deferred.copy() + new_graph.deferred = copy.deepcopy(self.deferred) - new_graph.grad_subgraph = self.grad_subgraph.copy() + new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph] - new_graph.defer_stack = self.defer_stack.copy() + new_graph.defer_stack = [memo[index] for index in self.defer_stack] return new_graph - # @classmethod # def shift(cls, mgraph: MultiGraph) -> MultiGraph: diff --git a/src/nnsight/modeling/vllm/vllm.py b/src/nnsight/modeling/vllm/vllm.py index e3eb8889..e2d6055c 100644 --- a/src/nnsight/modeling/vllm/vllm.py +++ b/src/nnsight/modeling/vllm/vllm.py @@ -8,9 +8,9 @@ from .executors.RayGPUExecutor import NNsightRayGPUExecutor from .sampling import NNsightSamplingParams from dataclasses import fields +from ...intervention.interleaver import Interleaver if TYPE_CHECKING: - from ...intervention.interleaver import Interleaver from ...intervention.graph import InterventionGraph from torch.nn import Module from vllm.transformers_utils.tokenizer import AnyTokenizer @@ -165,7 +165,7 @@ def _load(self, repo_id: str, **kwargs) -> "Module": enable_lora=bool(llm.llm_engine.lora_config), ) - if kwargs["tensor_parallel_size"] > 1: + if kwargs.get("tensor_parallel_size", 1) > 1: return llm.llm_engine.model_executor.driver_worker.worker.model_runner.model else: return llm.llm_engine.model_executor.driver_worker.model_runner.model @@ -226,11 +226,9 @@ def _batch( def interleave( self, + interleaver: Interleaver, *args, fn: Optional[Union[Callable, str]] = None, - intervention_graph: Optional["InterventionGraph"] = None, - interleaver: Optional["Interleaver"] = None, - batch_groups: Optional[List[Tuple[int, int]]] = None, **kwargs, ) -> Any: @@ -244,13 +242,14 @@ def interleave( fn(prompts, params, **kwargs) intervention_graph.alive = False """ - + + if not self.dispatched: self.dispatch() for param in args[1]: - param.intervention_graph = intervention_graph + param.intervention_graph = interleaver.graph if fn is None: fn = self._execute From f6b65e036e220c8399cef00479961670a88f2882 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Tue, 19 Nov 2024 12:53:11 -0500 Subject: [PATCH 4/5] MOre fixess --- src/nnsight/intervention/graph/graph.py | 23 ++++++++++--------- .../vllm/model_runners/GPUModelRunner.py | 2 ++ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/nnsight/intervention/graph/graph.py b/src/nnsight/intervention/graph/graph.py index 7512aacc..7a6edf86 100755 --- a/src/nnsight/intervention/graph/graph.py +++ b/src/nnsight/intervention/graph/graph.py @@ -354,25 +354,26 @@ def copy( if memo is None: memo = {} - new_graph = super().copy(new_graph, parent, memo) + new_graph = super().copy(new_graph, parent=parent, memo=memo) new_graph.compiled = self.compiled - for module_path, list_of_nodes in self.interventions.items(): + for key, value in self.call_counter.items(): + self.call_counter[memo[key]] = value - new_graph.interventions[module_path] = [ - new_graph.nodes[memo[node.index]] for node in list_of_nodes - ] + if new_graph.compiled: - new_graph.call_counter = { - memo[index]: value for index, value in self.call_counter.items() - } + for module_path, list_of_nodes in self.interventions.items(): + + new_graph.interventions[module_path] = [ + new_graph.nodes[memo[node.index]] for node in list_of_nodes + ] - new_graph.deferred = copy.deepcopy(self.deferred) + for key, values in self.deferred.items(): - new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph] + new_graph[memo[key]] = [memo[index] for index in values] - new_graph.defer_stack = [memo[index] for index in self.defer_stack] + new_graph.grad_subgraph = [memo[index] for index in self.grad_subgraph] return new_graph diff --git a/src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py b/src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py index 9cffe14b..6cf2e3f5 100755 --- a/src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py +++ b/src/nnsight/modeling/vllm/model_runners/GPUModelRunner.py @@ -233,6 +233,8 @@ def execute_model( ## NNSIGHT ######################################### intervention_graph = model_input.sampling_metadata.intervention_graph + + intervention_graph.set(self.model) batch_groups = model_input.sampling_metadata.batch_groups From f9a56202ebb8fac428605a8df70027a9044fabd2 Mon Sep 17 00:00:00 2001 From: Adam Belfki Date: Tue, 19 Nov 2024 17:00:42 -0500 Subject: [PATCH 5/5] fix how sampling params are passing in through the nnsight api to the vllm executor --- src/nnsight/modeling/vllm/sampling.py | 1 + src/nnsight/modeling/vllm/vllm.py | 11 +++++++++++ src/nnsight/tracing/hacks/conditional.py | 4 +++- src/nnsight/tracing/hacks/iterator.py | 4 +++- 4 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/nnsight/modeling/vllm/sampling.py b/src/nnsight/modeling/vllm/sampling.py index 1f275a15..293bbdbd 100755 --- a/src/nnsight/modeling/vllm/sampling.py +++ b/src/nnsight/modeling/vllm/sampling.py @@ -18,6 +18,7 @@ class NNsightSamplingParams(SamplingParams): intervention_graph: Optional[InterventionGraph] = None invoker_group: Optional[int] = None + is_default_param: bool = True def clone(self) -> "SamplingParams": """Deep copy excluding LogitsProcessor objects. diff --git a/src/nnsight/modeling/vllm/vllm.py b/src/nnsight/modeling/vllm/vllm.py index e2d6055c..c3c34c5d 100644 --- a/src/nnsight/modeling/vllm/vllm.py +++ b/src/nnsight/modeling/vllm/vllm.py @@ -191,6 +191,9 @@ def _prepare_input( **kwargs, ) + if kwargs != {}: + param.is_default_param = False + prompts.append(prompt) params.append(param) @@ -265,6 +268,14 @@ def _execute( **kwargs, ) -> Any: + kwargs.pop('invoker_group') + + for param in params: + if param.is_default_param: + for attr, value in kwargs.items(): + if hasattr(NNsightSamplingParams, attr): + setattr(param, attr, value) + self.vllm_entrypoint.generate(prompts, sampling_params=params) if TYPE_CHECKING: diff --git a/src/nnsight/tracing/hacks/conditional.py b/src/nnsight/tracing/hacks/conditional.py index cc2f7927..6a2a76ce 100755 --- a/src/nnsight/tracing/hacks/conditional.py +++ b/src/nnsight/tracing/hacks/conditional.py @@ -14,7 +14,9 @@ def handle_conditional(frame: FrameType, condition: "Proxy"): line_no = frame.f_lineno - source_lines, _ = inspect.getsourcelines(frame) + source_file = inspect.getsourcefile(frame) + with open(source_file, "r") as file: + source_lines = file.readlines() source = "".join(source_lines) tree = ast.parse(source) diff --git a/src/nnsight/tracing/hacks/iterator.py b/src/nnsight/tracing/hacks/iterator.py index f2fe6e91..f045bac8 100755 --- a/src/nnsight/tracing/hacks/iterator.py +++ b/src/nnsight/tracing/hacks/iterator.py @@ -13,7 +13,9 @@ def handle_iterator(frame: FrameType, collection: "Proxy"): line_no = frame.f_lineno - source_lines, _ = inspect.getsourcelines(frame) + source_file = inspect.getsourcefile(frame) + with open(source_file, "r") as file: + source_lines = file.readlines() source = "".join(source_lines) tree = ast.parse(source)