Skip to content

Commit

Permalink
Prevent large values in conv module
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Apr 13, 2024
1 parent ed6bc20 commit 0eccb2b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
7 changes: 7 additions & 0 deletions egs/librispeech/SSL/zipformer/wav2vec2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@
from typing import List, Tuple

import numpy as np
import random
from scaling import penalize_abs_values_gt
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import Fp32GroupNorm, Fp32LayerNorm, TransposeLast



class ConvFeatureExtractionModel(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -105,4 +108,8 @@ def forward(self, x):
for conv in self.conv_layers:
x = conv(x)

if self.training and random.random() < 0.2:
x = penalize_abs_values_gt(x, limit=1000.0, penalty=1.0e-05,
name=(self.name if hasattr(self, 'name') else 'ConvFeatureExtractionModel'))

return x
2 changes: 1 addition & 1 deletion egs/librispeech/SSL/zipformer/zipformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def forward(
selected_attn_weights = attn_weights[0:1]
if torch.jit.is_scripting() or torch.jit.is_tracing():
pass
elif not self.training and random.random() < float(self.const_attention_rate):
elif self.training and random.random() < float(self.const_attention_rate):
# Make attention weights constant. The intention is to
# encourage these modules to do something similar to an
# averaging-over-time operation.
Expand Down

0 comments on commit 0eccb2b

Please sign in to comment.