Skip to content

added residual attention, removed normal residual and gated attention#11

Open
kar-m wants to merge 1 commit into
mainfrom
karen/residual-attn
Open

added residual attention, removed normal residual and gated attention#11
kar-m wants to merge 1 commit into
mainfrom
karen/residual-attn

Conversation

@kar-m
Copy link
Copy Markdown
Contributor

@kar-m kar-m commented Mar 27, 2026

No description provided.

Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR changes the Transformer encoder/decoder block structure to derive each layer’s input from an accumulated history of prior layer states (“residual attention”), while removing the prior direct residual-add + gated-residual pattern.

Changes:

  • Replace standard residual connections in EncoderBlock/DecoderBlock with a learned, per-layer aggregation over x_accum (history of layer states).
  • Update encoder/decoder nn.scan bodies to carry and update x_accum, and pass layer_idx through the scan.
  • Add explicit x_accum allocation and final-state extraction (x_accum[:, :, num_layers, :]) in both encoder and decoder.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread src/model.py
Comment on lines +211 to +216
k = x_accum.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum)

logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))
valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :]
logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min)
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This residual-attention computation reads the full x_accum tensor (B×T×(L+1)×D) and computes logits/values over all layers each iteration, making the encoder block’s cost scale ~O(L^2) in number of layers and increasing per-step memory bandwidth substantially. If this is not strictly required, consider restricting the aggregation to a bounded window of previous layers or maintaining a smaller running summary so each layer doesn’t repeatedly process the entire accumulated state.

Suggested change
k = x_accum.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum)
logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))
valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :]
logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min)
# Only use accumulated states up to and including the current layer to avoid
# repeatedly processing future layers that are always masked out.
x_prefix = x_accum[:, :, : layer_idx + 1, :]
k = x_prefix.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_prefix)
logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))

Copilot uses AI. Check for mistakes.
Comment thread src/model.py
Comment on lines +269 to +270
x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_encoder_layers + 1, cfg.d_model), dtype=dt)
x_accum = x_accum.at[:, :, 0, :].set(x)
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allocating and carrying x_accum with shape (B, T, num_encoder_layers+1, d_model) through the scan significantly increases activation memory (and likely compilation time) vs. a standard scan carry. If you only need the final layer output and a limited history for residual attention, consider reducing the stored history (e.g., windowed buffer) or computing the needed aggregation without materializing the full 4D tensor.

Copilot uses AI. Check for mistakes.
Comment thread src/model.py
Comment on lines +302 to +311
q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32)
k = x_accum.astype(jnp.float32)
v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum)

logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model))
valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :]
logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min)
alphas = nn.sigmoid(logits).astype(self.dtype)
x = jnp.einsum("btl,btld->btd", alphas, v)

Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The decoder block also performs residual aggregation over the entire x_accum each layer iteration, which similarly scales compute ~O(L^2) with num_decoder_layers and can become a bottleneck for deeper models/longer sequences. If possible, consider limiting the number of previous layers attended to or using a more incremental aggregation strategy to keep per-layer work closer to O(L).

Copilot uses AI. Check for mistakes.
Comment thread src/model.py
Comment on lines +367 to +368
x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_decoder_layers + 1, cfg.d_model), dtype=dt)
x_accum = x_accum.at[:, :, 0, :].set(x)
Copy link

Copilot AI Mar 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to the encoder, x_accum is allocated as a full (B, T, num_decoder_layers+1, d_model) buffer and threaded through the scan, which can be very memory-intensive for large batch/sequence/layer counts. If memory is a concern, consider storing fewer layer states (windowed) or restructuring so scan outputs are collected separately instead of keeping the entire history in the carry.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants