Skip to content

Commit 3131028

Browse files
committed
extend dataloader to be sharded
1 parent 967420d commit 3131028

File tree

3 files changed

+93
-49
lines changed

3 files changed

+93
-49
lines changed

dataloader.h

Lines changed: 78 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Implements a medium simple DataLoader for a distributed training setup.
33
*/
44

5+
#include <glob.h>
56
#include <stdio.h>
67
#include <stdlib.h>
78
#include <stddef.h>
@@ -23,6 +24,8 @@ typedef struct {
2324
size_t B;
2425
size_t T;
2526
// input handling and its state
27+
glob_t glob_result; // stores the result of glob, for all shards we want to iterate
28+
int current_shard; // the current shard we are reading from
2629
FILE* tokens_file;
2730
long file_size;
2831
long current_position;
@@ -34,25 +37,13 @@ typedef struct {
3437
size_t num_batches;
3538
} DataLoader;
3639

37-
void dataloader_reset(DataLoader *loader) {
38-
// each process starts at a different offset in the file
39-
long header_bytes = HEADER_SIZE * sizeof(int);
40-
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
41-
loader->current_position = header_bytes + token_bytes_offset;
42-
}
43-
44-
void dataloader_init(DataLoader *loader,
45-
const char* filename,
46-
size_t B,
47-
size_t T,
48-
int process_rank,
49-
int num_processes) {
50-
loader->process_rank = process_rank;
51-
loader->num_processes = num_processes;
52-
loader->B = B;
53-
loader->T = T;
54-
55-
// open the input file for reading
40+
long dataloader_load_shard_(DataLoader *loader, int shard_index) {
41+
// use the first glob match as the filename for now
42+
const char* filename = loader->glob_result.gl_pathv[shard_index];
43+
// open the input file for reading. also only a single file can be opened at a time
44+
if (loader->tokens_file != NULL) {
45+
fcloseCheck(loader->tokens_file);
46+
}
5647
loader->tokens_file = fopenCheck(filename, "rb");
5748
// validate the header
5849
int header[HEADER_SIZE];
@@ -65,7 +56,7 @@ void dataloader_init(DataLoader *loader,
6556
}
6657
if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); }
6758
long ntok = header[2]; // number of tokens in the file
68-
59+
assert(ntok > 0); // we expect some tokens in the file. this should never trip, right?
6960
// determine the file size and make sure it is consistent with the number of tokens
7061
fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file
7162
loader->file_size = ftell(loader->tokens_file); // read the offset, i.e. file size
@@ -76,31 +67,80 @@ void dataloader_init(DataLoader *loader,
7667
printf("Error: file size is not as expected\n");
7768
exit(EXIT_FAILURE);
7869
}
79-
if (ntok < num_processes * B * T + 1) {
80-
// being too defensive/lazy, we could tolerate as low as T+1 tokens in principle
81-
printf("Error: there are too few tokens\n");
70+
return ntok;
71+
}
72+
73+
void dataloader_reset(DataLoader *loader) {
74+
// fully resets the DataLoader object to init configuration
75+
// each process starts at a different offset in the file
76+
long header_bytes = HEADER_SIZE * sizeof(int);
77+
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
78+
loader->current_shard = 0;
79+
loader->current_position = header_bytes + token_bytes_offset;
80+
dataloader_load_shard_(loader, loader->current_shard);
81+
}
82+
83+
void dataloader_advance_(DataLoader *loader) {
84+
// advance the loader by loading the next data shard and resetting the position
85+
if (loader->glob_result.gl_pathc > 1) {
86+
// if we have more than one shard, advance to the next one
87+
loader->current_shard = (loader->current_shard + 1) % loader->glob_result.gl_pathc;
88+
dataloader_load_shard_(loader, loader->current_shard);
89+
}
90+
long header_bytes = HEADER_SIZE * sizeof(int);
91+
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
92+
loader->current_position = header_bytes + token_bytes_offset;
93+
}
94+
95+
void dataloader_init(DataLoader *loader,
96+
const char* filename_pattern,
97+
size_t B,
98+
size_t T,
99+
int process_rank,
100+
int num_processes) {
101+
loader->process_rank = process_rank;
102+
loader->num_processes = num_processes;
103+
loader->B = B;
104+
loader->T = T;
105+
loader->tokens_file = NULL;
106+
107+
// glob to get the list of files matching the pattern, these are our data shards
108+
int glob_status = glob(filename_pattern, 0, NULL, &loader->glob_result);
109+
if (glob_status != 0) {
110+
printf("Error: failed to glob pattern: %s\n", filename_pattern);
111+
exit(EXIT_FAILURE);
112+
}
113+
if (loader->glob_result.gl_pathc == 0) {
114+
printf("Error: no files found matching the pattern: %s\n", filename_pattern);
82115
exit(EXIT_FAILURE);
83116
}
84117

85-
// allocate space for B*T + 1 integers to store the inputs and targets
118+
// inspect and validate all shards so we don't get any runtime errors later
119+
// if too slow / too many shards, may wish to revisit later
120+
long ntok_total = 0;
121+
for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) {
122+
long shard_ntok = dataloader_load_shard_(loader, shard_index);
123+
// we need at least one batch/shard, the way things are written right now.
124+
// can be relaxed a lot later.
125+
assert(shard_ntok >= num_processes * B * T + 1);
126+
ntok_total += shard_ntok;
127+
}
128+
printf("DataLoader: filename_pattern: %s\n", filename_pattern);
129+
printf("DataLoader: Found %ld tokens across %zu shards\n", ntok_total, loader->glob_result.gl_pathc);
130+
131+
// allocate all the space we'll need
86132
loader->buffer = (uint16_t*)malloc((B * T + 1) * sizeof(uint16_t));
87133
loader->inputs = (int*)malloc(B * T * sizeof(int));
88134
loader->targets = (int*)malloc(B * T * sizeof(int));
89-
// note: we definitely want to advance by B * T; That is the "stride" by which we move
90-
// the window of tokens. We only load B * T + 1 tokens because our targets are offset by 1
91-
loader->num_batches = ntok / (num_processes * B * T);
135+
loader->num_batches = ntok_total / (num_processes * B * T); // useful to know
92136

93-
// reset the loader to the beginning of the file
137+
// reset the loader, to initialize it
94138
dataloader_reset(loader);
95139
}
96140

97141
void dataloader_next_batch(DataLoader *loader) {
98142
size_t B = loader->B;
99143
size_t T = loader->T;
100-
// if we are at the end of the file, loop back to the beginning
101-
if (loader->current_position + (loader->num_processes * B * T + 1) * sizeof(uint16_t) > loader->file_size) {
102-
dataloader_reset(loader);
103-
}
104144
// read B*T+1 uint16_t tokens from the file into buffer
105145
fseekCheck(loader->tokens_file, loader->current_position, SEEK_SET);
106146
freadCheck(loader->buffer, sizeof(uint16_t), B*T+1, loader->tokens_file);
@@ -111,12 +151,18 @@ void dataloader_next_batch(DataLoader *loader) {
111151
}
112152
// advance the current position by B*T*num_processes integers
113153
// note: the "stride" of tokens by which we move each time is definitely B * T
154+
// we only load B * T + 1 tokens at each iteration because the targets are offset by 1
114155
loader->current_position += loader->num_processes * B * T * sizeof(uint16_t);
156+
// if the next batch would go past the end of the file, advance the loader
157+
if (loader->current_position + (loader->num_processes * B * T + 1) * sizeof(uint16_t) > loader->file_size) {
158+
dataloader_advance_(loader);
159+
}
115160
}
116161

117162
void dataloader_free(DataLoader *loader) {
118163
free(loader->buffer);
119164
free(loader->inputs);
120165
free(loader->targets);
121166
fcloseCheck(loader->tokens_file);
167+
globfree(&loader->glob_result);
122168
}

dev/data/fineweb.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def tokenize(doc):
7272

7373
# if we reach shard_size tokens, write shard to disk
7474
if len(all_tokens) >= args.shard_size:
75-
filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{shard_index:06d}.bin")
75+
split = "val" if shard_index == 0 else "train"
76+
filename = os.path.join(DATA_CACHE_DIR, f"fineweb_{split}_{shard_index:06d}.bin")
7677
write_tokens = all_tokens[:args.shard_size]
7778
rest_tokens = all_tokens[args.shard_size:]
7879
write_datafile(filename, write_tokens)

train_gpt2.cu

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2541,12 +2541,10 @@ void logger_free(Logger *logger) {
25412541
// CLI, poor man's argparse
25422542

25432543
void error_usage() {
2544-
// default run = debugging run with TinyShakespeare
2545-
// bigger run = train on TinyStories! e.g. val/sample less often, but sample more tokens, write to logfile
25462544
fprintf(stderr, "Usage: ./train_gpt2cu [options]\n");
2547-
fprintf(stderr, "Example: ./train_gpt2cu -i dev/data/tinystories/TinyStories -v 100 -s 100 -g 144 -o stories.log\n");
25482545
fprintf(stderr, "Options:\n");
2549-
fprintf(stderr, " -i <string> input dataset prefix (default = dev/data/tinyshakespeare/tiny_shakespeare)\n");
2546+
fprintf(stderr, " -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n");
2547+
fprintf(stderr, " -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n");
25502548
fprintf(stderr, " -e <string> input model filename (default = gpt2_124M_bf16.bin)\n");
25512549
fprintf(stderr, " -o <string> output log file (default = NULL)\n");
25522550
fprintf(stderr, " -b <int> (per-GPU, micro) batch size B (default = 4)\n");
@@ -2572,7 +2570,8 @@ int main(int argc, char *argv[]) {
25722570
multi_gpu_config = multi_gpu_config_init(&argc, &argv);
25732571

25742572
// read in the (optional) command line arguments
2575-
const char* input_dataset_prefix = "dev/data/tinyshakespeare/tiny_shakespeare"; // or e.g. data/TinyStories
2573+
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
2574+
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
25762575
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model
25772576
const char* output_log_file = NULL;
25782577
int B = 4; // batch size
@@ -2595,7 +2594,8 @@ int main(int argc, char *argv[]) {
25952594
if (argv[i][0] != '-') { error_usage(); } // must start with dash
25962595
if (strlen(argv[i]) != 2) { error_usage(); } // must be -x (one dash, one letter)
25972596
// read in the args
2598-
if (argv[i][1] == 'i') { input_dataset_prefix = argv[i+1]; }
2597+
if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }
2598+
else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }
25992599
else if (argv[i][1] == 'e') { load_filename = argv[i+1]; }
26002600
else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }
26012601
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size
@@ -2617,10 +2617,14 @@ int main(int argc, char *argv[]) {
26172617
}
26182618
// calculate a sensible default for total batch size by assuming no gradient accumulation
26192619
if (total_batch_size == -1) { total_batch_size = B * T * multi_gpu_config.num_processes; }
2620+
// if we're only overfitting a single batch for debugging, let's overfit the first batch
2621+
// from val instead of train split, because val is smaller and faster. (train_gpt2.py does the same)
2622+
if (overfit_single_batch == 1) { train_data_pattern = val_data_pattern; }
26202623
printf0("+-----------------------+----------------------------------------------------+\n");
26212624
printf0("| Parameter | Value |\n");
26222625
printf0("+-----------------------+----------------------------------------------------+\n");
2623-
printf0("| input dataset prefix | %-50s |\n", input_dataset_prefix);
2626+
printf0("| train data pattern | %-50s |\n", train_data_pattern);
2627+
printf0("| val data pattern | %-50s |\n", val_data_pattern);
26242628
printf0("| output log file | %-50s |\n", output_log_file == NULL ? "NULL" : output_log_file);
26252629
printf0("| micro batch size B | %-50d |\n", B);
26262630
printf0("| sequence length T | %-50d |\n", T);
@@ -2663,16 +2667,9 @@ int main(int argc, char *argv[]) {
26632667
printf0("+-----------------------+----------------------------------------------------+\n");
26642668

26652669
// build DataLoaders for both train and val
2666-
char train_tokens_filename[128], val_tokens_filename[128];
2667-
assert(strlen(input_dataset_prefix) < 100); // being bit lazy here, make sure we don't overflow
2668-
// if we're only overfitting a single batch for debugging, let's overfit the first batch
2669-
// from val instead of train split, because val is smaller and a bit faster
2670-
const char* train_split = (overfit_single_batch == 1) ? "val" : "train";
2671-
sprintf(train_tokens_filename, "%s_%s.bin", input_dataset_prefix, train_split);
2672-
sprintf(val_tokens_filename, "%s_val.bin", input_dataset_prefix);
26732670
DataLoader train_loader, val_loader;
2674-
dataloader_init(&train_loader, train_tokens_filename, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
2675-
dataloader_init(&val_loader, val_tokens_filename, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
2671+
dataloader_init(&train_loader, train_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
2672+
dataloader_init(&val_loader, val_data_pattern, B, T, multi_gpu_config.process_rank, multi_gpu_config.num_processes);
26762673
int train_num_batches = (max_steps == -1) ? train_loader.num_batches : max_steps; // default = 1 epoch
26772674
int val_num_batches = train_loader.num_batches < val_max_batches ? train_loader.num_batches : val_max_batches;
26782675
printf0("| train_num_batches | %-50d |\n", train_num_batches);

0 commit comments

Comments
 (0)