2
2
Implements a medium simple DataLoader for a distributed training setup.
3
3
*/
4
4
5
+ #include <glob.h>
5
6
#include <stdio.h>
6
7
#include <stdlib.h>
7
8
#include <stddef.h>
@@ -23,6 +24,8 @@ typedef struct {
23
24
size_t B ;
24
25
size_t T ;
25
26
// input handling and its state
27
+ glob_t glob_result ; // stores the result of glob, for all shards we want to iterate
28
+ int current_shard ; // the current shard we are reading from
26
29
FILE * tokens_file ;
27
30
long file_size ;
28
31
long current_position ;
@@ -34,25 +37,13 @@ typedef struct {
34
37
size_t num_batches ;
35
38
} DataLoader ;
36
39
37
- void dataloader_reset (DataLoader * loader ) {
38
- // each process starts at a different offset in the file
39
- long header_bytes = HEADER_SIZE * sizeof (int );
40
- long token_bytes_offset = loader -> process_rank * loader -> B * loader -> T * sizeof (uint16_t );
41
- loader -> current_position = header_bytes + token_bytes_offset ;
42
- }
43
-
44
- void dataloader_init (DataLoader * loader ,
45
- const char * filename ,
46
- size_t B ,
47
- size_t T ,
48
- int process_rank ,
49
- int num_processes ) {
50
- loader -> process_rank = process_rank ;
51
- loader -> num_processes = num_processes ;
52
- loader -> B = B ;
53
- loader -> T = T ;
54
-
55
- // open the input file for reading
40
+ long dataloader_load_shard_ (DataLoader * loader , int shard_index ) {
41
+ // use the first glob match as the filename for now
42
+ const char * filename = loader -> glob_result .gl_pathv [shard_index ];
43
+ // open the input file for reading. also only a single file can be opened at a time
44
+ if (loader -> tokens_file != NULL ) {
45
+ fcloseCheck (loader -> tokens_file );
46
+ }
56
47
loader -> tokens_file = fopenCheck (filename , "rb" );
57
48
// validate the header
58
49
int header [HEADER_SIZE ];
@@ -65,7 +56,7 @@ void dataloader_init(DataLoader *loader,
65
56
}
66
57
if (header [1 ] != 1 ) { printf ("Bad version in data file\n" ); exit (EXIT_FAILURE ); }
67
58
long ntok = header [2 ]; // number of tokens in the file
68
-
59
+ assert ( ntok > 0 ); // we expect some tokens in the file. this should never trip, right?
69
60
// determine the file size and make sure it is consistent with the number of tokens
70
61
fseekCheck (loader -> tokens_file , 0 , SEEK_END ); // seek to end of file
71
62
loader -> file_size = ftell (loader -> tokens_file ); // read the offset, i.e. file size
@@ -76,31 +67,80 @@ void dataloader_init(DataLoader *loader,
76
67
printf ("Error: file size is not as expected\n" );
77
68
exit (EXIT_FAILURE );
78
69
}
79
- if (ntok < num_processes * B * T + 1 ) {
80
- // being too defensive/lazy, we could tolerate as low as T+1 tokens in principle
81
- printf ("Error: there are too few tokens\n" );
70
+ return ntok ;
71
+ }
72
+
73
+ void dataloader_reset (DataLoader * loader ) {
74
+ // fully resets the DataLoader object to init configuration
75
+ // each process starts at a different offset in the file
76
+ long header_bytes = HEADER_SIZE * sizeof (int );
77
+ long token_bytes_offset = loader -> process_rank * loader -> B * loader -> T * sizeof (uint16_t );
78
+ loader -> current_shard = 0 ;
79
+ loader -> current_position = header_bytes + token_bytes_offset ;
80
+ dataloader_load_shard_ (loader , loader -> current_shard );
81
+ }
82
+
83
+ void dataloader_advance_ (DataLoader * loader ) {
84
+ // advance the loader by loading the next data shard and resetting the position
85
+ if (loader -> glob_result .gl_pathc > 1 ) {
86
+ // if we have more than one shard, advance to the next one
87
+ loader -> current_shard = (loader -> current_shard + 1 ) % loader -> glob_result .gl_pathc ;
88
+ dataloader_load_shard_ (loader , loader -> current_shard );
89
+ }
90
+ long header_bytes = HEADER_SIZE * sizeof (int );
91
+ long token_bytes_offset = loader -> process_rank * loader -> B * loader -> T * sizeof (uint16_t );
92
+ loader -> current_position = header_bytes + token_bytes_offset ;
93
+ }
94
+
95
+ void dataloader_init (DataLoader * loader ,
96
+ const char * filename_pattern ,
97
+ size_t B ,
98
+ size_t T ,
99
+ int process_rank ,
100
+ int num_processes ) {
101
+ loader -> process_rank = process_rank ;
102
+ loader -> num_processes = num_processes ;
103
+ loader -> B = B ;
104
+ loader -> T = T ;
105
+ loader -> tokens_file = NULL ;
106
+
107
+ // glob to get the list of files matching the pattern, these are our data shards
108
+ int glob_status = glob (filename_pattern , 0 , NULL , & loader -> glob_result );
109
+ if (glob_status != 0 ) {
110
+ printf ("Error: failed to glob pattern: %s\n" , filename_pattern );
111
+ exit (EXIT_FAILURE );
112
+ }
113
+ if (loader -> glob_result .gl_pathc == 0 ) {
114
+ printf ("Error: no files found matching the pattern: %s\n" , filename_pattern );
82
115
exit (EXIT_FAILURE );
83
116
}
84
117
85
- // allocate space for B*T + 1 integers to store the inputs and targets
118
+ // inspect and validate all shards so we don't get any runtime errors later
119
+ // if too slow / too many shards, may wish to revisit later
120
+ long ntok_total = 0 ;
121
+ for (int shard_index = 0 ; shard_index < loader -> glob_result .gl_pathc ; shard_index ++ ) {
122
+ long shard_ntok = dataloader_load_shard_ (loader , shard_index );
123
+ // we need at least one batch/shard, the way things are written right now.
124
+ // can be relaxed a lot later.
125
+ assert (shard_ntok >= num_processes * B * T + 1 );
126
+ ntok_total += shard_ntok ;
127
+ }
128
+ printf ("DataLoader: filename_pattern: %s\n" , filename_pattern );
129
+ printf ("DataLoader: Found %ld tokens across %zu shards\n" , ntok_total , loader -> glob_result .gl_pathc );
130
+
131
+ // allocate all the space we'll need
86
132
loader -> buffer = (uint16_t * )malloc ((B * T + 1 ) * sizeof (uint16_t ));
87
133
loader -> inputs = (int * )malloc (B * T * sizeof (int ));
88
134
loader -> targets = (int * )malloc (B * T * sizeof (int ));
89
- // note: we definitely want to advance by B * T; That is the "stride" by which we move
90
- // the window of tokens. We only load B * T + 1 tokens because our targets are offset by 1
91
- loader -> num_batches = ntok / (num_processes * B * T );
135
+ loader -> num_batches = ntok_total / (num_processes * B * T ); // useful to know
92
136
93
- // reset the loader to the beginning of the file
137
+ // reset the loader, to initialize it
94
138
dataloader_reset (loader );
95
139
}
96
140
97
141
void dataloader_next_batch (DataLoader * loader ) {
98
142
size_t B = loader -> B ;
99
143
size_t T = loader -> T ;
100
- // if we are at the end of the file, loop back to the beginning
101
- if (loader -> current_position + (loader -> num_processes * B * T + 1 ) * sizeof (uint16_t ) > loader -> file_size ) {
102
- dataloader_reset (loader );
103
- }
104
144
// read B*T+1 uint16_t tokens from the file into buffer
105
145
fseekCheck (loader -> tokens_file , loader -> current_position , SEEK_SET );
106
146
freadCheck (loader -> buffer , sizeof (uint16_t ), B * T + 1 , loader -> tokens_file );
@@ -111,12 +151,18 @@ void dataloader_next_batch(DataLoader *loader) {
111
151
}
112
152
// advance the current position by B*T*num_processes integers
113
153
// note: the "stride" of tokens by which we move each time is definitely B * T
154
+ // we only load B * T + 1 tokens at each iteration because the targets are offset by 1
114
155
loader -> current_position += loader -> num_processes * B * T * sizeof (uint16_t );
156
+ // if the next batch would go past the end of the file, advance the loader
157
+ if (loader -> current_position + (loader -> num_processes * B * T + 1 ) * sizeof (uint16_t ) > loader -> file_size ) {
158
+ dataloader_advance_ (loader );
159
+ }
115
160
}
116
161
117
162
void dataloader_free (DataLoader * loader ) {
118
163
free (loader -> buffer );
119
164
free (loader -> inputs );
120
165
free (loader -> targets );
121
166
fcloseCheck (loader -> tokens_file );
167
+ globfree (& loader -> glob_result );
122
168
}
0 commit comments