diff --git a/configs/acoustic.yaml b/configs/acoustic.yaml index 935d6e16..24d1e4d4 100644 --- a/configs/acoustic.yaml +++ b/configs/acoustic.yaml @@ -65,6 +65,8 @@ max_beta: 0.02 enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +use_gated_attn: false +use_qk_norm: false use_stretch_embed: true use_variance_scaling: true rel_pos: true diff --git a/configs/templates/config_acoustic.yaml b/configs/templates/config_acoustic.yaml index 9d63028f..17a1fe12 100644 --- a/configs/templates/config_acoustic.yaml +++ b/configs/templates/config_acoustic.yaml @@ -72,6 +72,8 @@ diffusion_type: reflow enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +use_gated_attn: true +use_qk_norm: true use_stretch_embed: true use_variance_scaling: true use_shallow_diffusion: true diff --git a/configs/templates/config_variance.yaml b/configs/templates/config_variance.yaml index 40f4c532..b15c7bdd 100644 --- a/configs/templates/config_variance.yaml +++ b/configs/templates/config_variance.yaml @@ -66,6 +66,8 @@ tension_logit_max: 10.0 enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +use_gated_attn: true +use_qk_norm: true use_stretch_embed: false use_variance_scaling: true hidden_size: 384 diff --git a/configs/variance.yaml b/configs/variance.yaml index a819c1c4..3422d550 100644 --- a/configs/variance.yaml +++ b/configs/variance.yaml @@ -37,6 +37,8 @@ predict_tension: false enc_ffn_kernel_size: 3 use_rope: true rope_interleaved: false +use_gated_attn: false +use_qk_norm: false use_stretch_embed: false use_variance_scaling: true rel_pos: true diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index 0012b99c..af6210f8 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -211,8 +211,20 @@ def forward(self, x): return x +class AtanSigmoid(nn.Module): + def __init__(self): + super(AtanSigmoid, self).__init__() + self.pi = math.pi + self.pi_half = math.pi / 2 + self.inv_pi = 1.0 / math.pi + + def forward(self, x): + return (torch.atan(x) + self.pi_half) * self.inv_pi + + class MultiheadSelfAttentionWithRoPE(nn.Module): - def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=None): + def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=None, + use_gated_attn=False, use_qk_norm=False): super().__init__() assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads" @@ -223,6 +235,12 @@ def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=N # Linear layers for Q, K, V projections self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) + # refer to Qwen 3 + self.use_qk_norm = use_qk_norm + if self.use_qk_norm: + self.q_norm = LayerNorm(embed_dim // num_heads) + self.k_norm = LayerNorm(embed_dim // num_heads) + # Final linear layer after concatenation self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -232,6 +250,16 @@ def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=N # Rotary Embeddings self.rotary_embed = rotary_embed + # refer to NIPS 2025 best paper: "Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free" + # (arxiv: 2505.06708) + self.use_gated_attn = use_gated_attn + if self.use_gated_attn: + self.gate_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # self.atan_sigmoid = AtanSigmoid() + nn.init.xavier_uniform_(self.gate_proj.weight) + if bias: + nn.init.constant_(self.gate_proj.bias, 0.0) + # Initialization parameters nn.init.xavier_uniform_(self.in_proj.weight) nn.init.xavier_uniform_(self.out_proj.weight) @@ -252,6 +280,11 @@ def forward(self, x, key_padding_mask=None): K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D) V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D) + # Query-Key Normalization + if self.use_qk_norm: + Q = self.q_norm(Q) + K = self.k_norm(K) + # Apply RoPE if self.rotary_embed is not None: Q = self.rotary_embed.rotate_queries_or_keys(Q) @@ -276,6 +309,12 @@ def forward(self, x, key_padding_mask=None): # Reshape and concatenate heads attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # (B, L, C) + if self.use_gated_attn: + # Formula (5): Y' = Y ⊙ σ(XW_θ) + # gate_score = self.atan_sigmoid(self.gate_proj(x)) # (B, L, C) + gate_score = torch.sigmoid(self.gate_proj(x)) # (B, L, C) + attn_output = attn_output * gate_score + # Final linear projection output = self.out_proj(attn_output) # (B, L, C) @@ -284,7 +323,8 @@ def forward(self, x, key_padding_mask=None): class EncSALayer(nn.Module): def __init__(self, c, num_heads, dropout, attention_dropout=0.1, - relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None): + relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None, + use_gated_attn=False, use_qk_norm=False): super().__init__() self.dropout = dropout self.layer_norm1 = LayerNorm(c) @@ -295,7 +335,8 @@ def __init__(self, c, num_heads, dropout, attention_dropout=0.1, self.use_rope = False else: self.self_attn = MultiheadSelfAttentionWithRoPE( - c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed + c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed, + use_gated_attn=use_gated_attn, use_qk_norm=use_qk_norm ) self.use_rope = True self.layer_norm2 = LayerNorm(c) diff --git a/modules/fastspeech/acoustic_encoder.py b/modules/fastspeech/acoustic_encoder.py index 868d383f..4b304e37 100644 --- a/modules/fastspeech/acoustic_encoder.py +++ b/modules/fastspeech/acoustic_encoder.py @@ -38,7 +38,8 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + use_gated_attn=hparams.get('use_gated_attn', False), use_qk_norm=hparams.get('use_qk_norm', False) ) self.pitch_embed = Linear(1, hparams['hidden_size']) diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py index cc840aed..d01ad8e7 100644 --- a/modules/fastspeech/tts_modules.py +++ b/modules/fastspeech/tts_modules.py @@ -12,13 +12,15 @@ class TransformerEncoderLayer(nn.Module): - def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None): + def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None, + use_gated_attn=False, use_qk_norm=False): super().__init__() self.op = EncSALayer( hidden_size, num_heads, dropout=dropout, attention_dropout=0.0, relu_dropout=dropout, kernel_size=kernel_size, - act=act, rotary_embed=rotary_embed + act=act, rotary_embed=rotary_embed, + use_gated_attn=use_gated_attn, use_qk_norm=use_qk_norm ) def forward(self, x, **kwargs): @@ -369,7 +371,8 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): class FastSpeech2Encoder(nn.Module): def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, ffn_act='gelu', - dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True): + dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True, + use_gated_attn=False, use_qk_norm=False): super().__init__() self.num_layers = num_layers embed_dim = self.hidden_size = hidden_size @@ -383,7 +386,8 @@ def __init__(self, hidden_size, num_layers, TransformerEncoderLayer( self.hidden_size, self.dropout, kernel_size=ffn_kernel_size, act=ffn_act, - num_heads=num_heads, rotary_embed=rotary_embed + num_heads=num_heads, rotary_embed=rotary_embed, + use_gated_attn=use_gated_attn, use_qk_norm=use_qk_norm ) for _ in range(self.num_layers) ]) diff --git a/modules/fastspeech/variance_encoder.py b/modules/fastspeech/variance_encoder.py index ba6994c1..76fa9b9f 100644 --- a/modules/fastspeech/variance_encoder.py +++ b/modules/fastspeech/variance_encoder.py @@ -33,7 +33,8 @@ def __init__(self, vocab_size): ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'], dropout=hparams['dropout'], num_heads=hparams['num_heads'], use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False), - use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + use_gated_attn=hparams.get('use_gated_attn', False), use_qk_norm=hparams.get('use_qk_norm', False) ) dur_hparams = hparams['dur_prediction_args'] @@ -127,7 +128,8 @@ def get_hparam(key): ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'), dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'), use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'), - use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True) + use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True), + use_gated_attn=hparams.get('use_gated_attn', False), use_qk_norm=hparams.get('use_qk_norm', False) ) self.out_proj = Linear(hidden_size, hparams['hidden_size'])