Skip to content

misuse of the word "tiny" #45

@falseywinchnet

Description

@falseywinchnet

clearly, this model is not intended for mortals, as it requires nvidia H100 and L40S gpus to train any problem
please republish and instead of calling it tiny, call it simple. it is simple compared to HRM and like HRM actually implements Iteratively reweighted kernel machines
efficiently learn sparse functions

by Zhu, Davis, Drusvyatskiy and Fazel

We note that the iterative approach is actually suitable for newtonian descent as proposed in ParaRNN by Apple, and we offer tentative code to replace or extend trm.py (use at own peril- not tested, coded by today's emergent k2-reasoning model) with iteratively newton acceleration for training.

Naturally, if we owned a H100 cluster, we would be happy to run this code, but like so many other "experiments" released to the public by corporate and institutional entities, training this model remains out of the reach of individuals who only have access to personal computing machines, probably largely due to how it is designed and not due to any architectural limitations.

from typing import Tuple, Dict, Optional
from dataclasses import dataclass
import math
import torch
import torch.nn.functional as F
from torch import nn
from pydantic import BaseModel
from models.common import trunc_normal_init_
from models.layers import rms_norm, SwiGLU, Attention, RotaryEmbedding, CastedEmbedding, CastedLinear
from models.sparse_embedding import CastedSparseEmbedding

IGNORE_LABEL_ID = -100

@dataclass
class TinyRecursiveReasoningModel_ACTV1InnerCarry:
    z_H: torch.Tensor
    z_L: torch.Tensor

@dataclass
class TinyRecursiveReasoningModel_ACTV1Carry:
    inner_carry: TinyRecursiveReasoningModel_ACTV1InnerCarry
    steps: torch.Tensor
    halted: torch.Tensor
    current_data: Dict[str, torch.Tensor]

class TinyRecursiveReasoningModel_ACTV1Config(BaseModel):
    batch_size: int
    seq_len: int
    puzzle_emb_ndim: int = 0
    num_puzzle_identifiers: int
    vocab_size: int

    H_cycles: int
    L_cycles: int

    H_layers: int  # Retained for API compatibility, ignored in Newton variant
    L_layers: int

    # Transformer config
    hidden_size: int
    expansion: float
    num_heads: int
    pos_encodings: str

    rms_norm_eps: float = 1e-5
    rope_theta: float = 10000.0
    
    # Halting Q-learning config
    halt_max_steps: int
    halt_exploration_prob: float

    forward_dtype: str = "bfloat16"

    # Newton-specific additions
    newton_max_cg_iters: int = 5
    newton_trust_region_mu: float = 0.1
    newton_enable: bool = True  # Flag to toggle Newton vs original

    # Original config parameters
    mlp_t: bool = False
    puzzle_emb_len: int = 16
    no_ACT_continue: bool = True

class TinyRecursiveReasoningModel_NewtonVariant(nn.Module):
    """Newton-accelerated drop-in replacement for TinyRecursiveReasoningModel_ACTV1_Inner"""
    
    def __init__(self, config: TinyRecursiveReasoningModel_ACTV1Config) -> None:
        super().__init__()
        self.config = config
        self.forward_dtype = getattr(torch, self.config.forward_dtype)

        # I/O layers (identical to original)
        self.embed_scale = math.sqrt(self.config.hidden_size)
        embed_init_std = 1.0 / self.embed_scale
        self.embed_tokens = CastedEmbedding(self.config.vocab_size, self.config.hidden_size, 
                                            init_std=embed_init_std, cast_to=self.forward_dtype)
        self.lm_head = CastedLinear(self.config.hidden_size, self.config.vocab_size, bias=False)
        self.q_head = CastedLinear(self.config.hidden_size, 2, bias=True)

        # Position and puzzle embeddings (identical to original)
        self.puzzle_emb_len = -(self.config.puzzle_emb_ndim // -self.config.hidden_size) \
                              if self.config.puzzle_emb_len == 0 else self.config.puzzle_emb_len
        if self.config.puzzle_emb_ndim > 0:
            self.puzzle_emb = CastedSparseEmbedding(self.config.num_puzzle_identifiers, self.config.puzzle_emb_ndim,
                                                    batch_size=self.config.batch_size, init_std=0, cast_to=self.forward_dtype)

        if self.config.pos_encodings == "rope":
            self.rotary_emb = RotaryEmbedding(dim=self.config.hidden_size // self.config.num_heads,
                                              max_position_embeddings=self.config.seq_len + self.puzzle_emb_len,
                                              base=self.config.rope_theta)
        elif self.config.pos_encodings == "learned":
            self.embed_pos = CastedEmbedding(self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, 
                                             init_std=embed_init_std, cast_to=self.forward_dtype)

        # Core network (single network for both f_y and f_z)
        self.core_net = self._build_core_network()

        # Initial states (identical to original)
        self.H_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
        self.L_init = nn.Buffer(trunc_normal_init_(torch.empty(self.config.hidden_size, dtype=self.forward_dtype), std=1), persistent=True)
        
        # Q-head initialization (identical to original)
        with torch.no_grad():
            self.q_head.weight.zero_()
            self.q_head.bias.fill_(-5)

    def _build_core_network(self):
        """Build the core 2-layer network used for both reasoning and refinement"""
        layers = []
        for _ in range(self.config.L_layers):
            layers.append(TinyRecursiveReasoningModel_ACTV1Block(self.config))
        return TinyRecursiveReasoningModel_ACTV1ReasoningModule(layers)

    def _input_embeddings(self, input: torch.Tensor, puzzle_identifiers: torch.Tensor):
        """Identical to original implementation"""
        embedding = self.embed_tokens(input.to(torch.int32))

        if self.config.puzzle_emb_ndim > 0:
            puzzle_embedding = self.puzzle_emb(puzzle_identifiers)
            pad_count = self.puzzle_emb_len * self.config.hidden_size - puzzle_embedding.shape[-1]
            if pad_count > 0:
                puzzle_embedding = F.pad(puzzle_embedding, (0, pad_count))
            embedding = torch.cat((puzzle_embedding.view(-1, self.puzzle_emb_len, self.config.hidden_size), embedding), dim=-2)

        if self.config.pos_encodings == "learned":
            embedding = 0.707106781 * (embedding + self.embed_pos.embedding_weight.to(self.forward_dtype))

        return self.embed_scale * embedding

    def empty_carry(self, batch_size: int):
        """Identical API to original"""
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
            z_L=torch.empty(batch_size, self.config.seq_len + self.puzzle_emb_len, self.config.hidden_size, dtype=self.forward_dtype),
        )
        
    def reset_carry(self, reset_flag: torch.Tensor, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry):
        """Identical API to original"""
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=torch.where(reset_flag.view(-1, 1, 1), self.H_init, carry.z_H),
            z_L=torch.where(reset_flag.view(-1, 1, 1), self.L_init, carry.z_L),
        )

    def _compute_residual_and_jacobian(self, y: torch.Tensor, z: torch.Tensor, 
                                       x_emb: torch.Tensor, cos_sin) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Compute residual vector and Jacobian blocks for Newton step
        Residual: r = [y - f_y(y,z); z - f_z(x,y,z)]
        Jacobian: J = [∂f_y/∂y - I, ∂f_y/∂z; ∂f_z/∂y, ∂f_z/∂z - I]
        """
        batch_size, seq_len, hidden_size = y.shape
        state_dim = batch_size * seq_len * hidden_size
        
        # Enable grad for Jacobian computation
        with torch.enable_grad():
            y_req = y.detach().requires_grad_(True)
            z_req = z.detach().requires_grad_(True)
            
            # Compute f_y(y,z) - note: x is NOT injected here
            f_y_input = y_req  # net expects (y, z) when x absent
            f_y_output = self.core_net(f_y_input, z_req, cos_sin=cos_sin)
            
            # Compute f_z(x,y,z) - note: x IS injected here
            f_z_input = z_req
            injection = y_req + x_emb
            f_z_output = self.core_net(f_z_input, injection, cos_sin=cos_sin)
            
            # Compute residuals
            r_y = y_req - f_y_output
            r_z = z_req - f_z_output
            
            # Compute Jacobian blocks via autograd
            def compute_jacobian_block(output, input_var):
                """Compute Jacobian block ∂output/∂input_var"""
                # Flatten for batch processing
                output_flat = output.reshape(-1, hidden_size)
                input_flat = input_var.reshape(-1, hidden_size)
                
                # Initialize Jacobian block
                jacobian_block = torch.zeros(batch_size * seq_len, hidden_size, hidden_size, 
                                             device=output.device, dtype=output.dtype)
                
                # Compute column-by-column using vmap-like approach
                for i in range(hidden_size):
                    v = torch.zeros_like(input_flat)
                    v[:, i] = 1.0
                    grad_outputs = torch.autograd.grad(
                        outputs=output_flat,
                        inputs=input_flat,
                        grad_outputs=v,
                        create_graph=False,
                        retain_graph=True
                    )[0]
                    jacobian_block[:, :, i] = grad_outputs
                
                return jacobian_block
            
            # Compute blocks
            J_yy = compute_jacobian_block(f_y_output, y_req) - torch.eye(hidden_size, device=y.device, dtype=y.dtype).unsqueeze(0).expand(batch_size*seq_len, -1, -1)
            J_yz = compute_jacobian_block(f_y_output, z_req)
            J_zy = compute_jacobian_block(f_z_output, y_req)
            J_zz = compute_jacobian_block(f_z_output, z_req) - torch.eye(hidden_size, device=z.device, dtype=z.dtype).unsqueeze(0).expand(batch_size*seq_len, -1, -1)
            
            # Assemble full Jacobian as sparse block operator
            residual = torch.cat([r_y.reshape(batch_size*seq_len, hidden_size), 
                                 r_z.reshape(batch_size*seq_len, hidden_size)], dim=0)
            
            def jacobian_vector_product(v):
                """J * v where v is [v_y; v_z] concatenated"""
                v_y = v[:batch_size*seq_len].reshape(batch_size*seq_len, hidden_size)
                v_z = v[batch_size*seq_len:].reshape(batch_size*seq_len, hidden_size)
                
                out_y = torch.matmul(J_yy, v_y.unsqueeze(-1)).squeeze(-1) + \
                        torch.matmul(J_yz, v_z.unsqueeze(-1)).squeeze(-1)
                out_z = torch.matmul(J_zy, v_y.unsqueeze(-1)).squeeze(-1) + \
                        torch.matmul(J_zz, v_z.unsqueeze(-1)).squeeze(-1)
                
                return torch.cat([out_y, out_z], dim=0)
            
            return residual.reshape(-1), jacobian_vector_product

    def _newton_step(self, y: torch.Tensor, z: torch.Tensor, x_emb: torch.Tensor, 
                     cos_sin, max_iters: int = 5) -> Tuple[torch.Tensor, torch.Tensor]:
        """Perform Newton step with CG solver for simultaneous (y,z) update"""
        residual, jvp = self._compute_residual_and_jacobian(y, z, x_emb, cos_sin)
        
        # Initial guess for delta
        delta = torch.zeros_like(residual)
        
        # Trust-region CG solver
        def cg_step(k, x, r, p):
            """Single CG iteration with trust region damping"""
            Ap = jvp(p)
            alpha = torch.sum(r * r) / (torch.sum(p * Ap) + self.config.newton_trust_region_mu * torch.sum(p * p))
            x_new = x + alpha * p
            r_new = r - alpha * Ap
            beta = torch.sum(r_new * r_new) / torch.sum(r * r)
            p_new = r_new + beta * p
            return x_new, r_new, p_new
        
        # Run CG iterations
        r = residual.clone()
        p = r.clone()
        for _ in range(max_iters):
            delta, r, p = cg_step(0, delta, r, p)
            if torch.norm(r) < 1e-6:
                break
        
        # Apply update
        batch_size, seq_len, hidden_size = y.shape
        state_dim = batch_size * seq_len * hidden_size
        delta_y = delta[:state_dim].reshape(batch_size, seq_len, hidden_size)
        delta_z = delta[state_dim:].reshape(batch_size, seq_len, hidden_size)
        
        return y + delta_y, z + delta_z

    def _standard_recurse(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, 
                         x_emb: torch.Tensor, cos_sin) -> TinyRecursiveReasoningModel_ACTV1InnerCarry:
        """Original recursion method for comparison/backup"""
        z_H, z_L = carry.z_H, carry.z_L
        
        # H_cycles-1 without grad
        with torch.no_grad():
            for _H_step in range(self.config.H_cycles-1):
                for _L_step in range(self.config.L_cycles):
                    z_L = self.core_net(z_L, z_H + x_emb, cos_sin=cos_sin)
                z_H = self.core_net(z_H, z_L, cos_sin=cos_sin)
        
        # 1 with grad
        for _L_step in range(self.config.L_cycles):
            z_L = self.core_net(z_L, z_H + x_emb, cos_sin=cos_sin)
        z_H = self.core_net(z_H, z_L, cos_sin=cos_sin)
        
        return TinyRecursiveReasoningModel_ACTV1InnerCarry(z_H=z_H, z_L=z_L)

    def forward(self, carry: TinyRecursiveReasoningModel_ACTV1InnerCarry, 
                batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1InnerCarry, torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
        """
        Forward pass with Newton-accelerated recursion or standard fallback.
        API is identical to original TinyRecursiveReasoningModel_ACTV1_Inner.
        """
        seq_info = dict(cos_sin=self.rotary_emb() if hasattr(self, "rotary_emb") else None)
        x_emb = self._input_embeddings(batch["inputs"], batch["puzzle_identifiers"])

        if self.config.newton_enable:
            # Newton-accelerated recursion
            z_H, z_L = carry.z_H, carry.z_L
            
            # H_cycles-1 without grad (use standard recursion for state warming)
            with torch.no_grad():
                for _H_step in range(self.config.H_cycles - 1):
                    # Single Newton step per H_cycle for state improvement
                    z_H, z_L = self._newton_step(z_H, z_L, x_emb, seq_info.get('cos_sin'), 
                                                 max_iters=self.config.newton_max_cg_iters)
            
            # Final H_cycle with grad: perform Newton step and compute outputs
            z_H, z_L = self._newton_step(z_H, z_L, x_emb, seq_info.get('cos_sin'), 
                                         max_iters=self.config.newton_max_cg_iters)
        else:
            # Standard recursion fallback
            new_carry = self._standard_recurse(carry, x_emb, seq_info.get('cos_sin'))
            z_H, z_L = new_carry.z_H, new_carry.z_L

        # Detach for supervision step
        new_carry = TinyRecursiveReasoningModel_ACTV1InnerCarry(
            z_H=z_H.detach(), 
            z_L=z_L.detach()
        )
        
        # Output heads (identical to original)
        output = self.lm_head(z_H)[:, self.puzzle_emb_len:]
        q_logits = self.q_head(z_H[:, 0]).to(torch.float32)
        
        return new_carry, output, (q_logits[..., 0], q_logits[..., 1])

class TinyRecursiveReasoningModel_ACTV1(nn.Module):
    """ACT wrapper - identical API to original"""

    def __init__(self, config_dict: dict):
        super().__init__()
        self.config = TinyRecursiveReasoningModel_ACTV1Config(**config_dict)
        self.inner = TinyRecursiveReasoningModel_NewtonVariant(self.config)

    @property
    def puzzle_emb(self):
        return self.inner.puzzle_emb

    def initial_carry(self, batch: Dict[str, torch.Tensor]):
        """Identical API to original"""
        batch_size = batch["inputs"].shape[0]
        return TinyRecursiveReasoningModel_ACTV1Carry(
            inner_carry=self.inner.empty_carry(batch_size),
            steps=torch.zeros((batch_size, ), dtype=torch.int32),
            halted=torch.ones((batch_size, ), dtype=torch.bool),
            current_data={k: torch.empty_like(v) for k, v in batch.items()}
        )
        
    def forward(self, carry: TinyRecursiveReasoningModel_ACTV1Carry, 
                batch: Dict[str, torch.Tensor]) -> Tuple[TinyRecursiveReasoningModel_ACTV1Carry, Dict[str, torch.Tensor]]:
        """Identical API to original ACT wrapper"""
        new_inner_carry = self.inner.reset_carry(carry.halted, carry.inner_carry)
        new_steps = torch.where(carry.halted, 0, carry.steps)
        new_current_data = {
            k: torch.where(carry.halted.view((-1,) + (1,) * (batch[k].ndim - 1)), 
                          batch[k], v) 
            for k, v in carry.current_data.items()
        }

        # Forward inner model
        new_inner_carry, logits, (q_halt_logits, q_continue_logits) = self.inner(new_inner_carry, new_current_data)

        outputs = {
            "logits": logits,
            "q_halt_logits": q_halt_logits,
            "q_continue_logits": q_continue_logits
        }

        with torch.no_grad():
            new_steps = new_steps + 1
            is_last_step = new_steps >= self.config.halt_max_steps
            halted = is_last_step

            if self.training and (self.config.halt_max_steps > 1):
                if self.config.no_ACT_continue:
                    halted = halted | (q_halt_logits > 0)
                else:
                    halted = halted | (q_halt_logits > q_continue_logits)

                min_halt_steps = (torch.rand_like(q_halt_logits) < self.config.halt_exploration_prob) * \
                               torch.randint_like(new_steps, low=2, high=self.config.halt_max_steps + 1)
                halted = halted & (new_steps >= min_halt_steps)

                if not self.config.no_ACT_continue:
                    _, _, (next_q_halt_logits, next_q_continue_logits), _, _ = self.inner(new_inner_carry, new_current_data)
                    outputs["target_q_continue"] = torch.sigmoid(
                        torch.where(is_last_step, next_q_halt_logits, 
                                  torch.maximum(next_q_halt_logits, next_q_continue_logits))
                    )

        return TinyRecursiveReasoningModel_ACTV1Carry(new_inner_carry, new_steps, halted, new_current_data), outputs

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions