diff --git a/rand.h b/rand.h index f69340d20..e60e5e6a9 100644 --- a/rand.h +++ b/rand.h @@ -1,3 +1,85 @@ +/* +Mersenne Twisters implementation, numerically identical to torch. + +Example usage: + + mt19937_state state; + manual_seed(&state, 137); + printf("%u\n", randint32(&state)); + printf("%u\n", randint32(&state)); + printf("%u\n", randint32(&state)); + printf("%u\n", randint32(&state)); + printf("%u\n", randint32(&state)); + + float t8[8]; + normal_(t8, 8, 0, 1, &state); + for (int i = 0; i < 8; i++) { + printf("%f\n", t8[i]); + } + printf("%u\n", randint32(&state)); + + float t16[16]; + normal_(t16, 16, 0, 1, &state); + for (int i = 0; i < 16; i++) { + printf("%f\n", t16[i]); + } + printf("%u\n", randint32(&state)); + +PyTorch reference (producing identical results): + + import torch + torch.manual_seed(137) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + t = torch.zeros(8); + t.normal_() + for i in range(len(t)) : + print(t[i].item()) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + t = torch.zeros(16); + t.normal_() + for i in range(len(t)) : + print(t[i].item()) + print(torch.randint(0, 0xFFFFFFFF, [1]).item()) + +Both output: + + 4053805790 + 2173880614 + 380293709 + 1237255315 + 2986595568 + 0.7947664260864258 + 1.4369317293167114 + - 0.2292192131280899 + 0.47556325793266296 + - 0.6334410905838013 + - 0.5791953802108765 + - 0.0925704762339592 + - 0.8659197092056274 + 2186503452 + - 1.2813878059387207 + - 2.646395683288574 + - 0.06569503247737885 + 0.2180829495191574 + - 0.46536165475845337 + - 0.33108410239219666 + 2.5485482215881348 + 0.10425379872322083 + 0.8460659980773926 + 0.9462448358535767 + - 0.2913765013217926 + 0.34313806891441345 + - 1.1186704635620117 + - 0.18305328488349915 + - 2.3153159618377686 + 0.3961987793445587 + 2756748748 +*/ + #ifndef RAND_H #define RAND_H @@ -81,7 +163,7 @@ void uniform_(float* data, unsigned int numel, float from, float to, mt19937_sta } } -// Box–Muller transform +// Box�Muller transform void normal_fill_16(float* data, float mean, float std, mt19937_state* state) { #define EPSILONE 1e-12 diff --git a/train_gpt2.cu b/train_gpt2.cu index 31409a1bd..edd316508 100644 --- a/train_gpt2.cu +++ b/train_gpt2.cu @@ -57,7 +57,9 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200), // defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free // defines: evalloader_init, evalloader_reset, evalloader_next_batch, evalloader_free #include "dataloader.h" - +// defines: manual_seed, normal_ +// numerically identical to PyTorch's torch.manual_seed and torch.normal +#include "rand.h" // ---------------------------------------------------------------------------- // CUDA precision settings @@ -2042,10 +2044,10 @@ void gpt2_build_from_random(GPT2 *model, int depth) { model->config.num_layers = depth; // follows GPT-2 sizes int channels, num_heads; - if (depth == 12) { channels = 12; num_heads = 12; } // gpt2 (124M) - else if (depth == 24) { channels = 16; num_heads = 16; } // gpt2-medium (350M) - else if (depth == 36) { channels = 20; num_heads = 20; } // gpt2-large (774M) - else if (depth == 48) { channels = 25; num_heads = 25; } // gpt2-xl (1558M) + if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M) + else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M) + else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M) + else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M) else { fprintf(stderr, "Unsupported depth for now\n"); exit(EXIT_FAILURE); } model->config.channels = channels; model->config.num_heads = num_heads; @@ -2068,7 +2070,8 @@ void gpt2_build_from_random(GPT2 *model, int depth) { // allocate and random init the memory for all the parameters with GPT-2 schema // weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5) // NOTE: assuming all parameters are of the type floatX, could be relaxed later - unsigned long long init_rng_state = 42; + mt19937_state rng_state; + manual_seed(&rng_state, 42); floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes); memset(params_memory_cpu, 0, model->num_parameters_bytes); // fill in all the weights with random values @@ -2081,12 +2084,13 @@ void gpt2_build_from_random(GPT2 *model, int depth) { // in GPT-2, the projections back into the residual stream are additionally // scaled by 1/sqrt(2*L) for training stability float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f; - for (size_t j = 0; j < model->param_elements[i]; j++) { - float f = random_f32(&init_rng_state); // random float in [0, 1] - f *= scale; - f -= 0.5f * scale; // mean 0 - params_memory_cpu[offset + j] = (floatX)f; + int n = model->param_elements[i]; + float *fp32_buffer = (float*)mallocCheck(n * sizeof(float)); + normal_(fp32_buffer, n, 0.0f, scale, &rng_state); + for (size_t j = 0; j < n; j++) { + params_memory_cpu[offset + j] = (floatX)fp32_buffer[j]; } + free(fp32_buffer); } offset += model->param_elements[i]; }