22Implements a medium simple DataLoader for a distributed training setup. 
33*/ 
44
5+ #include  <glob.h> 
56#include  <stdio.h> 
67#include  <stdlib.h> 
78#include  <stddef.h> 
@@ -23,6 +24,8 @@ typedef struct {
2324    size_t  B ;
2425    size_t  T ;
2526    // input handling and its state 
27+     glob_t  glob_result ; // stores the result of glob, for all shards we want to iterate 
28+     int  current_shard ; // the current shard we are reading from 
2629    FILE *  tokens_file ;
2730    long  file_size ;
2831    long  current_position ;
@@ -34,25 +37,13 @@ typedef struct {
3437    size_t  num_batches ;
3538} DataLoader ;
3639
37- void  dataloader_reset (DataLoader  * loader ) {
38-     // each process starts at a different offset in the file 
39-     long  header_bytes  =  HEADER_SIZE  *  sizeof (int );
40-     long  token_bytes_offset  =  loader -> process_rank  *  loader -> B  *  loader -> T  *  sizeof (uint16_t );
41-     loader -> current_position  =  header_bytes  +  token_bytes_offset ;
42- }
43- 
44- void  dataloader_init (DataLoader  * loader ,
45-                      const  char *  filename ,
46-                      size_t  B ,
47-                      size_t  T ,
48-                      int  process_rank ,
49-                      int  num_processes ) {
50-     loader -> process_rank  =  process_rank ;
51-     loader -> num_processes  =  num_processes ;
52-     loader -> B  =  B ;
53-     loader -> T  =  T ;
54- 
55-     // open the input file for reading 
40+ long  dataloader_load_shard_ (DataLoader  * loader , int  shard_index ) {
41+     // use the first glob match as the filename for now 
42+     const  char *  filename  =  loader -> glob_result .gl_pathv [shard_index ];
43+     // open the input file for reading. also only a single file can be opened at a time 
44+     if  (loader -> tokens_file  !=  NULL ) {
45+         fcloseCheck (loader -> tokens_file );
46+     }
5647    loader -> tokens_file  =  fopenCheck (filename , "rb" );
5748    // validate the header 
5849    int  header [HEADER_SIZE ];
@@ -65,7 +56,7 @@ void dataloader_init(DataLoader *loader,
6556    }
6657    if  (header [1 ] !=  1 ) { printf ("Bad version in data file\n" ); exit (EXIT_FAILURE ); }
6758    long  ntok  =  header [2 ]; // number of tokens in the file 
68- 
59+      assert ( ntok   >   0 );  // we expect some tokens in the file. this should never trip, right? 
6960    // determine the file size and make sure it is consistent with the number of tokens 
7061    fseekCheck (loader -> tokens_file , 0 , SEEK_END ); // seek to end of file 
7162    loader -> file_size  =  ftell (loader -> tokens_file ); // read the offset, i.e. file size 
@@ -76,31 +67,80 @@ void dataloader_init(DataLoader *loader,
7667        printf ("Error: file size is not as expected\n" );
7768        exit (EXIT_FAILURE );
7869    }
79-     if  (ntok  <  num_processes  *  B  *  T  +  1 ) {
80-         // being too defensive/lazy, we could tolerate as low as T+1 tokens in principle 
81-         printf ("Error: there are too few tokens\n" );
70+     return  ntok ;
71+ }
72+ 
73+ void  dataloader_reset (DataLoader  * loader ) {
74+     // fully resets the DataLoader object to init configuration 
75+     // each process starts at a different offset in the file 
76+     long  header_bytes  =  HEADER_SIZE  *  sizeof (int );
77+     long  token_bytes_offset  =  loader -> process_rank  *  loader -> B  *  loader -> T  *  sizeof (uint16_t );
78+     loader -> current_shard  =  0 ;
79+     loader -> current_position  =  header_bytes  +  token_bytes_offset ;
80+     dataloader_load_shard_ (loader , loader -> current_shard );
81+ }
82+ 
83+ void  dataloader_advance_ (DataLoader  * loader ) {
84+     // advance the loader by loading the next data shard and resetting the position 
85+     if  (loader -> glob_result .gl_pathc  >  1 ) {
86+         // if we have more than one shard, advance to the next one 
87+         loader -> current_shard  =  (loader -> current_shard  +  1 ) % loader -> glob_result .gl_pathc ;
88+         dataloader_load_shard_ (loader , loader -> current_shard );
89+     }
90+     long  header_bytes  =  HEADER_SIZE  *  sizeof (int );
91+     long  token_bytes_offset  =  loader -> process_rank  *  loader -> B  *  loader -> T  *  sizeof (uint16_t );
92+     loader -> current_position  =  header_bytes  +  token_bytes_offset ;
93+ }
94+ 
95+ void  dataloader_init (DataLoader  * loader ,
96+                      const  char *  filename_pattern ,
97+                      size_t  B ,
98+                      size_t  T ,
99+                      int  process_rank ,
100+                      int  num_processes ) {
101+     loader -> process_rank  =  process_rank ;
102+     loader -> num_processes  =  num_processes ;
103+     loader -> B  =  B ;
104+     loader -> T  =  T ;
105+     loader -> tokens_file  =  NULL ;
106+ 
107+     // glob to get the list of files matching the pattern, these are our data shards 
108+     int  glob_status  =  glob (filename_pattern , 0 , NULL , & loader -> glob_result );
109+     if  (glob_status  !=  0 ) {
110+         printf ("Error: failed to glob pattern: %s\n" , filename_pattern );
111+         exit (EXIT_FAILURE );
112+     }
113+     if  (loader -> glob_result .gl_pathc  ==  0 ) {
114+         printf ("Error: no files found matching the pattern: %s\n" , filename_pattern );
82115        exit (EXIT_FAILURE );
83116    }
84117
85-     // allocate space for B*T + 1 integers to store the inputs and targets 
118+     // inspect and validate all shards so we don't get any runtime errors later 
119+     // if too slow / too many shards, may wish to revisit later 
120+     long  ntok_total  =  0 ;
121+     for  (int  shard_index  =  0 ; shard_index  <  loader -> glob_result .gl_pathc ; shard_index ++ ) {
122+         long  shard_ntok  =  dataloader_load_shard_ (loader , shard_index );
123+         // we need at least one batch/shard, the way things are written right now. 
124+         // can be relaxed a lot later. 
125+         assert (shard_ntok  >= num_processes  *  B  *  T  +  1 );
126+         ntok_total  +=  shard_ntok ;
127+     }
128+     printf ("DataLoader: filename_pattern: %s\n" , filename_pattern );
129+     printf ("DataLoader: Found %ld tokens across %zu shards\n" , ntok_total , loader -> glob_result .gl_pathc );
130+ 
131+     // allocate all the space we'll need 
86132    loader -> buffer  =  (uint16_t * )malloc ((B  *  T  +  1 ) *  sizeof (uint16_t ));
87133    loader -> inputs  =  (int * )malloc (B  *  T  *  sizeof (int ));
88134    loader -> targets  =  (int * )malloc (B  *  T  *  sizeof (int ));
89-     // note: we definitely want to advance by B * T; That is the "stride" by which we move 
90-     // the window of tokens. We only load B * T + 1 tokens because our targets are offset by 1 
91-     loader -> num_batches  =  ntok  / (num_processes  *  B  *  T );
135+     loader -> num_batches  =  ntok_total  / (num_processes  *  B  *  T ); // useful to know 
92136
93-     // reset the loader to the beginning of the file  
137+     // reset the loader,  to initialize it  
94138    dataloader_reset (loader );
95139}
96140
97141void  dataloader_next_batch (DataLoader  * loader ) {
98142    size_t  B  =  loader -> B ;
99143    size_t  T  =  loader -> T ;
100-     // if we are at the end of the file, loop back to the beginning 
101-     if  (loader -> current_position  +  (loader -> num_processes  *  B  *  T  +  1 ) *  sizeof (uint16_t ) >  loader -> file_size ) {
102-         dataloader_reset (loader );
103-     }
104144    // read B*T+1 uint16_t tokens from the file into buffer 
105145    fseekCheck (loader -> tokens_file , loader -> current_position , SEEK_SET );
106146    freadCheck (loader -> buffer , sizeof (uint16_t ), B * T + 1 , loader -> tokens_file );
@@ -111,12 +151,18 @@ void dataloader_next_batch(DataLoader *loader) {
111151    }
112152    // advance the current position by B*T*num_processes integers 
113153    // note: the "stride" of tokens by which we move each time is definitely B * T 
154+     // we only load B * T + 1 tokens at each iteration because the targets are offset by 1 
114155    loader -> current_position  +=  loader -> num_processes  *  B  *  T  *  sizeof (uint16_t );
156+     // if the next batch would go past the end of the file, advance the loader 
157+     if  (loader -> current_position  +  (loader -> num_processes  *  B  *  T  +  1 ) *  sizeof (uint16_t ) >  loader -> file_size ) {
158+         dataloader_advance_ (loader );
159+     }
115160}
116161
117162void  dataloader_free (DataLoader  * loader ) {
118163    free (loader -> buffer );
119164    free (loader -> inputs );
120165    free (loader -> targets );
121166    fcloseCheck (loader -> tokens_file );
167+     globfree (& loader -> glob_result );
122168}
0 commit comments