Skip to content

Commit

Permalink
use pytorch rand and fix dumb bug lol
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 23, 2024
1 parent bc1ebc1 commit 70a9c75
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 12 deletions.
84 changes: 83 additions & 1 deletion rand.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,85 @@
/*
Mersenne Twisters implementation, numerically identical to torch.
Example usage:
mt19937_state state;
manual_seed(&state, 137);
printf("%u\n", randint32(&state));
printf("%u\n", randint32(&state));
printf("%u\n", randint32(&state));
printf("%u\n", randint32(&state));
printf("%u\n", randint32(&state));
float t8[8];
normal_(t8, 8, 0, 1, &state);
for (int i = 0; i < 8; i++) {
printf("%f\n", t8[i]);
}
printf("%u\n", randint32(&state));
float t16[16];
normal_(t16, 16, 0, 1, &state);
for (int i = 0; i < 16; i++) {
printf("%f\n", t16[i]);
}
printf("%u\n", randint32(&state));
PyTorch reference (producing identical results):
import torch
torch.manual_seed(137)
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
t = torch.zeros(8);
t.normal_()
for i in range(len(t)) :
print(t[i].item())
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
t = torch.zeros(16);
t.normal_()
for i in range(len(t)) :
print(t[i].item())
print(torch.randint(0, 0xFFFFFFFF, [1]).item())
Both output:
4053805790
2173880614
380293709
1237255315
2986595568
0.7947664260864258
1.4369317293167114
- 0.2292192131280899
0.47556325793266296
- 0.6334410905838013
- 0.5791953802108765
- 0.0925704762339592
- 0.8659197092056274
2186503452
- 1.2813878059387207
- 2.646395683288574
- 0.06569503247737885
0.2180829495191574
- 0.46536165475845337
- 0.33108410239219666
2.5485482215881348
0.10425379872322083
0.8460659980773926
0.9462448358535767
- 0.2913765013217926
0.34313806891441345
- 1.1186704635620117
- 0.18305328488349915
- 2.3153159618377686
0.3961987793445587
2756748748
*/

#ifndef RAND_H
#define RAND_H

Expand Down Expand Up @@ -81,7 +163,7 @@ void uniform_(float* data, unsigned int numel, float from, float to, mt19937_sta
}
}

// BoxMuller transform
// BoxMuller transform

void normal_fill_16(float* data, float mean, float std, mt19937_state* state) {
#define EPSILONE 1e-12
Expand Down
26 changes: 15 additions & 11 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ This reads & runs in fp32, B=4, T=64, LR=1e-4, val/sample never (200),
// defines: dataloader_init, dataloader_reset, dataloader_next_batch, dataloader_free
// defines: evalloader_init, evalloader_reset, evalloader_next_batch, evalloader_free
#include "dataloader.h"

// defines: manual_seed, normal_
// numerically identical to PyTorch's torch.manual_seed and torch.normal
#include "rand.h"
// ----------------------------------------------------------------------------
// CUDA precision settings

Expand Down Expand Up @@ -2042,10 +2044,10 @@ void gpt2_build_from_random(GPT2 *model, int depth) {
model->config.num_layers = depth;
// follows GPT-2 sizes
int channels, num_heads;
if (depth == 12) { channels = 12; num_heads = 12; } // gpt2 (124M)
else if (depth == 24) { channels = 16; num_heads = 16; } // gpt2-medium (350M)
else if (depth == 36) { channels = 20; num_heads = 20; } // gpt2-large (774M)
else if (depth == 48) { channels = 25; num_heads = 25; } // gpt2-xl (1558M)
if (depth == 12) { channels = 768; num_heads = 12; } // gpt2 (124M)
else if (depth == 24) { channels = 1024; num_heads = 16; } // gpt2-medium (350M)
else if (depth == 36) { channels = 1280; num_heads = 20; } // gpt2-large (774M)
else if (depth == 48) { channels = 1600; num_heads = 25; } // gpt2-xl (1558M)
else { fprintf(stderr, "Unsupported depth for now\n"); exit(EXIT_FAILURE); }
model->config.channels = channels;
model->config.num_heads = num_heads;
Expand All @@ -2068,7 +2070,8 @@ void gpt2_build_from_random(GPT2 *model, int depth) {
// allocate and random init the memory for all the parameters with GPT-2 schema
// weights ~N(0, 0.02), biases 0, c_proj weights ~N(0, 0.02/(2*L)**0.5)
// NOTE: assuming all parameters are of the type floatX, could be relaxed later
unsigned long long init_rng_state = 42;
mt19937_state rng_state;
manual_seed(&rng_state, 42);
floatX* params_memory_cpu = (floatX*)mallocCheck(model->num_parameters_bytes);
memset(params_memory_cpu, 0, model->num_parameters_bytes);
// fill in all the weights with random values
Expand All @@ -2081,12 +2084,13 @@ void gpt2_build_from_random(GPT2 *model, int depth) {
// in GPT-2, the projections back into the residual stream are additionally
// scaled by 1/sqrt(2*L) for training stability
float scale = (i == 6 || i == 12) ? 0.02f * residual_scale : 0.02f;
for (size_t j = 0; j < model->param_elements[i]; j++) {
float f = random_f32(&init_rng_state); // random float in [0, 1]
f *= scale;
f -= 0.5f * scale; // mean 0
params_memory_cpu[offset + j] = (floatX)f;
int n = model->param_elements[i];
float *fp32_buffer = (float*)mallocCheck(n * sizeof(float));
normal_(fp32_buffer, n, 0.0f, scale, &rng_state);
for (size_t j = 0; j < n; j++) {
params_memory_cpu[offset + j] = (floatX)fp32_buffer[j];
}
free(fp32_buffer);
}
offset += model->param_elements[i];
}
Expand Down

0 comments on commit 70a9c75

Please sign in to comment.