Skip to content

Commit

Permalink
extend dataloader to be sharded
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 21, 2024
1 parent 967420d commit 3131028
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 49 deletions.
110 changes: 78 additions & 32 deletions dataloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Implements a medium simple DataLoader for a distributed training setup.
*/

#include <glob.h>
#include <stdio.h>
#include <stdlib.h>
#include <stddef.h>
Expand All @@ -23,6 +24,8 @@ typedef struct {
size_t B;
size_t T;
// input handling and its state
glob_t glob_result; // stores the result of glob, for all shards we want to iterate
int current_shard; // the current shard we are reading from
FILE* tokens_file;
long file_size;
long current_position;
Expand All @@ -34,25 +37,13 @@ typedef struct {
size_t num_batches;
} DataLoader;

void dataloader_reset(DataLoader *loader) {
// each process starts at a different offset in the file
long header_bytes = HEADER_SIZE * sizeof(int);
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
loader->current_position = header_bytes + token_bytes_offset;
}

void dataloader_init(DataLoader *loader,
const char* filename,
size_t B,
size_t T,
int process_rank,
int num_processes) {
loader->process_rank = process_rank;
loader->num_processes = num_processes;
loader->B = B;
loader->T = T;

// open the input file for reading
long dataloader_load_shard_(DataLoader *loader, int shard_index) {
// use the first glob match as the filename for now
const char* filename = loader->glob_result.gl_pathv[shard_index];
// open the input file for reading. also only a single file can be opened at a time
if (loader->tokens_file != NULL) {
fcloseCheck(loader->tokens_file);
}
loader->tokens_file = fopenCheck(filename, "rb");
// validate the header
int header[HEADER_SIZE];
Expand All @@ -65,7 +56,7 @@ void dataloader_init(DataLoader *loader,
}
if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); }
long ntok = header[2]; // number of tokens in the file

assert(ntok > 0); // we expect some tokens in the file. this should never trip, right?
// determine the file size and make sure it is consistent with the number of tokens
fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file
loader->file_size = ftell(loader->tokens_file); // read the offset, i.e. file size
Expand All @@ -76,31 +67,80 @@ void dataloader_init(DataLoader *loader,
printf("Error: file size is not as expected\n");
exit(EXIT_FAILURE);
}
if (ntok < num_processes * B * T + 1) {
// being too defensive/lazy, we could tolerate as low as T+1 tokens in principle
printf("Error: there are too few tokens\n");
return ntok;
}

void dataloader_reset(DataLoader *loader) {
// fully resets the DataLoader object to init configuration
// each process starts at a different offset in the file
long header_bytes = HEADER_SIZE * sizeof(int);
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
loader->current_shard = 0;
loader->current_position = header_bytes + token_bytes_offset;
dataloader_load_shard_(loader, loader->current_shard);
}

void dataloader_advance_(DataLoader *loader) {
// advance the loader by loading the next data shard and resetting the position
if (loader->glob_result.gl_pathc > 1) {
// if we have more than one shard, advance to the next one
loader->current_shard = (loader->current_shard + 1) % loader->glob_result.gl_pathc;
dataloader_load_shard_(loader, loader->current_shard);
}
long header_bytes = HEADER_SIZE * sizeof(int);
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
loader->current_position = header_bytes + token_bytes_offset;
}

void dataloader_init(DataLoader *loader,
const char* filename_pattern,
size_t B,
size_t T,
int process_rank,
int num_processes) {
loader->process_rank = process_rank;
loader->num_processes = num_processes;
loader->B = B;
loader->T = T;
loader->tokens_file = NULL;

// glob to get the list of files matching the pattern, these are our data shards
int glob_status = glob(filename_pattern, 0, NULL, &loader->glob_result);
if (glob_status != 0) {
printf("Error: failed to glob pattern: %s\n", filename_pattern);
exit(EXIT_FAILURE);
}
if (loader->glob_result.gl_pathc == 0) {
printf("Error: no files found matching the pattern: %s\n", filename_pattern);
exit(EXIT_FAILURE);
}

// allocate space for B*T + 1 integers to store the inputs and targets
// inspect and validate all shards so we don't get any runtime errors later
// if too slow / too many shards, may wish to revisit later
long ntok_total = 0;
for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) {
long shard_ntok = dataloader_load_shard_(loader, shard_index);
// we need at least one batch/shard, the way things are written right now.
// can be relaxed a lot later.
assert(shard_ntok >= num_processes * B * T + 1);
ntok_total += shard_ntok;
}
printf("DataLoader: filename_pattern: %s\n", filename_pattern);
printf("DataLoader: Found %ld tokens across %zu shards\n", ntok_total, loader->glob_result.gl_pathc);

// allocate all the space we'll need
loader->buffer = (uint16_t*)malloc((B * T + 1) * sizeof(uint16_t));
loader->inputs = (int*)malloc(B * T * sizeof(int));
loader->targets = (int*)malloc(B * T * sizeof(int));
// note: we definitely want to advance by B * T; That is the "stride" by which we move
// the window of tokens. We only load B * T + 1 tokens because our targets are offset by 1
loader->num_batches = ntok / (num_processes * B * T);
loader->num_batches = ntok_total / (num_processes * B * T); // useful to know

// reset the loader to the beginning of the file
// reset the loader, to initialize it
dataloader_reset(loader);
}

void dataloader_next_batch(DataLoader *loader) {
size_t B = loader->B;
size_t T = loader->T;
// if we are at the end of the file, loop back to the beginning
if (loader->current_position + (loader->num_processes * B * T + 1) * sizeof(uint16_t) > loader->file_size) {
dataloader_reset(loader);
}
// read B*T+1 uint16_t tokens from the file into buffer
fseekCheck(loader->tokens_file, loader->current_position, SEEK_SET);
freadCheck(loader->buffer, sizeof(uint16_t), B*T+1, loader->tokens_file);
Expand All @@ -111,12 +151,18 @@ void dataloader_next_batch(DataLoader *loader) {
}
// advance the current position by B*T*num_processes integers
// note: the "stride" of tokens by which we move each time is definitely B * T
// we only load B * T + 1 tokens at each iteration because the targets are offset by 1
loader->current_position += loader->num_processes * B * T * sizeof(uint16_t);
// if the next batch would go past the end of the file, advance the loader
if (loader->current_position + (loader->num_processes * B * T + 1) * sizeof(uint16_t) > loader->file_size) {
dataloader_advance_(loader);
}
}

void dataloader_free(DataLoader *loader) {
free(loader->buffer);
free(loader->inputs);
free(loader->targets);
fcloseCheck(loader->tokens_file);
globfree(&loader->glob_result);
}
3 changes: 2 additions & 1 deletion dev/data/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def tokenize(doc):

# if we reach shard_size tokens, write shard to disk
if len(all_tokens) >= args.shard_size:
filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{shard_index:06d}.bin")
split = "val" if shard_index == 0 else "train"
filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
write_tokens = all_tokens[:args.shard_size]
rest_tokens = all_tokens[args.shard_size:]
write_datafile(filename, write_tokens)
Expand Down
29 changes: 13 additions & 16 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2541,12 +2541,10 @@ void logger_free(Logger *logger) {
// CLI, poor man's argparse

void error_usage() {
// default run = debugging run with TinyShakespeare
// bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile
fprintf(stderr, "Usage: ./train_gpt2cu [options]\n");
fprintf(stderr, "Example: ./train_gpt2cu -i dev/data/tinystories/TinyStories -v 100 -s 100 -g 144 -o stories.log\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -i <string> input dataset prefix (default = dev/data/tinyshakespeare/tiny_shakespeare)\n");
fprintf(stderr, " -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n");
fprintf(stderr, " -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n");
fprintf(stderr, " -e <string> input model filename (default = gpt2_124M_bf16.bin)\n");
fprintf(stderr, " -o <string> output log file (default = NULL)\n");
fprintf(stderr, " -b <int> (per-GPU, micro) batch size B (default = 4)\n");
Expand All @@ -2572,7 +2570,8 @@ int main(int argc, char *argv[]) {
multi_gpu_config = multi_gpu_config_init(&argc, &argv);

// read in the (optional) command line arguments
const char* input_dataset_prefix = "dev/data/tinyshakespeare/tiny_shakespeare"; // or e.g. data/TinyStories
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model
const char* output_log_file = NULL;
int B = 4; // batch size
Expand All @@ -2595,7 +2594,8 @@ int main(int argc, char *argv[]) {
if (argv[i][0] != '-') { error_usage(); } // must start with dash
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
// read in the args
if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; }
if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'e') { load_filename = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size
Expand All @@ -2617,10 +2617,14 @@ int main(int argc, char *argv[]) {
}
// calculate a sensible default for total batch size by assuming no gradient accumulation
if (total_batch_size == -1) { total_batch_size = B * T * multi_gpu_config.num_processes; }
// if we're only overfitting a single batch for debugging, let's overfit the first batch
// from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same)
if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; }
printf0("+-----------------------+----------------------------------------------------+\n");
printf0("| Parameter | Value |\n");
printf0("+-----------------------+----------------------------------------------------+\n");
printf0("| input dataset prefix | %-50s |\n", input_dataset_prefix);
printf0("| train data pattern | %-50s |\n", train_data_pattern);
printf0("| val data pattern | %-50s |\n", val_data_pattern);
printf0("| output log file | %-50s |\n", output_log_file == NULL ? "NULL" : output_log_file);
printf0("| micro batch size B | %-50d |\n", B);
printf0("| sequence length T | %-50d |\n", T);
Expand Down Expand Up @@ -2663,16 +2667,9 @@ int main(int argc, char *argv[]) {
printf0("+-----------------------+----------------------------------------------------+\n");

// build DataLoaders for both train and val
char train_tokens_filename[128], val_tokens_filename[128];
assert(strlen(input_dataset_prefix) < 100); // being bit lazy here, make sure we don't overflow
// if we're only overfitting a single batch for debugging, let's overfit the first batch
// from val instead of train split, because val is smaller and a bit faster
const char* train_split = (overfit_single_batch == 1) ? "val" : "train";
sprintf(train_tokens_filename, "%s_%s.bin", input_dataset_prefix, train_split);
sprintf(val_tokens_filename, "%s_val.bin", input_dataset_prefix);
DataLoader train_loader, val_loader;
dataloader_init(&train_loader, train_tokens_filename, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
dataloader_init(&val_loader, val_tokens_filename, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
int train_num_batches = (max_steps == -1) ? train_loader.num_batches : max_steps; // default = 1 epoch
int val_num_batches = train_loader.num_batches < val_max_batches ? train_loader.num_batches : val_max_batches;
printf0("| train_num_batches | %-50d |\n", train_num_batches);
Expand Down

0 comments on commit 3131028

Please sign in to comment.