Skip to content

Commit

Permalink
torch tensor can't handle uint16 so let's convert to int32, which is …
Browse files Browse the repository at this point in the history
…silly because we'll convert to .long right after but ok
  • Loading branch information
karpathy committed May 21, 2024
1 parent 170c92b commit 587506d
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,8 @@ def print0(*args, **kwargs):
ntok = header[2] # number of tokens (claimed)
# the rest of it are tokens, stored as uint16
tokens = np.frombuffer(f.read(), dtype=np.uint16)
# convert tokens to int32 because torch can't handle uint16 sad
tokens = tokens.astype(np.int32)
assert len(tokens) == ntok, "number of tokens read does not match header?"

# np -> tensor, long, on device
Expand Down

0 comments on commit 587506d

Please sign in to comment.