Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
kwen2501 committed Feb 11, 2025
1 parent c185b0d commit 86934e6
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
9 changes: 6 additions & 3 deletions torchtitan/models/deepseek_v3/attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def __init__(self, is_causal: bool, sliding_window: Optional[int] = None):

if self.sliding_window is not None and self.sliding_window <= 0:
raise ValueError(
f"Make sure that when passing `sliding_window` that its value is a strictly positive integer, not `{self.sliding_window}`"
"Make sure that when passing `sliding_window` that its value is a strictly positive integer, "
f"not `{self.sliding_window}`"
)

def to_causal_4d(
Expand Down Expand Up @@ -126,7 +127,8 @@ def to_4d(
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
if key_value_length is None:
raise ValueError(
"This attention mask converter is causal. Make sure to pass `key_value_length` to correctly create a causal mask."
"This attention mask converter is causal. Make sure to pass "
"`key_value_length` to correctly create a causal mask."
)

past_key_values_length = key_value_length - query_length
Expand Down Expand Up @@ -233,7 +235,8 @@ def _unmask_unattended(
`expanded_mask` is [bsz, num_masks, tgt_seq_len, src_seq_len] or [bsz, tgt_seq_len, src_seq_len].
`attention_mask` is [bsz, src_seq_len].
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case of alibi attention bias.
The dimension num_masks of `expanded_mask` is most often 1, but it can also be the number of heads in the case
of alibi attention bias.
For example, if `expanded_mask` is (e.g. here left-padding case)
```
Expand Down
13 changes: 6 additions & 7 deletions torchtitan/models/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ class ModelArgs:
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
Number of selected groups for each token(for each token, ensuring the selected experts is only within
`topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
Expand Down Expand Up @@ -499,7 +500,7 @@ def reset_parameters(self) -> None:

def forward(self, hidden_states):
bsz, seq_len, h = hidden_states.shape
### compute gating score
# compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(
hidden_states.type(torch.float32), self.weight.type(torch.float32), None
Expand All @@ -511,7 +512,7 @@ def forward(self, hidden_states):
f"insupportable scoring function for MoE gating: {self.scoring_func}"
)

### select top-k experts
# select top-k experts
if self.topk_method == "noaux_tc":
assert not self.training
scores_for_choice = scores.view(
Expand Down Expand Up @@ -546,7 +547,7 @@ def forward(self, hidden_states):
f"insupportable TopK function for MoE gating: {self.topk_method}"
)

### norm gate to sum 1
# norm gate to sum 1
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
Expand Down Expand Up @@ -936,9 +937,7 @@ def forward(
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
Expand Down

0 comments on commit 86934e6

Please sign in to comment.