Skip to content

Commit

Permalink
More meta wrappers (forced by Mixtral 8bx7)
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Dec 21, 2023
1 parent e33b2e6 commit f6c078a
Showing 1 changed file with 33 additions and 5 deletions.
38 changes: 33 additions & 5 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,48 @@ def repeat_interleave(
)


def cpu_wrapper(fn):
def noop_wrapper(fn):
@wraps(fn)
def cpu(input: torch.Tensor, *args, **kwargs):
def noop(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
return input

else:
return fn(input, *args, **kwargs)

return cpu
return noop


DEFAULT_PATCHER.add(Patch(torch.Tensor, cpu_wrapper(torch.Tensor.cpu), "cpu"))
DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.cpu), "cpu"))

def onehot_wrapper(fn):
@wraps(fn)
def onehot(input: torch.Tensor, num_classes=-1):
if input.device.type == "meta":
return torch.zeros((*input.shape, num_classes), device='meta')

else:
return fn(input, num_classes=num_classes)

return onehot


DEFAULT_PATCHER.add(Patch(torch.nn.functional, onehot_wrapper(torch.nn.functional.one_hot), "one_hot"))

def where_wrapper(fn):
@wraps(fn)
def where(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
return input.to(torch.int)

else:
return fn(input, *args, **kwargs)

return where

DEFAULT_PATCHER.add(Patch(torch, where_wrapper(torch.where), "where"))

DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.tolist), "tolist"))

DEFAULT_PATCHER.__enter__()

Expand All @@ -95,5 +124,4 @@ def activate_recent_meta():
def local_scalar_dense_meta(A):
return 0


activate_recent_meta()

0 comments on commit f6c078a

Please sign in to comment.