-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Question] Padding influences embedding #43
Comments
This is actually a feature(?) not a bug of the bidirectional SSM model that
we trained - since it’s not an attention-based model, it can implicitly use
those padding tokens between layers to carry forward information between
layers.
It’s something worth looking into for follow-up work, but we currently
don’t have bandwidth. We noticed some marginal performance degradation on
long sequences when explicitly zero-ing out those padding tokens.
Another idea we had but never got a chance to look into was aligning the
embeddings out of the 512, 2k, 8k, and 32k models so that you can
gracefully decrease the padding depending on input sequence length.
…On Fri, Jan 3, 2025 at 10:46 AM Eduard Zorita ***@***.***> wrote:
When using the tokenizer with padding="max_length" vs padding="longest
the generated embeddings are completely different. I believe this is a bug,
since padding tokens should be masked out, but it might be an intrinsic
side effect on how attention is computed.
Passing the same sentence tokenized with and without padding yields a
cosine similarity of 0.35.
This explains the issue described in #42
<#42>.
Reproduction script:
import torchfrom transformers import AutoModelForSequenceClassification, AutoTokenizer
max_seq_length = 32768testing_string = "Every morning, I make a cup of coffee to start my day."model = AutoModelForSequenceClassification.from_pretrained(
"togethercomputer/m2-bert-80M-32k-retrieval", trust_remote_code=True
).cuda()
tokenizer = AutoTokenizer.from_pretrained(
"bert-base-uncased", model_max_length=max_seq_length
)input_ids = tokenizer(
[testing_string],
return_tensors="pt",
padding="max_length",
return_token_type_ids=False,
truncation=True,
max_length=max_seq_length,
).to("cuda")
input_ids_no_padding = tokenizer(
[testing_string],
return_tensors="pt",
padding="longest",
return_token_type_ids=False,
truncation=True,
max_length=max_seq_length,
).to("cuda")
with torch.no_grad():
outputs = model(**input_ids)
embeddings = outputs["sentence_embedding"].to("cpu")
with torch.no_grad():
outputs = model(**input_ids_no_padding)
embeddings_no_padding = outputs["sentence_embedding"].to("cpu")
allclose = torch.allclose(embeddings, embeddings_no_padding)print(f"allclose: {allclose}")
# Print cosine similarity between the two embeddingscosine_similarity = torch.nn.functional.cosine_similarity(
embeddings, embeddings_no_padding
)print(f"cosine similarity: {cosine_similarity}")
Output:
allclose: False
cosine similarity: tensor([0.3504])
—
Reply to this email directly, view it on GitHub
<#43>, or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIR6HMBVVIKLOZCQMYL2IZL6JAVCNFSM6AAAAABURMR5G2VHI2DSMVQWIX3LMV43ASLTON2WKOZSG43DOMRXGUYDIOA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Thanks for your quick response @DanFu09, I believe you should explicitly state in the huggingface docs that the tokenizer must be used with |
Will update soon, thanks for the suggestion!
…On Fri, Jan 3, 2025 at 10:57 AM Eduard Zorita ***@***.***> wrote:
Thanks for your quick response @DanFu09 <https://github.com/DanFu09>, I
believe you should explicitly state in the huggingface docs that the
tokenizer must be used with padding="max_length", a python comment on
that line would be enough. Otherwise users may be tempted to use longest
to increase the inference efficiency.
—
Reply to this email directly, view it on GitHub
<#43 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/ABDDIIQKGRBW77D7UIPBTJT2IZNHNAVCNFSM6AAAAABURMR5G2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKNRYHE3DONBWG4>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
This was referenced Jan 3, 2025
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
When using the tokenizer with
padding="max_length"
vspadding="longest
the generated embeddings are completely different. I believe this is a bug, since padding tokens should be masked out, but it might be an intrinsic side effect on how attention is computed.Passing the same sentence tokenized with and without padding yields a cosine similarity of 0.35.
This explains the issue described in #42.
Reproduction script:
Output:
The text was updated successfully, but these errors were encountered: