Skip to content

Commit

Permalink
first draft, apparently this works. needs cleanups, and also we are n…
Browse files Browse the repository at this point in the history
…ot yet utilizing the full batch dimension. we actually have to load in multiple examples and fully utilize batch
  • Loading branch information
karpathy committed May 22, 2024
1 parent 69f1221 commit 051f3ca
Show file tree
Hide file tree
Showing 5 changed files with 317 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# data directories
dev/data/__pycache__/
dev/data/fineweb/
dev/data/fineweb10B/
dev/data/hellaswag/
dev/data/mmlu/
dev/data/tinyshakespeare/
Expand Down
170 changes: 170 additions & 0 deletions dataloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,173 @@ void dataloader_free(DataLoader *loader) {
fcloseCheck(loader->tokens_file);
globfree(&loader->glob_result);
}

// ----------------------------------------------------------------------------
// Distributed Eval Loader
// Many evals (like) HellaSwag and MMLU are multiple-choice
// where there are 4 possible continuations and a label for the correct one
// We want to load and serve these style of evals
/*
Copy pasting the section on the eval datafile format, from data_common.py:
- First comes a header with 256 int32s
- The examples follow, each example is a stream of uint16_t:
- <START_EXAMPLE> delimiter of 2**16-1, i.e. 65,535
- <EXAMPLE_BYTES>, bytes encoding this example, allowing efficient skip to next
- <EXAMPLE_INDEX>, the index of the example in the dataset
- <LABEL>, the index of the correct completion
- <NUM_COMPLETIONS>, indicating the number of completions (usually 4)
- <NUM><CONTEXT_TOKENS>, where <NUM> is the number of tokens in the context
- <NUM><COMPLETION_TOKENS>, repeated NUM_COMPLETIONS times
*/

typedef struct {
// variables related to distributed training
// each process/worker has to access different parts of the data
int process_rank;
int num_processes;
// hyperparameters. use size_t to prevent overflow
size_t B;
size_t T;
// input handling and its state
FILE* eval_file;
long file_size;
uint16_t* buffer; // we fread data from file into this buffer
// public variables that could be accessed from outside
int num_examples; // in total across all processes
int start_example_index; // the assignment of work for this process, start
int end_example_index; // and end. start is inclusive, end is exclusive
int* inputs; // input tokens into transformer
int* targets; // target tokens for the transformer
char* mask; // mask=1 at all completion token locations
int label; // the correct completion label
int num_completions; // number of completions for this example
} EvalLoader;

void evalloader_reset(EvalLoader *loader) {
// we have to be careful that each process starts at the correct offset.
// For example if there are N examples in the file and 4 processes,
// then process 0 should start at 0, process 1 at N/4, process 2 at N/2, etc.
long header_bytes = HEADER_SIZE * sizeof(int);
// determine which example we want this process to start at
int process_stride = loader->num_examples / loader->num_processes;
loader->start_example_index = process_stride * loader->process_rank;
loader->end_example_index = process_stride * (loader->process_rank + 1);
if (loader->end_example_index > loader->num_examples) {
loader->end_example_index = loader->num_examples;
}
// now seek through the file to the start of that example
// utilize <EXAMPLE_BYTES> for efficiency
fseekCheck(loader->eval_file, header_bytes, SEEK_SET);
for (int i = 0; i < loader->start_example_index; i++) {
uint16_t example_header[3];
// read 3 uint16_t values: <START_EXAMPLE>, <EXAMPLE_BYTES>, <EXAMPLE_INDEX>
freadCheck(&example_header[0], sizeof(uint16_t), 3, loader->eval_file);
// validate the <START_EXAMPLE> delimiter
assert(example_header[0] == 65535); // <START_EXAMPLE> delimiter
// validate the <EXAMPLE_INDEX>
assert(example_header[2] == i); // <EXAMPLE_INDEX> should match the loop index
// skip to the next example, keeping in mind that we already read the header
size_t remaining_bytes = example_header[1] - sizeof(uint16_t) * 3;
assert(remaining_bytes > 0); // we expect some bytes in the example
fseekCheck(loader->eval_file, remaining_bytes, SEEK_CUR);
}
// now we are at the start of the example we want to start at, pointing at <START_EXAMPLE>
}

void evalloader_init(EvalLoader *loader,
const char* filename,
size_t B,
size_t T,
int process_rank,
int num_processes) {
loader->process_rank = process_rank;
loader->num_processes = num_processes;
loader->B = B;
loader->T = T;

// open the file and validate the header
loader->eval_file = fopenCheck(filename, "rb");
// validate the header
int header[HEADER_SIZE];
freadCheck(header, sizeof(int), HEADER_SIZE, loader->eval_file);
if (header[0] != 20240522) { printf("Bad magic in eval file\n"); exit(EXIT_FAILURE); }
if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); }
loader->num_examples = header[2]; // number of tokens in the file
assert(loader->num_examples >= num_processes); // avoid headaches for now
size_t longest_example_bytes = header[3]; // longest example in the file
// basic sensibility check we could relax later. but roughly it's mostly
// the prompt/context and 4 completions, 2 bytes/token, so the longest example
// should be well below 5 times the context length or so (approx. napkin math)
assert(longest_example_bytes > 0 && longest_example_bytes < 5*T*2);

// allocate all the space we'll need
loader->buffer = (uint16_t*)malloc(longest_example_bytes);
loader->inputs = (int*)malloc(B * T * sizeof(int));
loader->targets = (int*)malloc(B * T * sizeof(int));
loader->mask = (char*)malloc(B * T * sizeof(char));
loader->label = -1; // initialize the label to an invalid value

// reset the loader, to initialize it
evalloader_reset(loader);
}

void evalloader_next_batch(EvalLoader *loader) {
// this function populates the inputs, targets, mask, and label fields
size_t B = loader->B;
size_t T = loader->T;
// read the current example header
uint16_t example_header[3];
freadCheck(&example_header[0], sizeof(uint16_t), 3, loader->eval_file);
// validate the <START_EXAMPLE> delimiter
assert(example_header[0] == 65535); // <START_EXAMPLE> delimiter
// validate the <EXAMPLE_INDEX>
assert(example_header[2] >= loader->start_example_index && example_header[2] < loader->end_example_index);
// read the rest of the example (we have space for 3 more uint16_t values in buffer, it's ok)
size_t example_bytes = example_header[1] - sizeof(uint16_t) * 3;
// read example_bytes into buffer. careful that this is actually in the units of bytes
freadCheck(loader->buffer, sizeof(char), example_bytes, loader->eval_file);
// process the example label
int label = (int)loader->buffer[0];
assert(label >= 0 && label < 4); // we expect the label to be in [0, 4) for right now
loader->label = label; // store for output
// process the number of completions
int num_completions = (int)loader->buffer[1];
assert(num_completions == 4); // we expect 4 completions for now
loader->num_completions = num_completions; // store for output
// init all inputs, targets, mask to zeros
memset(loader->inputs, 0, B * T * sizeof(int));
memset(loader->targets, 0, B * T * sizeof(int));
memset(loader->mask, 0, B * T * sizeof(char));
// process the context
// the context is shared for all completions, so we insert it into all data rows equally
int context_length = (int)loader->buffer[2];
uint16_t *context_tokens_start = &loader->buffer[3]; // where the tokens start
assert(context_length > 0 && context_length < T); // context is non-empty and up to T
for (int b = 0; b < num_completions; b++) {
for (int i = 0; i < context_length; i++) {
int tok_cur = (int)context_tokens_start[i];
loader->inputs[b * T + i] = tok_cur;
}
}
// process the completions, insert them in their row, right after the (shared) context
uint16_t *completions_iter = loader->buffer + 3 + context_length;
for (int c = 0; c < num_completions; c++) {
int completion_length = (int)completions_iter[0];
uint16_t *completion_tokens_start = completions_iter + 1;
assert(completion_length > 0 && context_length + completion_length < T); // things fit?
for (int i = 0; i < completion_length; i++) {
int tok_cur = (int)completion_tokens_start[i];
// at inputs, the completions simply follow the context
loader->inputs[c * T + context_length + i] = tok_cur;
// at targets things start to get tricky
// we expect the last context token to predict the first completion token
// and then onwards from there.
loader->targets[c * T + context_length + i - 1] = tok_cur;
// and at these positions, we want to set mask=1, because these are the
// positions where we want to average the loss, in each row, to determine
// its overall probability of following the context.
loader->mask[c * T + context_length + i - 1] = 1;
}
completions_iter += 1 + completion_length; // move to the next completion
}
}
61 changes: 61 additions & 0 deletions dev/data/data_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,64 @@ def write_datafile(filename, toks):
with open(filename, "wb") as f:
f.write(header.tobytes())
f.write(toks_np.tobytes())

def write_evalfile(filename, datas):
"""
Saves eval data as a .bin file, for reading in C.
Used for multiple-choice style evals, e.g. HellaSwag and MMLU
- First comes a header with 256 int32s
- The examples follow, each example is a stream of uint16_t:
- <START_EXAMPLE> delimiter of 2**16-1, i.e. 65,535
- <EXAMPLE_BYTES>, bytes encoding this example, allowing efficient skip to next
- <EXAMPLE_INDEX>, the index of the example in the dataset
- <LABEL>, the index of the correct completion
- <NUM_COMPLETIONS>, indicating the number of completions (usually 4)
- <NUM><CONTEXT_TOKENS>, where <NUM> is the number of tokens in the context
- <NUM><COMPLETION_TOKENS>, repeated NUM_COMPLETIONS times
"""
# construct the header
header = np.zeros(256, dtype=np.int32)
header[0] = 20240522 # magic
header[1] = 1 # version
header[2] = len(datas) # number of examples
header[3] = 0 # reserved for longest_example_bytes, fill in later
# now write the individual examples
longest_example_bytes = 0 # in units of uint16s
full_stream = [] # the stream of uint16s, we'll write a single time at the end
assert len(datas) < 2**16, "too many examples?"
for idx, data in enumerate(datas):
stream = []
# header of the example
stream.append(2**16-1) # <START_EXAMPLE>
stream.append(0) # <EXAMPLE_BYTES> (fill in later)
stream.append(idx) # <EXAMPLE_INDEX>
stream.append(data["label"]) # <LABEL>
ending_tokens = data["ending_tokens"]
assert len(ending_tokens) == 4, "expected 4 completions for now? can relax later"
stream.append(len(ending_tokens)) # <NUM_COMPLETIONS>
# the (shared) context tokens
ctx_tokens = data["ctx_tokens"]
assert all(0 <= t < 2**16-1 for t in ctx_tokens), "bad context token"
stream.append(len(ctx_tokens))
stream.extend(ctx_tokens)
# the completion tokens
for end_tokens in ending_tokens:
assert all(0 <= t < 2**16-1 for t in end_tokens), "bad completion token"
stream.append(len(end_tokens))
stream.extend(end_tokens)
# write to full stream
nbytes = len(stream)*2 # 2 bytes per uint16
assert nbytes < 2**16, "example too large?"
stream[1] = nbytes # fill in the <EXAMPLE_BYTES> field
longest_example_bytes = max(longest_example_bytes, nbytes)
full_stream.extend(stream)
# construct the numpy array
stream_np = np.array(full_stream, dtype=np.uint16)
# fill in the longest_example field
assert 0 < longest_example_bytes < 2**16, f"bad longest_example"
header[3] = longest_example_bytes
# write to file (for HellaSwag val this is 10,042 examples, 3.6MB file)
print(f"writing {len(datas):,} examples to {filename}")
with open(filename, "wb") as f:
f.write(header.tobytes())
f.write(stream_np.tobytes())
33 changes: 28 additions & 5 deletions dev/data/hellaswag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Downloads and evaluates HellaSwag in Python.
This then acts as the reference file for llm.c
Also writes the data (tokens, labels) to .bin files for parallel evaluation in C.
https://github.com/rowanz/hellaswag
Example HellaSwag json item:
Expand All @@ -22,6 +23,8 @@
gpt2-xl (1558M)
- eleuther harness reports acc 40.04%, acc_norm 50.89% (multiple choice style)
- this script: 10042 acc: 0.3842 acc_norm: 0.4893 (completion style)
The validation set of HellaSwag has a total of 10,042 examples.
"""

import os
Expand All @@ -33,7 +36,7 @@
import torch.nn as nn
from torch.nn import functional as F
from transformers import GPT2LMHeadModel
from data_common import download_file
from data_common import download_file, write_evalfile

# -----------------------------------------------------------------------------
DATA_CACHE_DIR = os.path.join(os.path.dirname(__file__), "hellaswag")
Expand Down Expand Up @@ -68,14 +71,23 @@ def render_example(example):
label = example["label"]
endings = example["endings"]

# data needed to reproduce this eval on the C size
data = {
"label": label,
"ctx_tokens": None,
"ending_tokens": [],
}

# gather up all the tokens
ctx_tokens = enc.encode(ctx)
data["ctx_tokens"] = ctx_tokens
tok_rows = []
mask_rows = []
for end in endings:
end_tokens = enc.encode(" " + end) # note: prepending " " because GPT-2 tokenizer
tok_rows.append(ctx_tokens + end_tokens)
mask_rows.append([0]*len(ctx_tokens) + [1]*len(end_tokens))
data["ending_tokens"].append(end_tokens)

# have to be careful during the collation because the number of tokens in each row can differ
max_len = max(len(row) for row in tok_rows)
Expand All @@ -85,17 +97,22 @@ def render_example(example):
tokens[i, :len(tok_row)] = torch.tensor(tok_row)
mask[i, :len(mask_row)] = torch.tensor(mask_row)

return tokens, mask, label
return data, tokens, mask, label

def iterate_examples(split):
# there are 10,042 examples in total in val

n = 0
download(split)
with open(os.path.join(DATA_CACHE_DIR, f"hellaswag_{split}.jsonl"), "r") as f:
for line in f:
example = json.loads(line)
n += 1
yield example

# DEBUGGING, TODO REMOVE
if n >= 100:
break

@torch.no_grad()
def evaluate(model_type, device):

Expand All @@ -105,11 +122,13 @@ def evaluate(model_type, device):
model.to(device)
# model = torch.compile(model)

datas = []
num_correct_norm = 0
num_correct = 0
num_total = 0
for example in iterate_examples("val"):
tokens, mask, label = render_example(example)
data, tokens, mask, label = render_example(example)
datas.append(data)
tokens = tokens.to(device)
mask = mask.to(device)

Expand Down Expand Up @@ -146,7 +165,11 @@ def evaluate(model_type, device):
print(f"Endings:")
for i, end in enumerate(example["endings"]):
print(f"{i} (loss: {avg_loss[i].item():.4f}) {end}")
print(f"predicted: {pred}, actual: {label}")
print(f"predicted: {pred_norm}, actual: {label}")

# now write the data to a .bin file
filename = os.path.join(DATA_CACHE_DIR, f"hellaswag_val.bin")
write_evalfile(filename, datas)

if __name__ == "__main__":
import argparse
Expand Down
Loading

0 comments on commit 051f3ca

Please sign in to comment.