Skip to content

Experimental JEPA on Robby's PR #1412 #1772

@Qeuph

Description

@Qeuph

train_jepa.py
What was added

Hyperparameters — four new env-vars (zero cost when disabled):

  • JEPA_LAMBDA — weight of the auxiliary loss (0 = fully off, try 0.05–0.15)
  • JEPA_MAX_SPAN_FRAC — maximum patch length as a fraction of seq_len (default 0.5)
  • JEPA_START_FRAC — delay JEPA until this fraction of training elapsed (useful for warmup)
  • JEPA_LAYER — which block's hidden state to use; -1 = final pre-norm representation

GPT._run_blocks — single shared forward kernel that optionally captures a tapped hidden state (zero extra compute for the -1 case — it's the same tensor). Replaces the duplicated encoder/decoder loops.

GPT._jepa_loss_fn — static method implementing the linear representation hypothesis:

context = (h[a-1] − h[0])  +  (h[T-1] − h[b])   # before + after
patch   =  h[b]   − h[a-1]                         # the patch span
loss    = 1 − cosine_similarity(context, patch)

context + patch = h[T-1] − h[0] — the two vectors perfectly partition the full-sequence encoding, so they carry complementary information. Maximising their cosine similarity forces the hidden states to encode spans linearly, which improves representational geometry and can also aid quantisation.

GPT.forward — the self._jepa_lambda > 0 guard is a Python float comparison, so torch.compile resolves it at trace time. When JEPA is off, the compiled graph is identical to the original. When on (and span_a is not None), a separate specialised graph is compiled exactly once.

_sample_jepa_spans — span sampling lives in Python-land, outside the fullgraph=True trace, so torch.randint never appears inside the compiled graph.

step_fn receives frac so spans are only sampled after jepa_start_frac — enabling a "train normally first, then add JEPA" curriculum.

Suggested sweep starting points:

JEPA_LAMBDA=0.1 JEPA_MAX_SPAN_FRAC=0.5 JEPA_START_FRAC=0.0  # full run
JEPA_LAMBDA=0.05 JEPA_START_FRAC=0.3                          # delayed start
JEPA_LAYER=5                                                    # mid-layer tap

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions