Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Envoy.iter! #274

Merged
merged 2 commits into from
Nov 1, 2024
Merged

Envoy.iter! #274

merged 2 commits into from
Nov 1, 2024

Conversation

JadenFiotto-Kaufman
Copy link
Member

New Feature ! @AdamBelfki3 @cadentj

New paradigm to specify module iterations.

Here is me specifying I want an intervention to apply to all iterations using a global .all():

from nnsight import LanguageModel
from nnsight.intervention import InterventionProtocol
from nnsight import list
import torch

model =  LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="auto", torch_dtype=torch.bfloat16)

from nnsight import list

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
        
    model.all()
        
    values.append(model.lm_head.output)

print(len(values.value)) # Prints 5

Now if I only wanted to do some interventions every iteration, I can use it like a context manager:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.all():
    
        values.append(model.lm_head.output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 5
print(len(other_values.value)) # Prints 1

.all() is an alias for .iter[:]. Yes thats right, you can specify a specific iteration with an int, multiple iterations with a list of ints, or a range using a slice:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.iter[2:4]:
    
        values.append(model.lm_head.output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1
with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    with model.iter[[0,1,4]]:
    
        values.append(model.lm_head.output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 3
print(len(other_values.value)) # Prints 1

This also works inline:

with model.generate("hello world", max_new_tokens=5):
    values = list().save()
    other_values = list().save()

    values.append(model.lm_head.iter[2:4].output)
            
    other_values.append(model.lm_head.output)

print(len(values.value)) # Prints 2
print(len(other_values.value)) # Prints 1

same thing for .all() applies to .next()

@JadenFiotto-Kaufman JadenFiotto-Kaufman merged commit cad7015 into 0.4 Nov 1, 2024
1 check passed
@JadenFiotto-Kaufman JadenFiotto-Kaufman deleted the envoy-iter branch December 4, 2024 21:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant