Skip to content

Conversation

fegin
Copy link
Contributor

@fegin fegin commented Sep 30, 2025

Stack from ghstack (oldest at bottom):

Status
The PR is not landable yet but server as a RFC. If people are okay with this design, this PR requires following changes and verifications:

  1. Change all models, including the experimental ones.
  2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet).
  3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR.

Summary
This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks.

The previous design has several issues, one particular one is #1723.

Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward().

The new design:

  1. Model needs to provide get_attention_masks() that accepts create_mask_fn, batch, and eos_id. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask.

Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks.

  1. get_attention_masks() will be called from the trainer and the resulting masks are passed to the model.forward().

Justification: this will allow us to fix #1723 with pytorch/pytorch#164111 and this PR.

  1. Provide a single AttentionOp instead of two.

Justification: since the masking logic is moved outside, we don't need to do bookkeeping of masks in FlexAttentionWrapper. The logic is so simple that one AttentionOp makes things cleaner.

Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp.

See the discussion in #1723.

[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 30, 2025
…odel

**Status**
The PR is not landable yet but server as a RFC. If people are okay with
this design, this PR requires following changes and verifications:

1. Change all models, including the experimental ones.
2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet).
3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR.

**Summary**
This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks.

The previous design has several issues, one particular one is #1723.

Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward().

The new design:
1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask.

Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks.

2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward().

Justification: this will allow us to
fix #1723 with pytorch/pytorch#164111 and this PR.

3. Provide a single AttentionOp instead of two.

Justification: since the masking logic is moved outside, we don't need
to do bookkeeping of masks in FlexAttentionWrapper. The logic is so
simple that one AttentionOp makes things cleaner.

Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp.

See the discussion in #1723.
ghstack-source-id: e869695
Pull-Request-resolved: #1776
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 30, 2025
[ghstack-poisoned]
fegin added a commit that referenced this pull request Sep 30, 2025
…odel

**Status**
The PR is not landable yet but server as a RFC. If people are okay with
this design, this PR requires following changes and verifications:

1. Change all models, including the experimental ones.
2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet).
3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR.

**Summary**
This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks.

The previous design has several issues, one particular one is #1723.

Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward().

The new design:
1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask.

Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks.

2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward().

Justification: this will allow us to
fix #1723 with pytorch/pytorch#164111 and this PR.

3. Provide a single AttentionOp instead of two.

Justification: since the masking logic is moved outside, we don't need
to do bookkeeping of masks in FlexAttentionWrapper. The logic is so
simple that one AttentionOp makes things cleaner.

Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp.

See the discussion in #1723.
ghstack-source-id: 35aa425
Pull-Request-resolved: #1776
@fegin fegin requested review from XilunWu and drisspg September 30, 2025 21:15
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice refactor! Left many comments lol.

@wwwjn for sliding window attention, you could just create another mask_mod following the examples here.

return _FlexAttentionWrapper._flex_attn(*args, **kwargs)


class _ScaledDotProductAttentionWrapper(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to add comments why we have such wrappers

Copy link
Contributor Author

@fegin fegin Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll evaluate if we can merge the two wrapper into AttentionOp. This seems to cause a lot of confusion. Even if we enable FlexCP in the future, people may still confuse.

Comment on lines 223 to 225
@functools.lru_cache(4)
def create_block_mask_fn(*args, **kwargs):
return _compiled_create_block_mask(*args, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you trying to handle the case where CP is not applied to every Attention module (which I somehow prefer delay until people definitely need the complexity)? O/w I don't see why we need to do this for every iteration, or why we need cache here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I'm not handling that case. I think this shouldn't be a closure. Let me change it.

if not self.model_args.use_flex_attn:
return None

nope_mask_mod = get_causal_mask_mod()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda find rope and nope mask_mod a little confusing, maybe causal and then mixin the right one? either nope or rope?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that nope and rope will be both causal or blocl_causal depending on the setting. And fix_block_size will be apply to rope as well.

@fegin
Copy link
Contributor Author

fegin commented Oct 2, 2025

Looks like the biggest concerns of this PR

  1. Don't generalize mask creation for SDPA, explicitly do if/else to only call mask creation if use_flex_attn and mask should be explicitly set to None if SDPA is used.
  2. The thing wrappers that are used for CP still confuse people. I can try to merge them with InnerAttention if possible.

cc., @drisspg @tianyu-l @XilunWu

Copy link
Contributor

@lkhphuc lkhphuc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice refactor.
I have a small suggestion around the get_attention_masks called in train.py .
As all things called in train.py, that would be better if they are a bit more flexible.

Comment on lines 426 to 429
attention_masks = model_parts[0].get_attention_masks(
create_mask_fn=create_mask_fn,
batch=inputs,
eos_id=self.tokenizer.eos_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make this more general for other models such as vlm?
For example, we might need pixel_values (and grid_thw for mask) in extra_inputs, as well as some other special tokens like img_id, soi_id, eoi_id.
Locally, sometimes we need to pass in a bunch more special tokens for a specific model/mask type, so pass in the entire tokenizer (or a dict/dataclass of special tokens) would be much more ergonomic.

Suggested change
attention_masks = model_parts[0].get_attention_masks(
create_mask_fn=create_mask_fn,
batch=inputs,
eos_id=self.tokenizer.eos_id,
attention_masks = model_parts[0].get_attention_masks(
create_mask_fn=create_mask_fn,
batch=inputs,
tokenizer=self.tokenizer,
**extra_inputs,

Copy link
Contributor Author

@fegin fegin Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice suggestion. I was wondering how to do that for VLM as I didn't have good experience. I'll definitely take a look.

[ghstack-poisoned]
return _causal_mask


def get_document_mask_mod(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing in mask_mod, can we use and_masks in get_attention_masks functions https://github.com/pytorch/pytorch/blob/9fff8155c362da777e7ce31b85fb2dc7cfced2d5/torch/nn/attention/flex_attention.py#L922

@adil-a
Copy link

adil-a commented Oct 7, 2025

Hi, unsure where the best place to ask this is but this seems like a relevant recent PR. I have two questions:

  • Is CP + packed sequences (where we need to provide a custom mask to prevent separate sequences from attending to each other) + SDPA backend supported? It seems no, but I wanted to confirm.
  • If not, are there plans to support it?
  • What is the ETA for FlexAttention + CP? Will this support packing sequences together (as explained above)?

@fegin
Copy link
Contributor Author

fegin commented Oct 7, 2025

Is CP + packed sequences (where we need to provide a custom mask to prevent separate sequences from attending to each other) + SDPA backend supported? It seems no, but I wanted to confirm.
If not, are there plans to support it?

SDPA doesn't support this so CP + SDPA doesn't support this. The current plan is to wait until SDPA support packed sequences.

What is the ETA for FlexAttention + CP? Will this support packing sequences together (as explained above)?

Yes, it will support packing sequence. But the current implementation will use allgather only.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me

"""Wrapper around `F.scaled_dot_product_attention` to make it CP compatible.
This wrapper is needed because `F.scaled_dot_product_attention` is not
a torch.nn.Module, and thus cannot be applied _ContextParallel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
a torch.nn.Module, and thus cannot be applied _ContextParallel.
a torch.nn.Module, and thus cannot be applied _ContextParallel to.


# Add CuDNN on B200 w/ highest priority
cls.backends = [
# Always make CuDNN as the highest priority if available.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can remove this comment because of set_priority=True below


self.use_flex_attn = model_args.use_flex_attn
if self.use_flex_attn:
self.inner_attention = FlexAttentionWrapper()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another option is to call it self.kernel, as used by some internal

raise ValueError(f"Unknown attention mask type: {self.attn_mask_type}")

rope_mask_mod = get_fixed_block_mask_mod(
nope_mask_mod, self.model_args.fixed_attn_block_size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer defining these masks independently and use and_masks to join them.

@adil-a
Copy link

adil-a commented Oct 7, 2025

SDPA doesn't support this so CP + SDPA doesn't support this. The current plan is to wait until SDPA support packed sequences.

Sorry, could you elaborate why it is unsupported even in the non-CP case? According to the pseudocode can't we pass in is_causal=False and then a block causal mask for attn_mask?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants