We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 170c92b commit 587506dCopy full SHA for 587506d
train_gpt2.py
@@ -514,6 +514,8 @@ def print0(*args, **kwargs):
514
ntok = header[2] # number of tokens (claimed)
515
# the rest of it are tokens, stored as uint16
516
tokens = np.frombuffer(f.read(), dtype=np.uint16)
517
+ # convert tokens to int32 because torch can't handle uint16 sad
518
+ tokens = tokens.astype(np.int32)
519
assert len(tokens) == ntok, "number of tokens read does not match header?"
520
521
# np -> tensor, long, on device
0 commit comments