Skip to content

Commit

Permalink
resolve merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
karpathy committed May 25, 2024
2 parents 16b364d + 4ff0412 commit f2ee356
Show file tree
Hide file tree
Showing 10 changed files with 745 additions and 192 deletions.
57 changes: 19 additions & 38 deletions dataloader.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
/*
Implements a medium simple DataLoader for a distributed training setup.
*/
#ifndef DATALOADER_H
#define DATALOADER_H

#include <stdio.h>
#include <stdlib.h>
Expand All @@ -13,33 +15,10 @@ Implements a medium simple DataLoader for a distributed training setup.
#include "utils.h"

// ----------------------------------------------------------------------------
// we need glob to list files matching a pattern
// windows does not have glob, so we fall back on a very simple implementation
// this implementation doesn't actually do a glob, it assumes that the "pattern"
// is exactly the single file of interest
// implementation of glob for Windows is in dev/unistd.h
#ifndef _WIN32
#include <glob.h>
#else

typedef struct glob_t {
size_t gl_pathc;
char **gl_pathv;
} glob_t;

int glob(const char *pattern, int flags, void *unused, glob_t *pglob) {
assert(strstr(pattern, "*") == NULL); // we don't support * here
pglob->gl_pathc = 1;
pglob->gl_pathv = (char **)malloc(sizeof(char *));
if (pglob->gl_pathv == NULL) { exit(EXIT_FAILURE); } // ??? oom?
pglob->gl_pathv[0] = (char *)pattern;
return 0;
}

void globfree(glob_t* pglob) {
free(pglob->gl_pathv);
}
#endif

// ----------------------------------------------------------------------------
// Distributed Data Loader
#define HEADER_SIZE 256
Expand All @@ -56,16 +35,16 @@ typedef struct {
glob_t glob_result; // stores the result of glob, for all shards we want to iterate
int current_shard; // the current shard we are reading from
FILE* tokens_file;
long file_size;
long current_position;
int64_t file_size;
int64_t current_position;
uint16_t* buffer; // we fread data from file into this buffer
// public variables that could be accessed from outside
size_t num_batches;
int* inputs; // input tokens into transformer
int* targets; // target tokens for the transformer
} DataLoader;

long dataloader_load_shard_(DataLoader *loader, int shard_index) {
int64_t dataloader_load_shard_(DataLoader *loader, int shard_index) {
// use the first glob match as the filename for now
const char* filename = loader->glob_result.gl_pathv[shard_index];
// open the input file for reading. also only a single file can be opened at a time
Expand All @@ -83,14 +62,14 @@ long dataloader_load_shard_(DataLoader *loader, int shard_index) {
exit(EXIT_FAILURE);
}
if (header[1] != 1) { printf("Bad version in data file\n"); exit(EXIT_FAILURE); }
long ntok = header[2]; // number of tokens in the file
int64_t ntok = header[2]; // number of tokens in the file
assert(ntok > 0); // we expect some tokens in the file. this should never trip, right?
// determine the file size and make sure it is consistent with the number of tokens
fseekCheck(loader->tokens_file, 0, SEEK_END); // seek to end of file
loader->file_size = ftell(loader->tokens_file); // read the offset, i.e. file size
fseekCheck(loader->tokens_file, 0, SEEK_SET); // seek back to the beginning
// we expect ntok in the file to be consistent with filesize, assert that is the case
long expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t);
int64_t expected_file_size = HEADER_SIZE * sizeof(int) + ntok * sizeof(uint16_t);
if (loader->file_size != expected_file_size) {
printf("Error: file size is not as expected\n");
exit(EXIT_FAILURE);
Expand All @@ -101,8 +80,8 @@ long dataloader_load_shard_(DataLoader *loader, int shard_index) {
void dataloader_reset(DataLoader *loader) {
// fully resets the DataLoader object to init configuration
// each process starts at a different offset in the file
long header_bytes = HEADER_SIZE * sizeof(int);
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
int64_t header_bytes = HEADER_SIZE * sizeof(int);
int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
loader->current_shard = 0;
loader->current_position = header_bytes + token_bytes_offset;
dataloader_load_shard_(loader, loader->current_shard);
Expand All @@ -115,8 +94,8 @@ void dataloader_advance_(DataLoader *loader) {
loader->current_shard = (loader->current_shard + 1) % loader->glob_result.gl_pathc;
dataloader_load_shard_(loader, loader->current_shard);
}
long header_bytes = HEADER_SIZE * sizeof(int);
long token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
int64_t header_bytes = HEADER_SIZE * sizeof(int);
int64_t token_bytes_offset = loader->process_rank * loader->B * loader->T * sizeof(uint16_t);
loader->current_position = header_bytes + token_bytes_offset;
}

Expand Down Expand Up @@ -145,9 +124,9 @@ void dataloader_init(DataLoader *loader,

// inspect and validate all shards so we don't get any runtime errors later
// if too slow / too many shards, may wish to revisit later
long ntok_total = 0;
int64_t ntok_total = 0;
for (int shard_index = 0; shard_index < loader->glob_result.gl_pathc; shard_index++) {
long shard_ntok = dataloader_load_shard_(loader, shard_index);
int64_t shard_ntok = dataloader_load_shard_(loader, shard_index);
// we need at least one batch/shard, the way things are written right now.
// can be relaxed a lot later.
assert(shard_ntok >= num_processes * B * T + 1);
Expand Down Expand Up @@ -229,7 +208,7 @@ typedef struct {
size_t T; // maximum context length of the model
// input handling and its state
FILE* eval_file;
long file_size;
int64_t file_size;
uint16_t* buffer; // we fread data from file into this buffer
// public variables that could be accessed from outside
int num_examples; // in total across all processes
Expand Down Expand Up @@ -261,7 +240,7 @@ void evalloader_reset(EvalLoader *loader) {
}
// now seek through the file to the start of that example
// utilize <EXAMPLE_BYTES> for efficiency
long header_bytes = HEADER_SIZE * sizeof(int);
int64_t header_bytes = HEADER_SIZE * sizeof(int);
fseekCheck(loader->eval_file, header_bytes, SEEK_SET);
for (int i = 0; i < loader->start_example_index; i++) {
uint16_t example_header[3];
Expand Down Expand Up @@ -460,4 +439,6 @@ void evalloader_free(EvalLoader *loader) {
free(loader->mask);
free(loader->label);
fcloseCheck(loader->eval_file);
}
}

#endif // DATALOADER_H
62 changes: 29 additions & 33 deletions dev/cuda/classifier_fused.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef Packed128<floatX> x128;
void softmax_forward_cpu(float* out, const float* inp, int N, int C) {
// inp is (N, C)
// out is (N, C), each row of inp will get softmaxed
for (int i = 0; i < N; i++) {
for (int64_t i = 0; i < N; i++) {
const float* inp_row = inp + i * C;
float* out_row = out + i * C;

Expand Down Expand Up @@ -66,31 +66,27 @@ void crossentropy_forward_cpu(float* losses,
// output: losses is (B,T) of the individual losses at each position
// input: probs are (B,T,V) of the probabilities
// input: targets is (B,T) of integers giving the correct index in logits
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
// loss = -log(probs[target])
const float* probs_bt = probs + b * T * V + t * V;
int ix = targets[b * T + t];
losses[b * T + t] = -logf(probs_bt[ix]);
}
for (int64_t bt = 0; bt < B * T; bt++) {
// loss = -log(probs[target])
const float* probs_bt = probs + bt * V;
int ix = targets[bt];
losses[bt] = -logf(probs_bt[ix]);
}
}

void crossentropy_softmax_backward_cpu(float* dlogits,
const float* dlosses, const float* probs, const int* targets,
int B, int T, int V) {
// backwards through both softmax and crossentropy
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
float* dlogits_bt = dlogits + b * T * V + t * V;
const float* probs_bt = probs + b * T * V + t * V;
float dloss = dlosses[b * T + t];
int ix = targets[b * T + t];
for (int i = 0; i < V; i++) {
float p = probs_bt[i];
float indicator = i == ix ? 1.0f : 0.0f;
dlogits_bt[i] = (p - indicator) * dloss;
}
for (int64_t bt = 0; bt < B * T; bt++) {
float* dlogits_bt = dlogits + bt * V;
const float* probs_bt = probs + bt * V;
float dloss = dlosses[bt];
int ix = targets[bt];
for (int i = 0; i < V; i++) {
float p = probs_bt[i];
float indicator = i == ix ? 1.0f : 0.0f;
dlogits_bt[i] = (p - indicator) * dloss;
}
}
}
Expand All @@ -115,7 +111,7 @@ struct SoftmaxParams {
};
namespace cg = cooperative_groups;
__device__ SoftmaxParams prepare_softmax(cg::thread_block_tile<32>& warp,
int idx, const float* inp, int V, int P) {
int64_t idx, const float* inp, int V, int P) {
// this warp (of 32) threads processes one row of inp, i.e. inp[idx, :] of shape (V,)
// note that inp is actually (B * T, P) but we only use the first V elements
// this function tehen calculates:
Expand Down Expand Up @@ -155,7 +151,7 @@ __global__ void fused_classifier_kernel1(float* dlogits, float* losses,
// each block of 4 warps is in charge of 4 rows of the input, one warp per row
// meta_group_size is the number of warps per block (e.g. 4)
// meta_group_rank is the index of the warp in the block (e.g. 0, 1, 2, 3)
int idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
int64_t idx = blockIdx.x * warp.meta_group_size() + warp.meta_group_rank();
if (idx >= B * T) { // there are B * T rows in the input
return;
}
Expand Down Expand Up @@ -192,7 +188,7 @@ __device__ float vec_at(const float4& vec, int index) {
}

__device__ SoftmaxParams prepare_softmax_blockwide(cg::thread_block_tile<32>& warp,
int idx, const float* inp, int V, int P) {
int64_t idx, const float* inp, int V, int P) {
// one row of inp, i.e. inp[idx, :] of shape (V,)
// float4 to get 128-bit loads and memory level parallelism
const float4* x_vec4 = reinterpret_cast<const float4*>(inp + idx * P);
Expand Down Expand Up @@ -256,7 +252,7 @@ __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* p
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x;
int64_t idx = blockIdx.x;
int ix = targets[idx];

// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
Expand Down Expand Up @@ -297,7 +293,7 @@ __global__ void fused_classifier_kernel2(float* dlogits, float* losses, float* p
}

__device__ SoftmaxParams prepare_softmax_blockwide_nofloat4(cg::thread_block_tile<32>& warp,
int idx, const float* inp, int V, int P) {
int64_t idx, const float* inp, int V, int P) {
// same but not float4
// one row of inp, i.e. inp[idx, :] of shape (V,)

Expand Down Expand Up @@ -353,7 +349,7 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p
namespace cg = cooperative_groups;
cg::thread_block block = cg::this_thread_block();
cg::thread_block_tile<32> warp = cg::tiled_partition<32>(block);
int idx = blockIdx.x;
int64_t idx = blockIdx.x;
int ix = targets[idx];

// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
Expand Down Expand Up @@ -385,7 +381,7 @@ __global__ void fused_classifier_kernel3(float* dlogits, float* losses, float* p
}
}

__device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp, int V, int P) {
__device__ SoftmaxParams prepare_softmax_blockwide2(int64_t idx, const floatX* inp, int V, int P) {
// one row of inp, i.e. inp[idx, :] of shape (V,)

const floatX* x = inp + idx * P;
Expand Down Expand Up @@ -443,7 +439,7 @@ __device__ SoftmaxParams prepare_softmax_blockwide2(int idx, const floatX* inp,
__global__ void fused_classifier_kernel4(floatX* dlogits, floatX* losses, floatX* probs,
const floatX* logits, const floatX* dlosses, const int* targets,
int B, int T, int V, int P) {
int idx = blockIdx.x;
int64_t idx = blockIdx.x;
int ix = targets[idx];

// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
Expand Down Expand Up @@ -512,7 +508,7 @@ __device__ float blockReduce(float val, bool final_sync=false, float out_of_boun
return block_val;
}

__device__ SoftmaxParams prepare_softmax_blockwide3(int idx, const floatX* inp, int V, int P) {
__device__ SoftmaxParams prepare_softmax_blockwide3(int64_t idx, const floatX* inp, int V, int P) {
// same but not float4
// one row of inp, i.e. inp[idx, :] of shape (V,)

Expand Down Expand Up @@ -566,7 +562,7 @@ __global__ void __launch_bounds__(1024, MAX_1024_THREADS_BLOCKS)
fused_classifier_kernel5(floatX* dlogits, floatX* losses, floatX* probs,
const floatX* logits, const floatX* dlosses, const int* targets,
int B, int T, int V, int P) {
int idx = blockIdx.x;
int64_t idx = blockIdx.x;
int ix = targets[idx];

// softmax (reading B * T * V, same logits read again below, hopefully still in cache)
Expand Down Expand Up @@ -702,10 +698,10 @@ void fused_classifier(int kernel_num, float* dlogits, float* losses,
int main(int argc, char **argv) {
srand(0);

int B = 8; // batch size
int T = 1024; // sequence length
int V = 50257; // vocab size
int P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64
int64_t B = 8; // batch size
int64_t T = 1024; // sequence length
int64_t V = 50257; // vocab size
int64_t P = (V + 63) & ~63; // padded vocab size, up to nearest multiple of 64

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
Expand Down
Loading

0 comments on commit f2ee356

Please sign in to comment.