Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vllm #267

Closed
wants to merge 46 commits into from
Closed

Vllm #267

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
54bf39a
.stream
JadenFiotto-Kaufman Sep 20, 2024
c918454
Merge branch 'dev' into streaming-protocol
JadenFiotto-Kaufman Sep 25, 2024
95406ae
Streaming protocol sending
JadenFiotto-Kaufman Sep 25, 2024
5e3d764
Response object returned as pickled bytes vs json.
JadenFiotto-Kaufman Sep 25, 2024
7fb92bf
Bug fixing streaming
JadenFiotto-Kaufman Sep 26, 2024
c50a295
Streaming working!
JadenFiotto-Kaufman Sep 26, 2024
bba6d7d
Streaming upload!
JadenFiotto-Kaufman Sep 26, 2024
490828a
Streaming upload bug fix
JadenFiotto-Kaufman Sep 26, 2024
efca9d2
Complex upload node pre-processing. Almost something amazing.
JadenFiotto-Kaufman Sep 26, 2024
5a7f595
Bug fix
JadenFiotto-Kaufman Sep 26, 2024
7bfd978
My crowning achievement
JadenFiotto-Kaufman Sep 27, 2024
078e079
Streaming documentation. Helper decorators added to nnsight
JadenFiotto-Kaufman Sep 29, 2024
c346dcb
.all()
JadenFiotto-Kaufman Sep 29, 2024
fcf718c
Add LanguageModel Generator.Streamer to get the value of the streamer…
JadenFiotto-Kaufman Sep 29, 2024
0d89440
Bug fix
JadenFiotto-Kaufman Sep 29, 2024
23ef319
Bug fix
JadenFiotto-Kaufman Sep 30, 2024
112dfba
Bug fix
JadenFiotto-Kaufman Sep 30, 2024
5e7d210
Deferring node destruction when execution nodes more than once using …
JadenFiotto-Kaufman Sep 30, 2024
902260d
.all() bug fix
JadenFiotto-Kaufman Oct 3, 2024
af37e0a
Use msgspec and zlib for sending data.
JadenFiotto-Kaufman Oct 5, 2024
2f7f0db
Undo last commit
JadenFiotto-Kaufman Oct 5, 2024
61f0344
Ability to specify Tracer type in NNsight. Ability to set interventio…
JadenFiotto-Kaufman Oct 10, 2024
4e472b0
Remove Protocol compilation
JadenFiotto-Kaufman Oct 10, 2024
25d685e
Refactor Graph copy. Added MultiGraph
JadenFiotto-Kaufman Oct 10, 2024
5790732
Refactor InterventionProtocol compilation. Added ability to shift inv…
JadenFiotto-Kaufman Oct 10, 2024
fe5e16e
Base VLLM working
JadenFiotto-Kaufman Oct 10, 2024
0f22ea5
.next() working. needed to manually set batch size as extra "cuda_gra…
JadenFiotto-Kaufman Oct 10, 2024
c87a237
Pickleable Graph
JadenFiotto-Kaufman Oct 10, 2024
a32d93e
Always send the full intervention graph for every invoker group. Its …
JadenFiotto-Kaufman Oct 10, 2024
f2bfa16
Remove unnecessary vllm classes. Make interleave return output.
JadenFiotto-Kaufman Oct 10, 2024
21e75f6
vllm inner function to capture additional modules.
JadenFiotto-Kaufman Oct 10, 2024
a1d1a20
vllm updates
JadenFiotto-Kaufman Oct 10, 2024
b5a362c
start
JadenFiotto-Kaufman Oct 10, 2024
2dd23c9
Merge pull request #268 from ndif-team/streaming-protocol
JadenFiotto-Kaufman Oct 10, 2024
58a0d6c
Merge pull request #269 from ndif-team/envoy-all
JadenFiotto-Kaufman Oct 10, 2024
2090067
Merge
JadenFiotto-Kaufman Oct 10, 2024
7d3f739
Merge branch '0.4' into vllm
JadenFiotto-Kaufman Oct 10, 2024
d42797f
Merge branch 'vllm' into addons-methods
JadenFiotto-Kaufman Oct 10, 2024
84c757c
Remove addons attr
JadenFiotto-Kaufman Oct 10, 2024
f8aee6a
Big refactor
JadenFiotto-Kaufman Oct 11, 2024
49420c7
Big refactor
JadenFiotto-Kaufman Oct 11, 2024
f85b70f
Merge branch 'addons-methods' into vllm
JadenFiotto-Kaufman Oct 11, 2024
60ea119
feat (vllm): Add support for tensor parallelism with nnsight.VLLM + e…
AdamBelfki3 Oct 11, 2024
bd6227c
vllm stuff
JadenFiotto-Kaufman Oct 11, 2024
78d0bb7
Merge branch 'vllm-tp' into vllm
JadenFiotto-Kaufman Oct 11, 2024
e413131
vllm updates and fixes
JadenFiotto-Kaufman Oct 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 93 additions & 7 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# #
# :::: ::: :::: ::: :::::::: ::::::::::: :::::::: ::: ::: ::::::::::: ::::::: :::::::: #
# :+:+: :+: :+:+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: :+: #
Expand All @@ -8,10 +8,10 @@
# #+# #+#+# #+# #+#+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #+# #
# ### #### ### #### ######## ########### ######## ### ### ### ####### ### ######## #
# #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
import os
from functools import wraps
from typing import Dict, Union
from typing import Callable, Dict, Union

from importlib.metadata import version, PackageNotFoundError

Expand Down Expand Up @@ -56,11 +56,11 @@
from torch._subclasses.fake_tensor import FakeTensor


def _bool(self):
def fake_bool(self):
return True


DEFAULT_PATCHER.add(Patch(FakeTensor, _bool, "__bool__"))
DEFAULT_PATCHER.add(Patch(FakeTensor, fake_bool, "__bool__"))


def fake_tensor_new_wrapper(fn):
Expand Down Expand Up @@ -118,10 +118,11 @@ def noop(input: torch.Tensor, *args, **kwargs):
)

import warnings

_str = str
_bool = bool

try:



from torch.amp.autocast_mode import autocast, is_autocast_available

Expand Down Expand Up @@ -555,3 +556,88 @@ def set_module_tensor_to_device(
apply = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.apply
log = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.log
cond = GlobalTracingContext.GLOBAL_TRACING_CONTEXT.cond

import inspect

from . import util
from .intervention import InterventionProxy


def trace(fn: Callable):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`.
This is opposed to entering the function during tracing and tracing all inner operations.

Args:
fn (Callable): Function to apply.

Returns:
Callable: Traceable function.
"""

@wraps(fn)
def inner(*args, **kwargs):

return apply(fn, *args, **kwargs)

return inner


def local(object: Callable | InterventionProxy):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`
AND convert all input Proxies to local ones via `.local()`.

If a non-function is passed in, its assumed to be an `InterventionProxy` and `.local()` is called and returned.

Args:
object ( Callable | InterventionProxy): Function to apply or Proxy to make local.

Returns:
Callable | InterventionProxy: Traceable local function or local Proxy.
"""

if inspect.isroutine(object):

fn = trace(object)

@wraps(fn)
def inner(*args, **kwargs):

args, kwargs = util.apply(
(args, kwargs), lambda x: x.local(), InterventionProxy
)

return fn(*args, **kwargs)

return inner

return object.local()


def remote(object: Callable | Any):
"""Helper decorator to add a function to the intervention graph via `.apply(...)`
AND convert all input Proxies to downloaded local ones via `.local()`
AND convert the output to an uploaded remote one via `remote()`.

If a non-function is passed in, `remote(object)` is called and returned.

Args:
object ( Callable | Any): Function to apply or object to make remote.

Returns:
Callable | InterventionProxy: Traceable local -> remote function or remote Proxy.
"""

if inspect.isroutine(object):

fn = local(object)

@wraps(fn)
def inner(*args, **kwargs):

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(
fn(*args, **kwargs)
)

return inner

return GlobalTracingContext.GLOBAL_TRACING_CONTEXT.remote(object)
28 changes: 17 additions & 11 deletions src/nnsight/contexts/GraphBasedContext.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def log(self, *data: Any) -> None:
data (Any): Data to print.
"""
self.apply(print, *data)

def remote(self, data:Any) -> InterventionProxy:
"""Streams data remotely when it becomes available locally.
The remote service will block until the local value is uploaded and received.

Is a no-op when not executing remotely.

Returns:
InterventionProxy: Proxy.
"""

return protocols.StreamingUploadProtocol.add(self.graph, data)

def bool(self, *args, **kwargs) -> InterventionProxy:
"""NNsight helper method to create a traceable bool."""
Expand Down Expand Up @@ -208,13 +220,16 @@ def __enter__(self) -> Self:
def __exit__(self, exc_type, exc_val, exc_tb) -> None:

GlobalTracingContext.try_deregister(self)

if isinstance(exc_val, BaseException):
self.graph.alive = False
self.graph = None
raise exc_val

self.backend(self)

if not isinstance(self.graph, weakref.ProxyType):
self.graph = weakref.proxy(self.graph)


### BACKENDS ########

Expand All @@ -225,22 +240,13 @@ def local_backend_execute(self) -> None:
self.graph.execute()
except protocols.EarlyStopProtocol.EarlyStopException as e:
raise e
finally:
graph = self.graph
graph.alive = False

if not isinstance(graph, weakref.ProxyType):
self.graph = weakref.proxy(graph)

def bridge_backend_handle(self, bridge: Bridge) -> None:

bridge.pop_graph()

protocols.LocalBackendExecuteProtocol.add(self, bridge.peek_graph())

self.graph = weakref.proxy(self.graph)


from inspect import getmembers, isclass

from torch.utils import data
Expand Down
30 changes: 16 additions & 14 deletions src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ class Invoker(AbstractContextManager):
def __init__(
self,
tracer: "Tracer",
*inputs: Any,
*args,
scan: bool = False,
**kwargs,
) -> None:

self.tracer = tracer
self.inputs = inputs
self.input = (args, kwargs)

self.scan = scan
self.kwargs = kwargs

self.scanning = False

Expand All @@ -71,28 +71,28 @@ def __enter__(self) -> Invoker:
# Set self.inputs to be the proxy_value so we can prepare_inputs, get the batch size, and scan.
if self.tracer.model._session is not None:

self.inputs, has_proxies_in_inputs = check_for_dependencies(
self.inputs
self.input, has_proxies_in_inputs = check_for_dependencies(
self.input
)

with GlobalTracingContext.exit_global_tracing_context():

if not has_proxies_in_inputs:

self.inputs, batch_size = self.tracer.model._prepare_inputs(
*self.inputs, **self.kwargs
self.input, batch_size = self.tracer.model._prepare_input(
*self.input[0], **self.input[1]
)

if self.scan:

inputs = self.inputs
input = self.input

if has_proxies_in_inputs:

inputs = util.apply(inputs, lambda x: x.proxy_value, Node)
input = util.apply(input, lambda x: x.proxy_value, Node)

inputs, batch_size = self.tracer.model._prepare_inputs(
*inputs, **self.kwargs
input, batch_size = self.tracer.model._prepare_input(
*input[0], **input[1]
)

self.tracer.model._envoy._clear()
Expand All @@ -112,8 +112,10 @@ def __enter__(self) -> Invoker:
shape_env=ShapeEnv(assume_static_by_default=True),
) as fake_mode:
with FakeCopyMode(fake_mode):
self.tracer.model._execute(
*copy.deepcopy(inputs),
fn = self.tracer.model._execute if self.tracer.method is None else getattr(self.tracer.model, self.tracer.method)
fn(
*copy.deepcopy(input[0]),
**copy.deepcopy(input[1]),
**copy.deepcopy(self.tracer._kwargs),
)

Expand All @@ -122,7 +124,7 @@ def __enter__(self) -> Invoker:
else:
self.tracer.model._envoy._reset()

self.tracer._invoker_inputs.append(self.inputs)
self.tracer._invoker_inputs.append(self.input)

return self

Expand Down
Loading
Loading