Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ max_beta: 0.02
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_gated_attn: false
use_qk_norm: false
use_stretch_embed: true
use_variance_scaling: true
rel_pos: true
Expand Down
2 changes: 2 additions & 0 deletions configs/templates/config_acoustic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ diffusion_type: reflow
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_gated_attn: true
use_qk_norm: true
use_stretch_embed: true
use_variance_scaling: true
use_shallow_diffusion: true
Expand Down
2 changes: 2 additions & 0 deletions configs/templates/config_variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ tension_logit_max: 10.0
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_gated_attn: true
use_qk_norm: true
use_stretch_embed: false
use_variance_scaling: true
hidden_size: 384
Expand Down
2 changes: 2 additions & 0 deletions configs/variance.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ predict_tension: false
enc_ffn_kernel_size: 3
use_rope: true
rope_interleaved: false
use_gated_attn: false
use_qk_norm: false
use_stretch_embed: false
use_variance_scaling: true
rel_pos: true
Expand Down
47 changes: 44 additions & 3 deletions modules/commons/common_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,20 @@ def forward(self, x):
return x


class AtanSigmoid(nn.Module):
def __init__(self):
super(AtanSigmoid, self).__init__()
self.pi = math.pi
self.pi_half = math.pi / 2
self.inv_pi = 1.0 / math.pi

def forward(self, x):
return (torch.atan(x) + self.pi_half) * self.inv_pi


class MultiheadSelfAttentionWithRoPE(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=None):
def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=None,
use_gated_attn=False, use_qk_norm=False):
super().__init__()
assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by number of heads"

Expand All @@ -223,6 +235,12 @@ def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=N
# Linear layers for Q, K, V projections
self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias)

# refer to Qwen 3
self.use_qk_norm = use_qk_norm
if self.use_qk_norm:
self.q_norm = LayerNorm(embed_dim // num_heads)
self.k_norm = LayerNorm(embed_dim // num_heads)

# Final linear layer after concatenation
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)

Expand All @@ -232,6 +250,16 @@ def __init__(self, embed_dim, num_heads, dropout=0.1, bias=False, rotary_embed=N
# Rotary Embeddings
self.rotary_embed = rotary_embed

# refer to NIPS 2025 best paper: "Gated Attention for Large Language Models: Non-linearity, Sparsity, and Attention-Sink-Free"
# (arxiv: 2505.06708)
self.use_gated_attn = use_gated_attn
if self.use_gated_attn:
self.gate_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
# self.atan_sigmoid = AtanSigmoid()
nn.init.xavier_uniform_(self.gate_proj.weight)
if bias:
nn.init.constant_(self.gate_proj.bias, 0.0)

# Initialization parameters
nn.init.xavier_uniform_(self.in_proj.weight)
nn.init.xavier_uniform_(self.out_proj.weight)
Expand All @@ -252,6 +280,11 @@ def forward(self, x, key_padding_mask=None):
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # (B, H, L, D)

# Query-Key Normalization
if self.use_qk_norm:
Q = self.q_norm(Q)
K = self.k_norm(K)

# Apply RoPE
if self.rotary_embed is not None:
Q = self.rotary_embed.rotate_queries_or_keys(Q)
Expand All @@ -276,6 +309,12 @@ def forward(self, x, key_padding_mask=None):
# Reshape and concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # (B, L, C)

if self.use_gated_attn:
# Formula (5): Y' = Y ⊙ σ(XW_θ)
# gate_score = self.atan_sigmoid(self.gate_proj(x)) # (B, L, C)
gate_score = torch.sigmoid(self.gate_proj(x)) # (B, L, C)
attn_output = attn_output * gate_score

# Final linear projection
output = self.out_proj(attn_output) # (B, L, C)

Expand All @@ -284,7 +323,8 @@ def forward(self, x, key_padding_mask=None):

class EncSALayer(nn.Module):
def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None):
relu_dropout=0.1, kernel_size=9, act='gelu', rotary_embed=None,
use_gated_attn=False, use_qk_norm=False):
super().__init__()
self.dropout = dropout
self.layer_norm1 = LayerNorm(c)
Expand All @@ -295,7 +335,8 @@ def __init__(self, c, num_heads, dropout, attention_dropout=0.1,
self.use_rope = False
else:
self.self_attn = MultiheadSelfAttentionWithRoPE(
c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed
c, num_heads, dropout=attention_dropout, bias=False, rotary_embed=rotary_embed,
use_gated_attn=use_gated_attn, use_qk_norm=use_qk_norm
)
self.use_rope = True
self.layer_norm2 = LayerNorm(c)
Expand Down
3 changes: 2 additions & 1 deletion modules/fastspeech/acoustic_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def __init__(self, vocab_size):
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True),
use_gated_attn=hparams.get('use_gated_attn', False), use_qk_norm=hparams.get('use_qk_norm', False)
)

self.pitch_embed = Linear(1, hparams['hidden_size'])
Expand Down
12 changes: 8 additions & 4 deletions modules/fastspeech/tts_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@


class TransformerEncoderLayer(nn.Module):
def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None):
def __init__(self, hidden_size, dropout, kernel_size=None, act='gelu', num_heads=2, rotary_embed=None,
use_gated_attn=False, use_qk_norm=False):
super().__init__()
self.op = EncSALayer(
hidden_size, num_heads, dropout=dropout,
attention_dropout=0.0, relu_dropout=dropout,
kernel_size=kernel_size,
act=act, rotary_embed=rotary_embed
act=act, rotary_embed=rotary_embed,
use_gated_attn=use_gated_attn, use_qk_norm=use_qk_norm
)

def forward(self, x, **kwargs):
Expand Down Expand Up @@ -369,7 +371,8 @@ def mel2ph_to_dur(mel2ph, T_txt, max_dur=None):
class FastSpeech2Encoder(nn.Module):
def __init__(self, hidden_size, num_layers,
ffn_kernel_size=9, ffn_act='gelu',
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True):
dropout=None, num_heads=2, use_pos_embed=True, rel_pos=True, use_rope=False, rope_interleaved=True,
use_gated_attn=False, use_qk_norm=False):
super().__init__()
self.num_layers = num_layers
embed_dim = self.hidden_size = hidden_size
Expand All @@ -383,7 +386,8 @@ def __init__(self, hidden_size, num_layers,
TransformerEncoderLayer(
self.hidden_size, self.dropout,
kernel_size=ffn_kernel_size, act=ffn_act,
num_heads=num_heads, rotary_embed=rotary_embed
num_heads=num_heads, rotary_embed=rotary_embed,
use_gated_attn=use_gated_attn, use_qk_norm=use_qk_norm
)
for _ in range(self.num_layers)
])
Expand Down
6 changes: 4 additions & 2 deletions modules/fastspeech/variance_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def __init__(self, vocab_size):
ffn_kernel_size=hparams['enc_ffn_kernel_size'], ffn_act=hparams['ffn_act'],
dropout=hparams['dropout'], num_heads=hparams['num_heads'],
use_pos_embed=hparams['use_pos_embed'], rel_pos=hparams.get('rel_pos', False),
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True),
use_gated_attn=hparams.get('use_gated_attn', False), use_qk_norm=hparams.get('use_qk_norm', False)
)

dur_hparams = hparams['dur_prediction_args']
Expand Down Expand Up @@ -127,7 +128,8 @@ def get_hparam(key):
ffn_kernel_size=get_hparam('enc_ffn_kernel_size'), ffn_act=get_hparam('ffn_act'),
dropout=get_hparam('dropout'), num_heads=get_hparam('num_heads'),
use_pos_embed=get_hparam('use_pos_embed'), rel_pos=get_hparam('rel_pos'),
use_rope=get_hparam('use_rope'), rope_interleaved=hparams.get('rope_interleaved', True)
use_rope=hparams.get('use_rope', False), rope_interleaved=hparams.get('rope_interleaved', True),
use_gated_attn=hparams.get('use_gated_attn', False), use_qk_norm=hparams.get('use_qk_norm', False)
)
self.out_proj = Linear(hidden_size, hparams['hidden_size'])

Expand Down