Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ def set_determinism(
# https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# Ensure flex_attention is compiled without max-autotune. This is needed to ensure
# reproducibility, since the autotune results may not be deterministic.
from torch.nn.attention.flex_attention import flex_attention

from torchtitan.models.attention import FlexAttentionWrapper

FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention)

if not world_mesh:
if seed is not None:
torch.manual_seed(seed)
Expand Down Expand Up @@ -199,14 +207,6 @@ def context(cp_context: Generator[None, None, None] | None = None):
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

if cp_context is not None:
from torch.nn.attention import SDPBackend

from torchtitan.models.attention import ScaledDotProductAttention

if SDPBackend.MATH in ScaledDotProductAttention.backends:
ScaledDotProductAttention.backends.remove(SDPBackend.MATH)

stack.enter_context(cp_context)

yield
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/experiments/llama4/infra/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,8 @@ def apply_non_moe_tp(
layer_plan = {
"attention_norm": SequenceParallel(),
"attention": prepare_module_input(
input_layouts=(Shard(1), None),
desired_input_layouts=(Replicate(), None),
input_layouts=(Shard(1), None, None),
desired_input_layouts=(Replicate(), None, None),
),
"attention.wq": colwise_parallel(),
"attention.wk": colwise_parallel(),
Expand Down
68 changes: 58 additions & 10 deletions torchtitan/experiments/llama4/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch
import torch.nn.functional as F
from torch import nn

from torchtitan.models.attention import build_attention
from torch.nn.attention.flex_attention import and_masks

from torchtitan.components.tokenizer import BaseTokenizer
from torchtitan.models.attention import (
create_attention_mask,
FlexAttentionWrapper,
get_causal_mask_mod,
get_document_mask_mod,
get_fixed_block_mask_mod,
ScaledDotProductAttentionWrapper,
)
from torchtitan.models.moe import MoE
from torchtitan.protocols import ModelProtocol
from torchtitan.protocols.model import AttentionMasksType
from torchtitan.protocols.train_spec import ModelProtocol

from .args import TransformerModelArgs

Expand Down Expand Up @@ -155,9 +164,11 @@ def __init__(
# values of these two variables.
self.use_rope = use_rope

self.sdpa = build_attention(
model_args.use_flex_attn, model_args.attn_mask_type, fixed_block_size
)
self.use_flex_attn = model_args.use_flex_attn
if self.use_flex_attn:
self.inner_attention = FlexAttentionWrapper()
else:
self.inner_attention = ScaledDotProductAttentionWrapper()

def init_weights(self, init_std: float):
for linear in (self.wq, self.wk, self.wv):
Expand All @@ -168,6 +179,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Forward pass of the attention module.
Expand Down Expand Up @@ -202,7 +214,13 @@ def forward(
xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)

output = self.sdpa(xq, xk, xv)
if self.use_flex_attn:
assert isinstance(attention_masks, dict), attention_masks
attention_mask = attention_masks["rope" if self.use_rope else "nope"]
output = self.inner_attention(xq, xk, xv, block_mask=attention_mask)
else:
assert attention_masks is None
output = self.inner_attention(xq, xk, xv)

output = output.transpose(
1, 2
Expand Down Expand Up @@ -335,6 +353,7 @@ def forward(
self,
x: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None,
):
"""
Perform a forward pass through the TransformerBlock.
Expand All @@ -347,7 +366,7 @@ def forward(
torch.Tensor: Output tensor after applying attention and feedforward layers.

"""
h = x + self.attention(self.attention_norm(x), freqs_cis)
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
if self.moe_enabled:
out = h + self.moe(self.ffn_norm(h))
else:
Expand Down Expand Up @@ -447,9 +466,38 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.rope_theta,
)

def get_attention_masks(
self,
input_batch: torch.Tensor,
tokenizer: BaseTokenizer,
extra_inputs: dict[str, torch.Tensor] | None = None,
) -> AttentionMasksType:
mask_mods = [get_causal_mask_mod()]
match self.model_args.attn_mask_type:
case "causal":
B = 1
case "block_causal":
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
B = input_batch.shape[0]
case _:
raise ValueError(f"Unknown attention mask type: {self.attn_mask_type}")

rope_mask_mod = and_masks(
*mask_mods,
get_fixed_block_mask_mod(self.model_args.fixed_attn_block_size),
)
nope_mask_mod = and_masks(*mask_mods)

seqlen = input_batch.shape[1]
return {
"rope": create_attention_mask(rope_mask_mod, B, None, seqlen, seqlen),
"nope": create_attention_mask(nope_mask_mod, B, None, seqlen, seqlen),
}

def forward(
self,
tokens: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
"""
Expand All @@ -473,7 +521,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis)
h = layer(h, self.freqs_cis, attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
Loading