Skip to content

Commit b3be58d

Browse files
committed
SGD (stochastic gradient descent) in ggml (examples/finetune)
1 parent aa59aa3 commit b3be58d

File tree

15 files changed

+290
-57
lines changed

15 files changed

+290
-57
lines changed

CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ if (NOT XCODE AND NOT MSVC AND NOT CMAKE_BUILD_TYPE)
1212
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release" "MinSizeRel" "RelWithDebInfo")
1313
endif()
1414

15+
message("CMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}")
16+
1517
# Add path to modules
1618
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/")
1719

common/arg.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1238,7 +1238,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
12381238
sampler_type_names.pop_back();
12391239

12401240
params.optimize = ggml_opt_get_default_optimizer_params(NULL);
1241-
params.optimize.adamw.alpha = 1e-8; // default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE
1241+
// default 1e-3 is much too high for LLAMA_EXAMPLE_FINETUNE -opt adamw
1242+
// (but could be ok for -opt sgd)
1243+
params.optimize.adamw.alpha = 1e-5;
12421244

12431245
/**
12441246
* filter options by example
@@ -2183,7 +2185,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21832185
}
21842186
).set_examples({LLAMA_EXAMPLE_PERPLEXITY}));
21852187
add_opt(common_arg({ "-lr", "--learning-rate" }, "ALPHA",
2186-
string_format("adamw optimizer alpha (default: %.1f)", (double) params.optimize.adamw.alpha),
2188+
string_format("adamw optimizer alpha (default: %.2g)", (double) params.optimize.adamw.alpha),
21872189
[](common_params & params, const std::string & value) {
21882190
params.optimize.adamw.alpha = std::stof(value);
21892191
})
@@ -2193,8 +2195,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
21932195
params.optimize.optimizer = named_ggml_opt_optimizer(name.c_str());
21942196
if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_COUNT) {
21952197
throw std::invalid_argument("invalid --optimizer (try adamw)");
2196-
} else if (params.optimize.optimizer == GGML_OPT_OPTIMIZER_SGD) {
2197-
throw std::invalid_argument("TODO: implement SGD");
21982198
}
21992199
})
22002200
.set_examples({ LLAMA_EXAMPLE_FINETUNE }));

ggml/include/ggml-opt.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,8 @@ extern "C" {
126126

127127
ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters
128128
void * get_opt_pars_ud; // userdata for calculating optimizer parameters
129+
struct ggml_opt_optimizer_params
130+
opt_params; // holds result of get_opt_pars(get_opt_pars_ud) after ggml_opt_init (could call get_opt_pars repeatedly instead)
129131
};
130132

131133
// get parameters for an optimization context with defaults set where possible

ggml/include/ggml.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ extern "C" {
450450
GGML_OP_REPEAT_BACK,
451451
GGML_OP_CONCAT,
452452
GGML_OP_SILU_BACK,
453-
GGML_OP_NORM, // normalize
453+
GGML_OP_NORM, // normalize
454454
GGML_OP_RMS_NORM,
455455
GGML_OP_RMS_NORM_BACK,
456456
GGML_OP_GROUP_NORM,
@@ -486,7 +486,7 @@ extern "C" {
486486
GGML_OP_POOL_1D,
487487
GGML_OP_POOL_2D,
488488
GGML_OP_POOL_2D_BACK,
489-
GGML_OP_UPSCALE, // nearest interpolate
489+
GGML_OP_UPSCALE, // nearest interpolate
490490
GGML_OP_PAD,
491491
GGML_OP_PAD_REFLECT_1D,
492492
GGML_OP_ARANGE,
@@ -517,6 +517,7 @@ extern "C" {
517517
GGML_OP_CROSS_ENTROPY_LOSS,
518518
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
519519
GGML_OP_OPT_STEP_ADAMW,
520+
GGML_OP_OPT_STEP_SGD,
520521

521522
GGML_OP_COUNT,
522523
};
@@ -2063,6 +2064,11 @@ extern "C" {
20632064
struct ggml_tensor * v,
20642065
struct ggml_tensor * adamw_params); // parameters such a the learning rate
20652066

2067+
// SGD (with weight decay) step
2068+
GGML_API struct ggml_tensor * ggml_opt_step_sgd(
2069+
struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad,
2070+
struct ggml_tensor * adamw_params); // parameters: alpha, the learning rate, and wd, weight decay
2071+
20662072
//
20672073
// automatic differentiation
20682074
//

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,6 +2057,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20572057
ggml_compute_forward_opt_step_adamw(params, tensor);
20582058
}
20592059
break;
2060+
case GGML_OP_OPT_STEP_SGD:
2061+
{
2062+
ggml_compute_forward_opt_step_sgd(params, tensor);
2063+
}
2064+
break;
20602065
case GGML_OP_NONE:
20612066
{
20622067
// nop
@@ -2341,6 +2346,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
23412346
case GGML_OP_CROSS_ENTROPY_LOSS:
23422347
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
23432348
case GGML_OP_OPT_STEP_ADAMW:
2349+
case GGML_OP_OPT_STEP_SGD:
23442350
{
23452351
n_tasks = n_threads;
23462352
} break;

ggml/src/ggml-cpu/ops.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8831,7 +8831,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88318831
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
88328832
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
88338833
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
8834-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
8834+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
88358835

88368836
const int ith = params->ith;
88378837
const int nth = params->nth;
@@ -8849,14 +8849,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88498849
const int ir1 = MIN(ir0 + dr, nr);
88508850

88518851
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8852+
88528853
const float alpha = adamw_params_ptr[0];
88538854
const float beta1 = adamw_params_ptr[1];
88548855
const float beta2 = adamw_params_ptr[2];
88558856
const float eps = adamw_params_ptr[3];
8856-
const float wd = adamw_params_ptr[4];
88578857
const float beta1h = adamw_params_ptr[5];
88588858
const float beta2h = adamw_params_ptr[6];
8859-
8859+
const float keep = adamw_params_ptr[7];
88608860
for (int ir = ir0; ir < ir1; ++ir) {
88618861
const int64_t i03 = ir/(ne02*ne01);
88628862
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
@@ -8879,7 +8879,7 @@ static void ggml_compute_forward_opt_step_adamw_f32(
88798879
// The weight decay is applied independently of the Adam momenta m and v.
88808880
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
88818881
// See: https://arxiv.org/pdf/1711.05101v3.pdf
8882-
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
8882+
w[i00] = w[i00] * keep - alpha * mh / vh;
88838883
}
88848884
}
88858885
}
@@ -8901,3 +8901,63 @@ void ggml_compute_forward_opt_step_adamw(
89018901
}
89028902
}
89038903
}
8904+
8905+
static void ggml_compute_forward_opt_step_sgd_f32(const ggml_compute_params * params, ggml_tensor * dst) {
8906+
const ggml_tensor * src0 = dst->src[0];
8907+
const ggml_tensor * src0_grad = dst->src[1];
8908+
const ggml_tensor * adamw_params = dst->src[2];
8909+
8910+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
8911+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
8912+
8913+
const int ith = params->ith;
8914+
const int nth = params->nth;
8915+
8916+
const int nr = ggml_nrows(src0);
8917+
8918+
GGML_TENSOR_UNARY_OP_LOCALS
8919+
GGML_ASSERT(nb00 == sizeof(float));
8920+
8921+
// rows per thread
8922+
const int dr = (nr + nth - 1) / nth;
8923+
8924+
// row range for this thread
8925+
const int ir0 = dr * ith;
8926+
const int ir1 = MIN(ir0 + dr, nr);
8927+
8928+
// using adamw param subset we care about - alpha, wd - could have a separate struct
8929+
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
8930+
const float alpha = adamw_params_ptr[0];
8931+
const float keep = adamw_params_ptr[7];
8932+
8933+
for (int ir = ir0; ir < ir1; ++ir) {
8934+
const int64_t i03 = ir / (ne02 * ne01);
8935+
const int64_t i02 = (ir - i03 * ne02 * ne01) / ne01;
8936+
const int64_t i01 = (ir - i03 * ne02 * ne01 - i02 * ne01);
8937+
8938+
const size_t offset = i03 * nb03 + i02 * nb02 + i01 * nb01;
8939+
8940+
float * w = (float *) ((char *) src0->data + offset); // weight
8941+
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
8942+
8943+
for (int i00 = 0; i00 < ne00; ++i00) {
8944+
w[i00] = w[i00] * keep - alpha * g[i00];
8945+
}
8946+
}
8947+
}
8948+
8949+
void ggml_compute_forward_opt_step_sgd(const ggml_compute_params * params, ggml_tensor * dst) {
8950+
const ggml_tensor * src0 = dst->src[0];
8951+
8952+
switch (src0->type) {
8953+
case GGML_TYPE_F32:
8954+
{
8955+
ggml_compute_forward_opt_step_sgd_f32(params, dst);
8956+
}
8957+
break;
8958+
default:
8959+
{
8960+
GGML_ABORT("fatal error - sgd is F32 only");
8961+
}
8962+
}
8963+
}

ggml/src/ggml-cpu/ops.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void ggml_compute_forward_custom(const struct ggml_compute_params * params, stru
104104
void ggml_compute_forward_cross_entropy_loss(const struct ggml_compute_params * params, struct ggml_tensor * dst);
105105
void ggml_compute_forward_cross_entropy_loss_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
106106
void ggml_compute_forward_opt_step_adamw(const struct ggml_compute_params * params, struct ggml_tensor * dst);
107-
107+
void ggml_compute_forward_opt_step_sgd(const struct ggml_compute_params * params, struct ggml_tensor * dst);
108108
#ifdef __cplusplus
109109
}
110110
#endif

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "ggml-cuda/mmvq.cuh"
2525
#include "ggml-cuda/norm.cuh"
2626
#include "ggml-cuda/opt-step-adamw.cuh"
27+
#include "ggml-cuda/opt-step-sgd.cuh"
2728
#include "ggml-cuda/out-prod.cuh"
2829
#include "ggml-cuda/pad.cuh"
2930
#include "ggml-cuda/pool2d.cuh"
@@ -2352,6 +2353,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
23522353
case GGML_OP_OPT_STEP_ADAMW:
23532354
ggml_cuda_opt_step_adamw(ctx, dst);
23542355
break;
2356+
case GGML_OP_OPT_STEP_SGD:
2357+
ggml_cuda_opt_step_sgd(ctx, dst);
2358+
break;
23552359
default:
23562360
return false;
23572361
}
@@ -3256,6 +3260,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
32563260
case GGML_OP_CROSS_ENTROPY_LOSS:
32573261
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
32583262
case GGML_OP_OPT_STEP_ADAMW:
3263+
case GGML_OP_OPT_STEP_SGD:
32593264
return true;
32603265
default:
32613266
return false;

ggml/src/ggml-cuda/opt-step-adamw.cu

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ static __global__ void opt_step_adamw_f32(
1717
const float beta1 = pars[1];
1818
const float beta2 = pars[2];
1919
const float eps = pars[3];
20-
const float wd = pars[4];
2120
const float beta1h = pars[5];
2221
const float beta2h = pars[6];
22+
const float keep = pars[7];
2323

2424
const float gi = g[i];
2525
const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1);
@@ -31,7 +31,7 @@ static __global__ void opt_step_adamw_f32(
3131
const float mh = gmi*beta1h;
3232
const float vh = sqrtf(gvi*beta2h) + eps;
3333

34-
x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh;
34+
x[i] = x[i] * keep - alpha * mh / vh;
3535
}
3636

3737
static void opt_step_adamw_f32_cuda(
@@ -62,14 +62,13 @@ void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst
6262
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
6363
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
6464
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
65-
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
65+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
6666

6767
float * src0_d = (float *) src0->data;
6868
const float * src0_grad_d = (const float *) src0_grad->data;
6969
float * src0_grad_m_d = (float *) src0_grad_m->data;
7070
float * src0_grad_v_d = (float *) src0_grad_v->data;
7171
const float * adamw_params_d = (const float *) adamw_params->data;
72-
7372
cudaStream_t stream = ctx.stream();
7473

7574
const int64_t ne = ggml_nelements(src0);

ggml/src/ggml-cuda/opt-step-sgd.cu

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
#include "ggml-impl.h"
2+
#include "opt-step-sgd.cuh"
3+
4+
#include <cstdint>
5+
6+
static __global__ void opt_step_sgd_f32(
7+
float * __restrict__ x, const float * __restrict__ g,
8+
const float * __restrict__ pars, const int64_t k) {
9+
10+
const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x;
11+
12+
if (i >= k)
13+
return;
14+
x[i] = x[i] * pars[7] - pars[0] * g[i];
15+
}
16+
17+
static void opt_step_sgd_f32_cuda(
18+
float * x, const float * g, const float * __restrict__ pars, const int64_t k, cudaStream_t stream) {
19+
20+
const dim3 block_dims(CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
21+
const dim3 block_nums((k + CUDA_OPT_STEP_SGD_BLOCK_SIZE - 1) / CUDA_OPT_STEP_SGD_BLOCK_SIZE, 1, 1);
22+
opt_step_sgd_f32<<<block_nums, block_dims, 0, stream>>>(x, g, pars, k);
23+
}
24+
25+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
26+
const ggml_tensor * src0 = dst->src[0];
27+
const ggml_tensor * src0_grad = dst->src[1];
28+
const ggml_tensor * adamw_params = dst->src[2];
29+
30+
GGML_ASSERT(src0->type == GGML_TYPE_F32);
31+
GGML_ASSERT(src0_grad->type == GGML_TYPE_F32);
32+
GGML_ASSERT(adamw_params->type == GGML_TYPE_F32);
33+
GGML_ASSERT(ggml_is_contiguous(src0));
34+
GGML_ASSERT(ggml_is_contiguous(src0_grad));
35+
GGML_ASSERT(ggml_is_contiguous(adamw_params));
36+
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
37+
GGML_ASSERT(ggml_nelements(adamw_params) == 8);
38+
39+
float * src0_d = (float *) src0->data;
40+
const float * src0_grad_d = (const float *) src0_grad->data;
41+
const float * adamw_params_d = (const float *) adamw_params->data;
42+
43+
cudaStream_t stream = ctx.stream();
44+
45+
const int64_t ne = ggml_nelements(src0);
46+
47+
opt_step_sgd_f32_cuda(src0_d, src0_grad_d, adamw_params_d, ne, stream);
48+
}

ggml/src/ggml-cuda/opt-step-sgd.cuh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "common.cuh"
2+
3+
#define CUDA_OPT_STEP_SGD_BLOCK_SIZE 256
4+
5+
void ggml_cuda_opt_step_sgd(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)