You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The NNSight object raises an IndexError when using unbatched token IDs as input while tracing in a loop. This bug is an oppressive landmine and the error message is not very helpful. It be nice if the trace invocation checks if the inputs are correctly batched using len(input_ids.shape) before applying the tensor concatenation. A simple solution is to use tracer.invoke(input_ids.unsqueeze(0)).
Root Cause
When an NNSight object batches inputs it uses the torch.concatenate() method to stack the input_id tensors along dimension zero (0). However, if a single dimension tensor is used as input, e.g., len(input_id.shape)==1, then concatenation appends the subsequent inputs to the original input_id tensor rather than stacking it. The downstream forward pass can potentially raise an IndexError after multiple concatenations because the resulting length of the "batched" inputs exceed the model context window. So, the forward pass will raise an IndexError when the inputs are being embedded because the input id indices exceed the number of columns in the embedding matrix.
Description
The
NNSight
object raises anIndexError
when using unbatched token IDs as input while tracing in a loop. This bug is an oppressive landmine and the error message is not very helpful. It be nice if the trace invocation checks if the inputs are correctly batched usinglen(input_ids.shape)
before applying the tensor concatenation. A simple solution is to usetracer.invoke(input_ids.unsqueeze(0))
.Root Cause
When an
NNSight
object batches inputs it uses thetorch.concatenate()
method to stack theinput_id
tensors along dimension zero (0). However, if a single dimension tensor is used as input, e.g.,len(input_id.shape)==1
, then concatenation appends the subsequent inputs to the originalinput_id
tensor rather than stacking it. The downstream forward pass can potentially raise anIndexError
after multiple concatenations because the resulting length of the "batched" inputs exceed the model context window. So, the forward pass will raise anIndexError
when the inputs are being embedded because the input id indices exceed the number of columns in the embedding matrix.Working Example
Failing Example
Info
The text was updated successfully, but these errors were encountered: