Skip to content

Commit 587506d

Browse files
committed
torch tensor can't handle uint16 so let's convert to int32, which is silly because we'll convert to .long right after but ok
1 parent 170c92b commit 587506d

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

train_gpt2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,8 @@ def print0(*args, **kwargs):
514514
ntok = header[2] # number of tokens (claimed)
515515
# the rest of it are tokens, stored as uint16
516516
tokens = np.frombuffer(f.read(), dtype=np.uint16)
517+
# convert tokens to int32 because torch can't handle uint16 sad
518+
tokens = tokens.astype(np.int32)
517519
assert len(tokens) == ntok, "number of tokens read does not match header?"
518520

519521
# np -> tensor, long, on device

0 commit comments

Comments
 (0)