Skip to content

Commit

Permalink
Fix logic in LaguageModel init. The meta model (from_config) was not …
Browse files Browse the repository at this point in the history
…being used when a tokenizer was provided.
  • Loading branch information
JadenFiotto-Kaufman committed Sep 8, 2024
1 parent 0787ce8 commit 5772f6b
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 8 additions & 5 deletions src/nnsight/models/DiffusionModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs) -> None:

class DiffusionModel(GenerationMixin, NNsight):

def __new__(cls, *args, **kwargs) -> Self | Envoy:
def __new__(cls, *args, **kwargs) -> Self | Envoy | Diffuser:
return object.__new__(cls)

def __init__(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -72,9 +72,9 @@ def _batch_inputs(

if batched_inputs is None:

return prepared_inputs
return (prepared_inputs, )

return batched_inputs + prepared_inputs
return (batched_inputs + prepared_inputs, )

def _execute_forward(self, prepared_inputs: Any, *args, **kwargs):

Expand All @@ -96,8 +96,11 @@ def _execute_generate(

if seed is not None:

generator = generator.manual_seed(seed)

if isinstance(prepared_inputs, list):
generator = [torch.Generator().manual_seed(seed) for _ in range(len(prepared_inputs))]
else:
generator = generator.manual_seed(seed)

output = self._model.pipeline(
prepared_inputs, *args, generator=generator, **kwargs
)
Expand Down
2 changes: 2 additions & 0 deletions src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ def _load(

if not hasattr(self.tokenizer.pad_token, "pad_token"):
self.tokenizer.pad_token = self.tokenizer.eos_token

if self._model is None:

if (
patch_llama_scan
Expand Down

0 comments on commit 5772f6b

Please sign in to comment.