diff --git a/profile_gpt2.cu b/profile_gpt2.cu index 4b24c8973..5a6764533 100644 --- a/profile_gpt2.cu +++ b/profile_gpt2.cu @@ -54,7 +54,7 @@ int main(int argc, char *argv[]) { gpt2_forward(&model, x, y, B, T); gpt2_zero_grad(&model); gpt2_backward(&model); - gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1, &multi_gpu_config); + gpt2_update(&model, 1e-4f, 0.9f, 0.999f, 1e-8f, 0.0f, 1.f, 1, &multi_gpu_config); cudaCheck(cudaDeviceSynchronize()); // finish all CUDA work to get correct precise timings // free