diff --git a/src/backend.py b/src/backend.py index 8e8a361e..8473da0e 100644 --- a/src/backend.py +++ b/src/backend.py @@ -115,7 +115,7 @@ def prefixed_name(ctx: Context, name: str): def assign(ctx: Context, name: str, inp: jax.Array): name = prefixed_name(ctx, name) - ctx.parameters[name] = inp + ctx.parameters[name] = inp.astype(ctx.parameters[name].dtype) def normal(ctx: Context, shape: Sequence[int]): diff --git a/src/context.py b/src/context.py index 71cc0665..85b63338 100644 --- a/src/context.py +++ b/src/context.py @@ -116,6 +116,9 @@ class Optimizer(DataClass): weight_decay: float = 0.01 warmup_end: int = 16384 exponential_decay: float = 3e-6 + svd_components: int = 8 + fisher_decay: float = 0.99 + log_matrix_power: int = 5 # 2^x+1 is actual power class Normalization(DataClass): diff --git a/src/optimizer.py b/src/optimizer.py index 52686fbd..a2292d8d 100644 --- a/src/optimizer.py +++ b/src/optimizer.py @@ -3,7 +3,7 @@ import jax from jax import lax, numpy as jnp -from src.backend import add_sq, assign, default, get_param, is_stacked, stable_rsqrt, with_context +from src.backend import assign, default, get_param, is_stacked, with_context, stable_rsqrt from src.constants import MomentumType from src.context import Context @@ -63,19 +63,6 @@ def graft(param_name: str, magnitude: jax.Array, direction: jax.Array) -> jax.Ar return direction * jnp.sqrt(norm(param_name, magnitude) / jnp.maximum(norm(param_name, direction), 1e-16)) -def tg_adam(ctx: Context, param_name: str, grad: jax.Array, tg_grad: jax.Array, step: jax.Array) -> jax.Array: - ema_g = ema(ctx, grad, step, 1 - ctx.optimizer.adam_beta1) - ema_gsq = ema(ctx, grad ** 2, step, 1 - ctx.optimizer.adam_beta2) - ema_tgsq = ema(ctx, tg_grad, step, 1 - ctx.optimizer.adam_beta3) - - if ctx.is_initializing: - return grad - - adam_update = ema_g * stable_rsqrt(ema_gsq, ctx.optimizer.epsilon) - tg_update = ema_g * stable_rsqrt(ema_tgsq, ctx.optimizer.epsilon) - return graft(param_name, adam_update, tg_update) - - def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array: opt = ctx.optimizer learning_rate = opt.learning_rate @@ -85,26 +72,61 @@ def get_current_lr(ctx: Context, step: jax.Array) -> jax.Array: return learning_rate.astype(ctx.model.storage_dtype) +def normalize(x: jax.Array) -> jax.Array: + return x * lax.rsqrt(lax.square(x).sum(1)) + + +def svd_fisher(ctx: Context, grad: jax.Array): + key = jax.random.PRNGKey(ctx.seed) + vectors = normalize(jax.random.normal(key, (ctx.optimizer.svd_components, grad.shape[0])).astype(jnp.float64)) + u = get_param(ctx, "u", vectors.shape[::-1], dtype=ctx.optimizer.momentum_dtype, tied=True, + init_val=jnp.zeros_like(vectors)).astype(jnp.float64) + v = get_param(ctx, "v", vectors.shape, dtype=ctx.optimizer.momentum_dtype, tied=True, + init_val=jnp.zeros_like(vectors)).astype(jnp.float64) + + mid = jnp.eye(ctx.optimizer.svd_components * 2 + 1) + mid = mid.at[:ctx.optimizer.svd_components, :ctx.optimizer.svd_components].set(jnp.transpose(u, (1, 0)) @ u) + grad = grad * (1 - ctx.optimizer.fisher_decay) + x0 = jnp.concatenate([u * ctx.optimizer.fisher_decay, grad], 1) + x0t = jnp.concatenate([v * ctx.optimizer.fisher_decay, grad], 0) + grad = grad - ((grad @ x0) @ jnp.linalg.inv(jnp.eye(ctx.optimizer.svd_components + 1) + x0t @ x0)) @ x0t + + for i, v in enumerate(vectors, 1): + local_mid = mid[:ctx.optimizer.svd_components + i, :ctx.optimizer.svd_components + i] + b0 = normalize(x0 @ local_mid) + b1 = normalize(x0t) + inner = b1 @ b0 + for _ in range(ctx.optimizer.log_matrix_power): + inner = inner @ inner + v = b0 @ (inner @ (b1 @ v)) # brackets for speed (V=[N,1], b1=[N,K], inner=[K,K], b0=[K,N) + u = x0 @ (local_mid @ (x0t @ v)) + x0 = jnp.concatenate([x0, u.reshape(-1, 1)], 1) + x0t = jnp.concatenate([x0t, v.reshape(-1, 1)], 0) + assign(ctx, "u", x0[:, -vectors:]) + assign(ctx, "v", x0t[-vectors:, :]) + return grad + + def update(ctx: Context, grads: Dict[str, jax.Array], step: jax.Array): - outer_ctx = ctx.add_to_prefix("optimizer") + ctx = ctx.add_to_prefix("optimizer") lr = -get_current_lr(ctx, step) + keys = [k for k in grads.keys() if "optimizer" not in k and not k.endswith('_sq') and not k.endswith('_sq_stacked')] + grads = jnp.concatenate([adaptive_gradient_clipping(ctx, k, grads[k].reshape(-1), False) for k in keys], 0) - for param_name, grad in grads.items(): - if "optimizer" in param_name or param_name.endswith('_sq') or param_name.endswith('_sq_stacked'): - continue - ctx = outer_ctx.add_to_prefix(param_name, count=False) - ctx.name_cache = {} - dtype = ctx.parameters[param_name].dtype - parameter_lr = lr * ctx.parameter_variance.get(param_name, 1) - - grad = adaptive_gradient_clipping(ctx, param_name, grad, False) - grad_sq = adaptive_gradient_clipping(ctx, param_name, grads[add_sq(param_name)], True) - weight_update = tg_adam(ctx, param_name, grad, grad_sq, step) * parameter_lr + ctx.name_cache = {} + ema_gsq = ema(ctx, lax.square(grads), step, 1 - ctx.optimizer.adam_beta2) + adam = ema(ctx, grads / stable_rsqrt(ema_gsq, ctx.optimizer.epsilon), step, 1 - ctx.optimizer.adam_beta1) + prec = svd_fisher(ctx, grads) - if ctx.is_initializing: - continue + if ctx.is_initializing: + return - param = ctx.parameters[param_name].astype(jnp.float64) + offset = 0 + for param_name in keys: + param = ctx.parameters[param_name] + dtype = ctx.parameters[param_name].dtype + parameter_lr = lr * ctx.parameter_variance.get(param_name, 1) + grad = graft(param_name, adam[offset:offset + param.size], prec[offset:offset + param.size]) * parameter_lr if not small_parameter(param_name, grad): param *= 1 + ctx.optimizer.weight_decay * parameter_lr - ctx.parameters[param_name] = (param + weight_update).astype(dtype) + ctx.parameters[param_name] = (param + grad).astype(dtype)