Conversation
|
Hey @callummcdougall I've pushed:
|
|
Got it, sorry for causing undue work - yes in the future will make sure to add tests! I wasn't sure about putting it in the sae config cause it's about the SAE's training data (or what inputs make sense for it) but not about e.g. the SAE's actual architecture. I was basing this on the fact that |
|
@callummcdougall I think the idea is that if you couldn't evaluate the SAE without knowing about this property, then it needs to be in the SAE config. Speaking of which I don't see any changes to the evals.py but presumably we should ensure that evals are only run on seqpos positions? Are you able to do this? |
There was a problem hiding this comment.
Code-wise this looks good to me, and looks like a reasonable addition to the library! Will defer to @jbloomAus if this is OK to merge. I guess there's a question fo whether the expectation is that this would require different evals, or if this is something that only effects training.
| activations = activation_store.get_activations(batch) | ||
|
|
||
| assert batch.shape == (1, 10) # Full context size | ||
| assert activations.shape == (1, 6, 1, cfg.d_in) # Only 6 positions (2 to 7) |
There was a problem hiding this comment.
nice! Really great test 🥇
Think it does seem valuable to also have the logged metrics during training only apply to the right sequence positions - is that what you meant @jbloomAus , or did you mean evals that are applied in a non-training context? Either way I can likely get to that later this week |
This allows seqpos slicing during training. Basically we add a
seqpos_slicearg to theLanguageModelSAERunnerConfig(in the form of a tuple, which gets converted to a slice viaslice(*seqpos_slice)- this is because slice objects aren't serializable when we're saving the config).Apart from this config, the only other file getting changed is
activations_store.py. It now has aseqpos_sliceattribute, and it uses this to slice the activations which are fetched fromget_activations(and which are used inget_buffer).Note that the default behaviour is
seqpos_slice = (None,), which slices over all sequence positions. Also note thatseqpos_slicecan be used in conjunction withcontext_size(i.e. one doesn't make the other redundant).