-
Notifications
You must be signed in to change notification settings - Fork 995
Description
clearly, this model is not intended for mortals, as it requires nvidia H100 and L40S gpus to train any problem
please republish and instead of calling it tiny, call it simple. it is simple compared to HRM and like HRM actually implements Iteratively reweighted kernel machines
efficiently learn sparse functions
by Zhu, Davis, Drusvyatskiy and Fazel
We note that the iterative approach is actually suitable for newtonian descent as proposed in ParaRNN by Apple, and we offer tentative code to replace or extend trm.py (use at own peril- not tested, coded by today's emergent k2-reasoning model) with iteratively newton acceleration for training.
Naturally, if we owned a H100 cluster, we would be happy to run this code, but like so many other "experiments" released to the public by corporate and institutional entities, training this model remains out of the reach of individuals who only have access to personal computing machines, probably largely due to how it is designed and not due to any architectural limitations.
from typing import Tuple, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding
IGNORE_LABEL_ID = -100
@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
z_H: torch.Tensor
z_L: torch.Tensor
@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
steps: torch.Tensor
halted: torch.Tensor
current_data: Dict[str, torch.Tensor]
class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
batch_size: int
seq_len: int
puzzle_emb_ndim: int = 0
num_puzzle_identifiers: int
vocab_size: int
H_cycles: int
L_cycles: int
H_layers: int # Retained for API compatibility, ignored in Newton variant
L_layers: int
# Transformer config
hidden_size: int
expansion: float
num_heads: int
pos_encodings: str
rms_norm_eps: float = 1e-5
rope_theta: float = 10000.0
# Halting Q-learning config
halt_max_steps: int
halt_exploration_prob: float
forward_dtype: str = "bfloat16"
# Newton-specific additions
newton_max_cg_iters: int = 5
newton_trust_region_mu: float = 0.1
newton_enable: bool = True # Flag to toggle Newton vs original
# Original config parameters
mlp_t: bool = False
puzzle_emb_len: int = 16
no_ACT_continue: bool = True
class TinyRecursiveReasoningModel_NewtonVariant(nn.Module):
"""Newton-accelerated drop-in replacement for TinyRecursiveReasoningModel_ACTV1_Inner"""
def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
super().__init__()
self.config = config
self.forward_dtype = getattr(torch, self.config.forward_dtype)
# I/O layers (identical to original)
self.embed_scale = math.sqrt(self.config.hidden_size)
embed_init_std = 1.0 / self.embed_scale
self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size,
init_std=embed_init_std, cast_to=self.forward_dtype)
self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)
# Position and puzzle embeddings (identical to original)
self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) \
if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
if self.config.puzzle_emb_ndim > 0:
self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)
if self.config.pos_encodings == "rope":
self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
base=self.config.rope_theta)
elif self.config.pos_encodings == "learned":
self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size,
init_std=embed_init_std, cast_to=self.forward_dtype)
# Core network (single network for both f_y and f_z)
self.core_net = self._build_core_network()
# Initial states (identical to original)
self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
# Q-head initialization (identical to original)
with torch.no_grad():
self.q_head.weight.zero_()
self.q_head.bias.fill_(-5)
def _build_core_network(self):
"""Build the core 2-layer network used for both reasoning and refinement"""
layers = []
for _ in range(self.config.L_layers):
layers.append(TinyRecursiveReasoningModel_ACTV1Block(self.config))
return TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers)
def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
"""Identical to original implementation"""
embedding = self.embed_tokens(input.to(torch.int32))
if self.config.puzzle_emb_ndim > 0:
puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
if pad_count > 0:
puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)
if self.config.pos_encodings == "learned":
embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))
return self.embed_scale * embedding
def empty_carry(self, batch_size: int):
"""Identical API to original"""
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
)
def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
"""Identical API to original"""
return TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
)
def _compute_residual_and_jacobian(self, y: torch.Tensor, z: torch.Tensor,
x_emb: torch.Tensor, cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute residual vector and Jacobian blocks for Newton step
Residual: r = [y - f_y(y,z); z - f_z(x,y,z)]
Jacobian: J = [∂f_y/∂y - I, ∂f_y/∂z; ∂f_z/∂y, ∂f_z/∂z - I]
"""
batch_size, seq_len, hidden_size = y.shape
state_dim = batch_size * seq_len * hidden_size
# Enable grad for Jacobian computation
with torch.enable_grad():
y_req = y.detach().requires_grad_(True)
z_req = z.detach().requires_grad_(True)
# Compute f_y(y,z) - note: x is NOT injected here
f_y_input = y_req # net expects (y, z) when x absent
f_y_output = self.core_net(f_y_input, z_req, cos_sin=cos_sin)
# Compute f_z(x,y,z) - note: x IS injected here
f_z_input = z_req
injection = y_req + x_emb
f_z_output = self.core_net(f_z_input, injection, cos_sin=cos_sin)
# Compute residuals
r_y = y_req - f_y_output
r_z = z_req - f_z_output
# Compute Jacobian blocks via autograd
def compute_jacobian_block(output, input_var):
"""Compute Jacobian block ∂output/∂input_var"""
# Flatten for batch processing
output_flat = output.reshape(-1, hidden_size)
input_flat = input_var.reshape(-1, hidden_size)
# Initialize Jacobian block
jacobian_block = torch.zeros(batch_size * seq_len, hidden_size, hidden_size,
device=output.device, dtype=output.dtype)
# Compute column-by-column using vmap-like approach
for i in range(hidden_size):
v = torch.zeros_like(input_flat)
v[:, i] = 1.0
grad_outputs = torch.autograd.grad(
outputs=output_flat,
inputs=input_flat,
grad_outputs=v,
create_graph=False,
retain_graph=True
)[0]
jacobian_block[:, :, i] = grad_outputs
return jacobian_block
# Compute blocks
J_yy = compute_jacobian_block(f_y_output, y_req) - torch.eye(hidden_size, device=y.device, dtype=y.dtype).unsqueeze(0).expand(batch_size*seq_len, -1, -1)
J_yz = compute_jacobian_block(f_y_output, z_req)
J_zy = compute_jacobian_block(f_z_output, y_req)
J_zz = compute_jacobian_block(f_z_output, z_req) - torch.eye(hidden_size, device=z.device, dtype=z.dtype).unsqueeze(0).expand(batch_size*seq_len, -1, -1)
# Assemble full Jacobian as sparse block operator
residual = torch.cat([r_y.reshape(batch_size*seq_len, hidden_size),
r_z.reshape(batch_size*seq_len, hidden_size)], dim=0)
def jacobian_vector_product(v):
"""J * v where v is [v_y; v_z] concatenated"""
v_y = v[:batch_size*seq_len].reshape(batch_size*seq_len, hidden_size)
v_z = v[batch_size*seq_len:].reshape(batch_size*seq_len, hidden_size)
out_y = torch.matmul(J_yy, v_y.unsqueeze(-1)).squeeze(-1) + \
torch.matmul(J_yz, v_z.unsqueeze(-1)).squeeze(-1)
out_z = torch.matmul(J_zy, v_y.unsqueeze(-1)).squeeze(-1) + \
torch.matmul(J_zz, v_z.unsqueeze(-1)).squeeze(-1)
return torch.cat([out_y, out_z], dim=0)
return residual.reshape(-1), jacobian_vector_product
def _newton_step(self, y: torch.Tensor, z: torch.Tensor, x_emb: torch.Tensor,
cos_sin, max_iters: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
"""Perform Newton step with CG solver for simultaneous (y,z) update"""
residual, jvp = self._compute_residual_and_jacobian(y, z, x_emb, cos_sin)
# Initial guess for delta
delta = torch.zeros_like(residual)
# Trust-region CG solver
def cg_step(k, x, r, p):
"""Single CG iteration with trust region damping"""
Ap = jvp(p)
alpha = torch.sum(r * r) / (torch.sum(p * Ap) + self.config.newton_trust_region_mu * torch.sum(p * p))
x_new = x + alpha * p
r_new = r - alpha * Ap
beta = torch.sum(r_new * r_new) / torch.sum(r * r)
p_new = r_new + beta * p
return x_new, r_new, p_new
# Run CG iterations
r = residual.clone()
p = r.clone()
for _ in range(max_iters):
delta, r, p = cg_step(0, delta, r, p)
if torch.norm(r) < 1e-6:
break
# Apply update
batch_size, seq_len, hidden_size = y.shape
state_dim = batch_size * seq_len * hidden_size
delta_y = delta[:state_dim].reshape(batch_size, seq_len, hidden_size)
delta_z = delta[state_dim:].reshape(batch_size, seq_len, hidden_size)
return y + delta_y, z + delta_z
def _standard_recurse(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry,
x_emb: torch.Tensor, cos_sin) -> TinyRecursiveReasoningModel_ACTV1InnerCarry:
"""Original recursion method for comparison/backup"""
z_H, z_L = carry.z_H, carry.z_L
# H_cycles-1 without grad
with torch.no_grad():
for _H_step in range(self.config.H_cycles-1):
for _L_step in range(self.config.L_cycles):
z_L = self.core_net(z_L, z_H + x_emb, cos_sin=cos_sin)
z_H = self.core_net(z_H, z_L, cos_sin=cos_sin)
# 1 with grad
for _L_step in range(self.config.L_cycles):
z_L = self.core_net(z_L, z_H + x_emb, cos_sin=cos_sin)
z_H = self.core_net(z_H, z_L, cos_sin=cos_sin)
return TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H, z_L=z_L)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry,
batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Forward pass with Newton-accelerated recursion or standard fallback.
API is identical to original TinyRecursiveReasoningModel_ACTV1_Inner.
"""
seq_info = dict(cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None)
x_emb = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])
if self.config.newton_enable:
# Newton-accelerated recursion
z_H, z_L = carry.z_H, carry.z_L
# H_cycles-1 without grad (use standard recursion for state warming)
with torch.no_grad():
for _H_step in range(self.config.H_cycles - 1):
# Single Newton step per H_cycle for state improvement
z_H, z_L = self._newton_step(z_H, z_L, x_emb, seq_info.get('cos_sin'),
max_iters=self.config.newton_max_cg_iters)
# Final H_cycle with grad: perform Newton step and compute outputs
z_H, z_L = self._newton_step(z_H, z_L, x_emb, seq_info.get('cos_sin'),
max_iters=self.config.newton_max_cg_iters)
else:
# Standard recursion fallback
new_carry = self._standard_recurse(carry, x_emb, seq_info.get('cos_sin'))
z_H, z_L = new_carry.z_H, new_carry.z_L
# Detach for supervision step
new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(
z_H=z_H.detach(),
z_L=z_L.detach()
)
# Output heads (identical to original)
output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
return new_carry, output, (q_logits[..., 0], q_logits[..., 1])
class TinyRecursiveReasoningModel_ACTV1(nn.Module):
"""ACT wrapper - identical API to original"""
def __init__(self, config_dict: dict):
super().__init__()
self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
self.inner = TinyRecursiveReasoningModel_NewtonVariant(self.config)
@property
def puzzle_emb(self):
return self.inner.puzzle_emb
def initial_carry(self, batch: Dict[str, torch.Tensor]):
"""Identical API to original"""
batch_size = batch["inputs"].shape[0]
return TinyRecursiveReasoningModel_ACTV1Carry(
inner_carry=self.inner.empty_carry(batch_size),
steps=torch.zeros((batch_size, ), dtype=torch.int32),
halted=torch.ones((batch_size, ), dtype=torch.bool),
current_data={k: torch.empty_like(v) for k, v in batch.items()}
)
def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry,
batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
"""Identical API to original ACT wrapper"""
new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
new_steps = torch.where(carry.halted, 0, carry.steps)
new_current_data = {
k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)),
batch[k], v)
for k, v in carry.current_data.items()
}
# Forward inner model
new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)
outputs = {
"logits": logits,
"q_halt_logits": q_halt_logits,
"q_continue_logits": q_continue_logits
}
with torch.no_grad():
new_steps = new_steps + 1
is_last_step = new_steps >= self.config.halt_max_steps
halted = is_last_step
if self.training and (self.config.halt_max_steps > 1):
if self.config.no_ACT_continue:
halted = halted | (q_halt_logits > 0)
else:
halted = halted | (q_halt_logits > q_continue_logits)
min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * \
torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
halted = halted & (new_steps >= min_halt_steps)
if not self.config.no_ACT_continue:
_, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
outputs["target_q_continue"] = torch.sigmoid(
torch.where(is_last_step, next_q_halt_logits,
torch.maximum(next_q_halt_logits, next_q_continue_logits))
)
return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs