added residual attention, removed normal residual and gated attention#11
added residual attention, removed normal residual and gated attention#11kar-m wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
This PR changes the Transformer encoder/decoder block structure to derive each layer’s input from an accumulated history of prior layer states (“residual attention”), while removing the prior direct residual-add + gated-residual pattern.
Changes:
- Replace standard residual connections in
EncoderBlock/DecoderBlockwith a learned, per-layer aggregation overx_accum(history of layer states). - Update encoder/decoder
nn.scanbodies to carry and updatex_accum, and passlayer_idxthrough the scan. - Add explicit
x_accumallocation and final-state extraction (x_accum[:, :, num_layers, :]) in both encoder and decoder.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| k = x_accum.astype(jnp.float32) | ||
| v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) | ||
|
|
||
| logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) | ||
| valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] | ||
| logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) |
There was a problem hiding this comment.
This residual-attention computation reads the full x_accum tensor (B×T×(L+1)×D) and computes logits/values over all layers each iteration, making the encoder block’s cost scale ~O(L^2) in number of layers and increasing per-step memory bandwidth substantially. If this is not strictly required, consider restricting the aggregation to a bounded window of previous layers or maintaining a smaller running summary so each layer doesn’t repeatedly process the entire accumulated state.
| k = x_accum.astype(jnp.float32) | |
| v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) | |
| logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) | |
| valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] | |
| logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) | |
| # Only use accumulated states up to and including the current layer to avoid | |
| # repeatedly processing future layers that are always masked out. | |
| x_prefix = x_accum[:, :, : layer_idx + 1, :] | |
| k = x_prefix.astype(jnp.float32) | |
| v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_prefix) | |
| logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) |
| x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_encoder_layers + 1, cfg.d_model), dtype=dt) | ||
| x_accum = x_accum.at[:, :, 0, :].set(x) |
There was a problem hiding this comment.
Allocating and carrying x_accum with shape (B, T, num_encoder_layers+1, d_model) through the scan significantly increases activation memory (and likely compilation time) vs. a standard scan carry. If you only need the final layer output and a limited history for residual attention, consider reducing the stored history (e.g., windowed buffer) or computing the needed aggregation without materializing the full 4D tensor.
| q = self.param("residual_query", default_init(), (self.d_model,)).astype(jnp.float32) | ||
| k = x_accum.astype(jnp.float32) | ||
| v = ZCRMSNorm(dtype=self.dtype, name="residual_value_norm")(x_accum) | ||
|
|
||
| logits = jnp.einsum("btld,d->btl", k, q) / jnp.sqrt(jnp.float32(self.d_model)) | ||
| valid_layers = (jnp.arange(x_accum.shape[2]) <= layer_idx)[None, None, :] | ||
| logits = jnp.where(valid_layers, logits, jnp.finfo(logits.dtype).min) | ||
| alphas = nn.sigmoid(logits).astype(self.dtype) | ||
| x = jnp.einsum("btl,btld->btd", alphas, v) | ||
|
|
There was a problem hiding this comment.
The decoder block also performs residual aggregation over the entire x_accum each layer iteration, which similarly scales compute ~O(L^2) with num_decoder_layers and can become a bottleneck for deeper models/longer sequences. If possible, consider limiting the number of previous layers attended to or using a more incremental aggregation strategy to keep per-layer work closer to O(L).
| x_accum = jnp.zeros((x.shape[0], x.shape[1], cfg.num_decoder_layers + 1, cfg.d_model), dtype=dt) | ||
| x_accum = x_accum.at[:, :, 0, :].set(x) |
There was a problem hiding this comment.
Similar to the encoder, x_accum is allocated as a full (B, T, num_decoder_layers+1, d_model) buffer and threaded through the scan, which can be very memory-intensive for large batch/sequence/layer counts. If memory is a concern, consider storing fewer layer states (windowed) or restructuring so scan outputs are collected separately instead of keeping the entire history in the carry.
No description provided.