Skip to content

Commit

Permalink
continued changes for sharded dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 22, 2024
1 parent 3131028 commit edb0df9
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 45 deletions.
16 changes: 8 additions & 8 deletions dataloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ Implements a medium simple DataLoader for a distributed training setup.
#define HEADER_SIZE 256

typedef struct {
// Distributed data parallel specifics.
// Each worker loads it's own chunk of data.
// variables related to distributed training
// each process/worker has to access different parts of the data
int process_rank;
int num_processes;
// hyperparameters. use size_t to prevent overflow
Expand All @@ -29,12 +29,11 @@ typedef struct {
FILE* tokens_file;
long file_size;
long current_position;
// outputs
uint16_t* buffer; // used to fread data from file into
uint16_t* buffer; // we fread data from file into this buffer
// public variables that could be accessed from outside
size_t num_batches;
int* inputs; // input tokens into transformer
int* targets; // target tokens for the transformer
// convenience variables
size_t num_batches;
} DataLoader;

long dataloader_load_shard_(DataLoader *loader, int shard_index) {
Expand Down Expand Up @@ -125,8 +124,9 @@ void dataloader_init(DataLoader *loader,
assert(shard_ntok >= num_processes * B * T + 1);
ntok_total += shard_ntok;
}
printf("DataLoader: filename_pattern: %s\n", filename_pattern);
printf("DataLoader: Found %ld tokens across %zu shards\n", ntok_total, loader->glob_result.gl_pathc);
// debugging prints
// printf("DataLoader: filename_pattern: %s\n", filename_pattern);
// printf("DataLoader: Found %ld tokens across %zu shards\n", ntok_total, loader->glob_result.gl_pathc);

// allocate all the space we'll need
loader->buffer = (uint16_t*)malloc((B * T + 1) * sizeof(uint16_t));
Expand Down
50 changes: 28 additions & 22 deletions dev/data/fineweb.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,30 +55,36 @@ def tokenize(doc):
return enc.encode_ordinary(doc["text"])

# main loop write files
pool = mp.Pool()
shard_index = 0
all_tokens = []
progress_bar = None
for tokens in pool.imap(tokenize, fw):
with mp.Pool() as pool:
shard_index = 0
all_tokens = []
progress_bar = None
for tokens in pool.imap(tokenize, fw):

# record the tokens and make sure to separate documents
all_tokens.append(eot)
all_tokens.extend(tokens)
# record the tokens and make sure to separate documents
all_tokens.append(eot)
all_tokens.extend(tokens)

# update progress bar
if progress_bar is None:
progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}")
progress_bar.update(len(tokens))
# update progress bar
if progress_bar is None:
progress_bar = tqdm(total=args.shard_size, unit="tokens", desc=f"Shard {shard_index}")
progress_bar.update(len(tokens))

# if we reach shard_size tokens, write shard to disk
if len(all_tokens) >= args.shard_size:
# if we reach shard_size tokens, write shard to disk
if len(all_tokens) >= args.shard_size:
split = "val" if shard_index == 0 else "train"
filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
write_tokens = all_tokens[:args.shard_size]
rest_tokens = all_tokens[args.shard_size:]
write_datafile(filename, write_tokens)
shard_index += 1
progress_bar = None
# note: create a copy so Python can free the all_tokens memory above
# the list rest_tokens is expected to be very small
all_tokens = [t for t in rest_tokens]

# write any remaining tokens as the last shard
if len(all_tokens) > 0:
split = "val" if shard_index == 0 else "train"
filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
write_tokens = all_tokens[:args.shard_size]
rest_tokens = all_tokens[args.shard_size:]
write_datafile(filename, write_tokens)
shard_index += 1
progress_bar = None
# note: create a copy so Python can free the all_tokens memory above
# the list rest_tokens is expected to be very small
all_tokens = [t for t in rest_tokens]
write_datafile(filename, all_tokens)
2 changes: 1 addition & 1 deletion train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2545,7 +2545,7 @@ void error_usage() {
fprintf(stderr, "Options:\n");
fprintf(stderr, " -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n");
fprintf(stderr, " -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n");
fprintf(stderr, " -e <string> input model filename (default = gpt2_124M_bf16.bin)\n");
fprintf(stderr, " -e <string> input from model at this filename (default = gpt2_124M_bf16.bin)\n");
fprintf(stderr, " -o <string> output log file (default = NULL)\n");
fprintf(stderr, " -b <int> (per-GPU, micro) batch size B (default = 4)\n");
fprintf(stderr, " -t <int> sequence length T (default = 1024)\n");
Expand Down
24 changes: 10 additions & 14 deletions train_gpt2_fp32.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1525,12 +1525,10 @@ void logger_free(Logger *logger) {
// CLI, poor man's argparse

void error_usage() {
// default run = debugging run with TinyShakespeare
// bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile
fprintf(stderr, "Usage: ./train_gpt2fp32cu [options]\n");
fprintf(stderr, "Example: ./train_gpt2fp32cu -i dev/data/tinystories/TinyStories -v 100 -s 100 -g 144 -o stories.log\n");
fprintf(stderr, "Options:\n");
fprintf(stderr, " -i <string> input dataset prefix (default = data/tiny_shakespeare)\n");
fprintf(stderr, " -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n");
fprintf(stderr, " -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n");
fprintf(stderr, " -o <string> output log file (default = NULL)\n");
fprintf(stderr, " -b <int> batch size B (default = 4)\n");
fprintf(stderr, " -t <int> sequence length T (default = 1024)\n");
Expand All @@ -1547,7 +1545,8 @@ void error_usage() {
int main(int argc, char *argv[]) {

// read in the (optional) command line arguments
const char* input_dataset_prefix = "dev/data/tinyshakespeare/tiny_shakespeare"; // or e.g. data/TinyStories
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* output_log_file = NULL;
int B = 4; // batch size
int T = 1024; // sequence length max
Expand All @@ -1561,7 +1560,8 @@ int main(int argc, char *argv[]) {
if (argv[i][0] != '-') { error_usage(); } // must start with dash
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
// read in the args
if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; }
if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); }
else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }
Expand All @@ -1575,7 +1575,8 @@ int main(int argc, char *argv[]) {
printf("+-----------------------+----------------------------------------------------+\n");
printf("| Parameter | Value |\n");
printf("+-----------------------+----------------------------------------------------+\n");
printf("| input dataset prefix | %-50s |\n", input_dataset_prefix);
printf("| train data pattern | %-50s |\n", train_data_pattern);
printf("| val data pattern | %-50s |\n", val_data_pattern);
printf("| output log file | %-50s |\n", output_log_file == NULL ? "NULL" : output_log_file);
printf("| batch size B | %-50d |\n", B);
printf("| sequence length T | %-50d |\n", T);
Expand Down Expand Up @@ -1617,14 +1618,9 @@ int main(int argc, char *argv[]) {
printf("+-----------------------+----------------------------------------------------+\n");

// build DataLoaders for both train and val
char train_tokens_filename[128];
char val_tokens_filename[128];
assert(strlen(input_dataset_prefix) < 100); // being bit lazy here, make sure we don't overflow
sprintf(train_tokens_filename, "%s_train.bin", input_dataset_prefix);
sprintf(val_tokens_filename, "%s_val.bin", input_dataset_prefix);
DataLoader train_loader, val_loader;
dataloader_init(&train_loader, train_tokens_filename, B, T, 0, 1);
dataloader_init(&val_loader, val_tokens_filename, B, T, 0, 1);
dataloader_init(&train_loader, train_data_pattern, B, T, 0, 1);
dataloader_init(&val_loader, val_data_pattern, B, T, 0, 1);
int train_num_batches = train_loader.num_batches; // let's do 1 epoch by default for now
int val_num_batches = train_loader.num_batches < val_max_batches ? train_loader.num_batches : val_max_batches;
printf("| train_num_batches | %-50d |\n", train_num_batches);
Expand Down

0 comments on commit edb0df9

Please sign in to comment.