train_jepa.py
What was added
Hyperparameters — four new env-vars (zero cost when disabled):
JEPA_LAMBDA — weight of the auxiliary loss (0 = fully off, try 0.05–0.15)
JEPA_MAX_SPAN_FRAC — maximum patch length as a fraction of seq_len (default 0.5)
JEPA_START_FRAC — delay JEPA until this fraction of training elapsed (useful for warmup)
JEPA_LAYER — which block's hidden state to use; -1 = final pre-norm representation
GPT._run_blocks — single shared forward kernel that optionally captures a tapped hidden state (zero extra compute for the -1 case — it's the same tensor). Replaces the duplicated encoder/decoder loops.
GPT._jepa_loss_fn — static method implementing the linear representation hypothesis:
context = (h[a-1] − h[0]) + (h[T-1] − h[b]) # before + after
patch = h[b] − h[a-1] # the patch span
loss = 1 − cosine_similarity(context, patch)
context + patch = h[T-1] − h[0] — the two vectors perfectly partition the full-sequence encoding, so they carry complementary information. Maximising their cosine similarity forces the hidden states to encode spans linearly, which improves representational geometry and can also aid quantisation.
GPT.forward — the self._jepa_lambda > 0 guard is a Python float comparison, so torch.compile resolves it at trace time. When JEPA is off, the compiled graph is identical to the original. When on (and span_a is not None), a separate specialised graph is compiled exactly once.
_sample_jepa_spans — span sampling lives in Python-land, outside the fullgraph=True trace, so torch.randint never appears inside the compiled graph.
step_fn receives frac so spans are only sampled after jepa_start_frac — enabling a "train normally first, then add JEPA" curriculum.
Suggested sweep starting points:
JEPA_LAMBDA=0.1 JEPA_MAX_SPAN_FRAC=0.5 JEPA_START_FRAC=0.0 # full run
JEPA_LAMBDA=0.05 JEPA_START_FRAC=0.3 # delayed start
JEPA_LAYER=5 # mid-layer tap
train_jepa.py
What was added
Hyperparameters— four new env-vars (zero cost when disabled):JEPA_LAMBDA— weight of the auxiliary loss (0 = fully off, try 0.05–0.15)JEPA_MAX_SPAN_FRAC— maximum patch length as a fraction of seq_len (default 0.5)JEPA_START_FRAC— delay JEPA until this fraction of training elapsed (useful for warmup)JEPA_LAYER— which block's hidden state to use; -1 = final pre-norm representationGPT._run_blocks— single shared forward kernel that optionally captures a tapped hidden state (zero extra compute for the-1case — it's the same tensor). Replaces the duplicated encoder/decoder loops.GPT._jepa_loss_fn— static method implementing the linear representation hypothesis:context + patch = h[T-1] − h[0]— the two vectors perfectly partition the full-sequence encoding, so they carry complementary information. Maximising their cosine similarity forces the hidden states to encode spans linearly, which improves representational geometry and can also aid quantisation.GPT.forward— theself._jepa_lambda > 0guard is a Python float comparison, sotorch.compileresolves it at trace time. When JEPA is off, the compiled graph is identical to the original. When on (andspan_a is not None), a separate specialised graph is compiled exactly once._sample_jepa_spans— span sampling lives in Python-land, outside thefullgraph=Truetrace, sotorch.randintnever appears inside the compiled graph.step_fnreceivesfracso spans are only sampled afterjepa_start_frac— enabling a "train normally first, then add JEPA" curriculum.Suggested sweep starting points: