Skip to content

Commit

Permalink
Merge pull request karpathy#457 from karpathy/feature/write_checkpoints
Browse files Browse the repository at this point in the history
add checkpoint function write to file
  • Loading branch information
karpathy authored May 25, 2024
2 parents 4ff0412 + 2a0f78d commit fe698b3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 15 deletions.
9 changes: 5 additions & 4 deletions dev/unistd.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#define _USE_MATH_DEFINES

#include <stdio.h>

#include <math.h>
//#define gen_max_length 64 // compile as C++ to skip this VLA issue
#include <time.h>
Expand All @@ -18,14 +17,16 @@ static inline int clock_gettime(int ignore_variable, struct timespec* tv)
}

#define OMP /* turn it on */
#include <io.h> /* needed for access below */
#include <io.h> /* needed for access below */
#define F_OK 0
#define access _access

#define TURN_OFF_FP_FAST __pragma(float_control( precise, on, push )) // Save current setting and turn on /fp:precise
#define TURN_ON_FP_FAST __pragma(float_control(pop)) // Restore file's default settings

#define mkdir _mkdir // add mkdir into namespace for windows
#include <direct.h> /* for _mkdir and _stat */
#define mkdir(path, mode) _mkdir(path) /* sketchy way to get mkdir to work on windows */
#define stat _stat

typedef struct glob_t {
size_t gl_pathc; // Count of matched pathnames
Expand Down Expand Up @@ -57,7 +58,7 @@ static inline int glob(const char* pattern, int ignored_flags, int (*ignored_err
strncpy_s(pattern_copy, sizeof(pattern_copy) - 1, pattern, sizeof(pattern_copy) - 1);

replace_forward_slashes (pattern_copy); // Replace forward slashes with backslashes

if (strchr(pattern_copy, '\\') != NULL) {
strncpy_s(directory_path, sizeof(directory_path) - 1, pattern_copy, strrchr(pattern_copy, '\\') - pattern_copy + 1);
directory_path[strrchr(pattern_copy, '\\') - pattern_copy + 1] = '\0';
Expand Down
70 changes: 59 additions & 11 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,11 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200),

#include <unistd.h>
#include <stdio.h>
#include <stdlib.h>
#include <stdarg.h>
#include <string>
#include <sys/stat.h>
#include <sys/types.h>
#include <vector>
#include <algorithm>
#include <functional>
Expand Down Expand Up @@ -2130,6 +2133,31 @@ typedef struct {
int4* bucket_info; // encoder_backward, B*T*num_c_groups (int4) - size for worst case
} GPT2;

void gpt2_write_to_checkpoint(GPT2 *model, const char* checkpoint_path) {
// write the model to a checkpoint file
printf0("Writing model to %s\n", checkpoint_path);
FILE *model_file = fopenCheck(checkpoint_path, "wb");
// write the header first
int model_header[256];
model_header[0] = 20240326;
assert(PRECISION_MODE == PRECISION_FP32 || PRECISION_MODE == PRECISION_BF16);
model_header[1] = PRECISION_MODE == PRECISION_FP32 ? 3 : 5;
model_header[2] = model->config.max_seq_len;
model_header[3] = model->config.vocab_size;
model_header[4] = model->config.num_layers;
model_header[5] = model->config.num_heads;
model_header[6] = model->config.channels;
model_header[7] = model->config.padded_vocab_size;
fwrite(model_header, sizeof(int), 256, model_file);
// write the parameters
void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes);
cudaCheck(cudaMemcpy(params_memory_cpu, model->params_memory, model->num_parameters_bytes, cudaMemcpyDeviceToHost));
fwrite(params_memory_cpu, 1, model->num_parameters_bytes, model_file);
free(params_memory_cpu);
// close file, we're done
fcloseCheck(model_file);
}

void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {

if (PRECISION_MODE == PRECISION_FP16) {
Expand Down Expand Up @@ -2187,7 +2215,7 @@ void gpt2_build_from_checkpoint(GPT2 *model, const char* checkpoint_path) {
model->params_memory = malloc_and_point_parameters(&model->params, model->param_elements, model->param_sizeof);

// read in all the parameters from file and copy them to device
float* params_memory_cpu = (float*)mallocCheck(model->num_parameters_bytes);
void* params_memory_cpu = (void*)mallocCheck(model->num_parameters_bytes);
freadCheck(params_memory_cpu, 1, model->num_parameters_bytes, model_file);
cudaCheck(cudaMemcpy(model->params_memory, params_memory_cpu, model->num_parameters_bytes, cudaMemcpyHostToDevice));
free(params_memory_cpu);
Expand Down Expand Up @@ -2830,17 +2858,30 @@ int sample_softmax(const float* logits, int n, float coin) {
// ----------------------------------------------------------------------------
// Logger lite, will probably grow/change some over time

void create_dir_if_not_exists(const char *dir) {
struct stat st = {0};
if (stat(dir, &st) == -1) {
if (mkdir(dir, 0700) == -1) {
printf0("ERROR: could not create directory: %s\n", dir);
exit(EXIT_FAILURE);
}
printf0("created directory: %s\n", dir);
}
}

typedef struct {
FILE *logfile;
int flush_every; // every how many steps to flush the log
} Logger;

void logger_init(Logger *logger, const char *filename) {
void logger_init(Logger *logger, const char *log_dir, int process_rank) {
logger->flush_every = 10;
logger->logfile = NULL;
// only rank 0 process will log
if (filename != NULL && multi_gpu_config.process_rank == 0) {
logger->logfile = fopenCheck(filename, "w");
if (log_dir != NULL && process_rank == 0) {
char output_log_file[256];
assert(strlen(log_dir) < 200); // being a bit lazy, can relax later maybe
snprintf(output_log_file, 256, "%s/main.log", log_dir);
logger->logfile = fopenCheck(output_log_file, "w");
}
}

Expand Down Expand Up @@ -2876,7 +2917,7 @@ void error_usage() {
fprintf(stderr, " -i <string> train data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_train.bin)\n");
fprintf(stderr, " -j <string> val data filename pattern (default = dev/data/tinyshakespeare/tiny_shakespeare_val.bin)\n");
fprintf(stderr, " -e <string> input from model at this filename (default = gpt2_124M_bf16.bin)\n");
fprintf(stderr, " -o <string> output log file (default = NULL)\n");
fprintf(stderr, " -o <string> output log dir (default = NULL, no logging)\n");
fprintf(stderr, " -b <int> (per-GPU, micro) batch size B (default = 4)\n");
fprintf(stderr, " -t <int> sequence length T (default = 1024)\n");
fprintf(stderr, " -d <int> total desired batch size (default = B * T * num_processes, i.e. no grad accumulation\n");
Expand Down Expand Up @@ -2907,7 +2948,7 @@ int main(int argc, char *argv[]) {
const char* train_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_train.bin";
const char* val_data_pattern = "dev/data/tinyshakespeare/tiny_shakespeare_val.bin";
const char* load_filename = "gpt2_124M_bf16.bin"; // bf16 weights of the model
const char* output_log_file = NULL;
const char* output_log_dir = NULL;
int B = 4; // batch size
int T = 1024; // sequence length max
int total_batch_size = -1; // will be calculated down below later, if not provided
Expand Down Expand Up @@ -2935,7 +2976,7 @@ int main(int argc, char *argv[]) {
if (argv[i][1] == 'i') { train_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'j') { val_data_pattern = argv[i+1]; }
else if (argv[i][1] == 'e') { load_filename = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_file = argv[i+1]; }
else if (argv[i][1] == 'o') { output_log_dir = argv[i+1]; }
else if (argv[i][1] == 'b') { B = atoi(argv[i+1]); } // Per-GPU (micro) batch size
else if (argv[i][1] == 't') { T = atoi(argv[i+1]); }
else if (argv[i][1] == 'd') { total_batch_size = atoi(argv[i+1]); }
Expand All @@ -2959,6 +3000,12 @@ int main(int argc, char *argv[]) {
}
// should do a bit more error checking here
assert(warmup_iterations >= 0);
// check if output_log_dir has a "." in it, because this behavior changed May 24, 2024. take out later
if (output_log_dir != NULL && strstr(output_log_dir, ".") != NULL) {
fprintf(stderr, "-o (output_log_dir) has a '.', are you specifying a file instead of dir?\n");
fprintf(stderr, "(note that this option changed recently, -o used to be file, became dir.)\n");
exit(EXIT_FAILURE);
}
// calculate a sensible default for total batch size by assuming no gradient accumulation
if (total_batch_size == -1) { total_batch_size = B * T * multi_gpu_config.num_processes; }
// if we're only overfitting a single batch for debugging, let's overfit the first batch
Expand All @@ -2969,7 +3016,7 @@ int main(int argc, char *argv[]) {
printf0("+-----------------------+----------------------------------------------------+\n");
printf0("| train data pattern | %-50s |\n", train_data_pattern);
printf0("| val data pattern | %-50s |\n", val_data_pattern);
printf0("| output log file | %-50s |\n", output_log_file == NULL ? "NULL" : output_log_file);
printf0("| output log dir | %-50s |\n", output_log_dir == NULL ? "NULL" : output_log_dir);
printf0("| micro batch size B | %-50d |\n", B);
printf0("| sequence length T | %-50d |\n", T);
printf0("| total batch size | %-50d |\n", total_batch_size);
Expand Down Expand Up @@ -3071,9 +3118,10 @@ int main(int argc, char *argv[]) {
B, T, multi_gpu_config.num_processes, total_batch_size);
printf0("=> setting grad_accum_steps=%d\n", grad_accum_steps);

// set up the Logger
// set up logging
create_dir_if_not_exists(output_log_dir);
Logger logger;
logger_init(&logger, output_log_file);
logger_init(&logger, output_log_dir, multi_gpu_config.process_rank);

// set up the Tokenizer
Tokenizer tokenizer;
Expand Down

0 comments on commit fe698b3

Please sign in to comment.