Skip to content

Commit fdf5e1d

Browse files
committed
[!205][RELEASE] Switch to released pangolinn (ACL 2024)
As we have released pangolinn and the related paper has been accepted, we need to change the current UTs to use pangolinn and update the references to the paper. Refactors the Conformer and Hyena UTs to use pangolinn and updates the references to the related paper. Refactored UTs
1 parent e1e04bb commit fdf5e1d

6 files changed

+227
-226
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ Dedicated README for each work can be found in the `fbk_works` directory.
55

66
### 2024
77

8+
- [[ACL 2024] **When Good and Reproducible Results are a Giant with Feet of Clay: The Importance of Software Quality in NLP**](fbk_works/BUGFREE_CONFORMER.md)
89
- [[LREC-COLING 2024] **How do Hyenas deal with Human Speech? Speech Recognition and Translation with ConfHyena**](fbk_works/HYENA_COLING2024.md)
910

1011
### 2023
@@ -18,7 +19,6 @@ Dedicated README for each work can be found in the `fbk_works` directory.
1819
- [[INTERSPEECH 2023] **Joint Speech Translation and Named Entity Recognition**](fbk_works/JOINT_ST_NER2023.md)
1920
- [[ACL 2023] **Attention as a Guide for Simultaneous Speech Translation**](fbk_works/EDATT_SIMULST_AGENT_ACL2023.md)
2021
- [[IWSLT 2023] **Direct Models for Simultaneous Translation and Automatic Subtitling: FBK@IWSLT2023**](fbk_works/IWSLT_2023.md)
21-
- [**Reproducibility is Nothing Without Correctness: The Importance of Testing Code in NLP**](fbk_works/BUGFREE_CONFORMER.md)
2222

2323
### 2022
2424

fbk_uts/conformer/test_conformer_encoder.py

+92-101
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,108 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414
import copy
15-
import math
1615
import unittest
1716
from argparse import Namespace
1817

19-
import torch
20-
from torch import nn
18+
from torch import nn, Tensor, LongTensor
2119

2220
from examples.speech_to_text.models.conformer import conformer_s, ConformerEncoder
2321
from examples.speech_to_text.modules.conformer_attention import MultiHeadedSelfAttentionModule
2422
from examples.speech_to_text.modules.conformer_encoder_layer import ConformerEncoderLayer
2523
from fairseq.data import Dictionary
2624
from fairseq.data.data_utils import lengths_to_padding_mask
2725

26+
from pangolinn import seq2seq
27+
28+
29+
class MultiHeadedSelfAttentionPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper):
30+
def build_module(self) -> nn.Module:
31+
return MultiHeadedSelfAttentionModule(self.num_input_channels, 2)
32+
33+
@property
34+
def num_input_channels(self) -> int:
35+
return 8
36+
37+
def forward(self, x: Tensor, lengths: LongTensor) -> Tensor:
38+
return self._module(x, lengths_to_padding_mask(lengths))
39+
40+
41+
class ConformerEncoderLayerPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper):
42+
def build_module(self) -> nn.Module:
43+
base_args = Namespace()
44+
base_args.input_feat_per_channel = self.num_input_channels
45+
base_args.input_channels = 1
46+
base_args.max_source_positions = 10
47+
base_args.no_syncbatchnorm = True
48+
base_args.encoder_embed_dim = 8
49+
conformer_s(base_args)
50+
return ConformerEncoderLayer(base_args)
51+
52+
@property
53+
def num_input_channels(self) -> int:
54+
return 8
55+
56+
def forward(self, x: Tensor, lengths: LongTensor) -> Tensor:
57+
return self._module(x.transpose(0, 1), lengths_to_padding_mask(lengths)).transpose(0, 1)
58+
59+
60+
class ConformerEncoderPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper):
61+
def base_args(self) -> Namespace:
62+
base_args = Namespace()
63+
base_args.input_feat_per_channel = self.num_input_channels
64+
base_args.input_channels = 1
65+
base_args.max_source_positions = 10
66+
base_args.no_syncbatchnorm = True
67+
base_args.encoder_embed_dim = 8
68+
base_args.encoder_layers = 3
69+
base_args.criterion = "ctc_multi_loss"
70+
base_args.ctc_compress_strategy = "none"
71+
base_args.ctc_encoder_layer = 2
72+
conformer_s(base_args)
73+
return base_args
74+
75+
def build_module(self) -> nn.Module:
76+
return ConformerEncoder(self.base_args(), Dictionary())
77+
78+
@property
79+
def num_input_channels(self) -> int:
80+
return 8
81+
82+
@property
83+
def sequence_downsampling_factor(self) -> int:
84+
# the two initial Conv1D reduce sequence length by a factor of 4
85+
return 4
86+
87+
def forward(self, x: Tensor, lengths: LongTensor) -> Tensor:
88+
return self._module(x, lengths)["encoder_out"][0].transpose(0, 1)
89+
90+
91+
class ConformerEncoderUnsafePangolinnWrapper(ConformerEncoderPangolinnWrapper):
92+
def base_args(self) -> Namespace:
93+
args = super().base_args()
94+
args.batch_unsafe_relative_shift = True
95+
return args
96+
97+
98+
class MultiHeadedSelfAttentionTestCase(seq2seq.EncoderPaddingTestCase):
99+
module_wrapper_class = MultiHeadedSelfAttentionPangolinnWrapper
100+
101+
102+
class ConformerEncoderLayerPaddingTestCase(seq2seq.EncoderPaddingTestCase):
103+
module_wrapper_class = ConformerEncoderLayerPangolinnWrapper
104+
105+
106+
class ConformerEncoderPaddingTestCase(seq2seq.EncoderPaddingTestCase):
107+
module_wrapper_class = ConformerEncoderPangolinnWrapper
108+
109+
110+
class ConformerEncoderUnsafePaddingTestCase(seq2seq.EncoderPaddingTestCase):
111+
module_wrapper_class = ConformerEncoderUnsafePangolinnWrapper
112+
113+
def test_batch_size_does_not_matter(self):
114+
with self.assertRaises(AssertionError):
115+
super().test_batch_size_does_not_matter()
116+
28117

29118
class ConformerEncoderTestCase(unittest.TestCase):
30119
@classmethod
@@ -59,104 +148,6 @@ def check_norm(self, args, norm_class):
59148
self.assertTrue(
60149
isinstance(encoder._modules["conformer_layers"][layer].conv_module.batchnorm, norm_class))
61150

62-
def test_conformer_encoder_layer_padding(self):
63-
batchnorm_args = copy.deepcopy(self.base_args)
64-
batchnorm_args.no_syncbatchnorm = True
65-
batchnorm_args.encoder_embed_dim = 8
66-
fake_sample = torch.rand(2, 10, 8)
67-
fake_sample[1, 3:, :] = 0
68-
fake_lengths = torch.LongTensor([10, 3])
69-
padding_mask = lengths_to_padding_mask(fake_lengths)
70-
encoder_layer = ConformerEncoderLayer(batchnorm_args)
71-
encoder_layer.eval()
72-
out = encoder_layer(fake_sample.transpose(0, 1), padding_mask).transpose(0, 1)
73-
self.assertTrue(
74-
torch.all(out[1, 3:, :] == 0.0), f"non-zero entries in {out[1, 3:, :]}")
75-
76-
def test_encoder_padding(self):
77-
batchnorm_args = copy.deepcopy(self.base_args)
78-
batchnorm_args.no_syncbatchnorm = True
79-
batchnorm_args.encoder_embed_dim = 8
80-
batchnorm_args.input_feat_per_channel = 8
81-
batchnorm_args.encoder_layers = 3
82-
fake_sample = torch.rand(2, 27, 8)
83-
fake_sample[1, 13:, :] = 0
84-
fake_lengths = torch.LongTensor([27, 13])
85-
encoder = ConformerEncoder(batchnorm_args, self.fake_dict)
86-
encoder.eval()
87-
net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True)
88-
padding_area = net_out["encoder_out"][0][4:, 1, :] # output is N x B x C and downsampled by 4
89-
self.assertGreater(padding_area.numel(), 0)
90-
self.assertTrue(torch.all(padding_area == 0.0), f"non-zero entries in {padding_area}")
91-
92-
def test_multihead_selfattn(self):
93-
batchnorm_args = copy.deepcopy(self.base_args)
94-
batchnorm_args.no_syncbatchnorm = True
95-
batchnorm_args.encoder_embed_dim = 8
96-
fake_sample = torch.rand(2, 10, 8)
97-
fake_sample[1, 3:, :] = 0
98-
fake_lengths = torch.LongTensor([10, 3])
99-
padding_mask = lengths_to_padding_mask(fake_lengths)
100-
fake_sample2 = fake_sample[1:, :3, :]
101-
padding_mask2 = lengths_to_padding_mask(fake_lengths[1].unsqueeze(0))
102-
attn = MultiHeadedSelfAttentionModule(8, 4)
103-
attn.eval()
104-
attn_out = attn(fake_sample, padding_mask)
105-
attn_out2 = attn(fake_sample2, padding_mask2)
106-
torch.testing.assert_allclose(attn_out[1, :3, :], attn_out2[0])
107-
self.assertTrue(
108-
torch.all(attn_out[1, 3:, :] == 0.0), f"non-zero entries in {attn_out[1, 3:, :]}")
109-
110-
def test_encoder_batch(self):
111-
batchnorm_args = copy.deepcopy(self.base_args)
112-
batchnorm_args.no_syncbatchnorm = True
113-
batchnorm_args.encoder_embed_dim = 8
114-
batchnorm_args.input_feat_per_channel = 8
115-
batchnorm_args.encoder_layers = 3
116-
fake_sample = torch.rand(5, 27, 8)
117-
fake_sample[1, 13:, :] = 0
118-
fake_sample[2, 8:, :] = 0
119-
fake_sample[3, 8:, :] = 0
120-
fake_sample[4, 5:, :] = 0
121-
fake_lengths = torch.LongTensor([27, 13, 8, 8, 5])
122-
encoder = ConformerEncoder(batchnorm_args, self.fake_dict)
123-
encoder.eval()
124-
net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True)
125-
126-
def test_item(item_idx):
127-
item_len = fake_lengths[item_idx].item()
128-
item_out_len = math.ceil(item_len / 4)
129-
fake_sample2 = fake_sample[item_idx, :item_len, :]
130-
net_out2 = encoder.forward(
131-
fake_sample2.unsqueeze(0), fake_lengths[item_idx].unsqueeze(0), return_all_hiddens=True)
132-
torch.testing.assert_allclose(
133-
net_out["encoder_out"][0][:item_out_len, item_idx, :],
134-
net_out2["encoder_out"][0][:, 0, :])
135-
136-
for i in range(5):
137-
test_item(i)
138-
139-
def test_encoder_batch_unsafe_fails(self):
140-
batchnorm_args = copy.deepcopy(self.base_args)
141-
batchnorm_args.no_syncbatchnorm = True
142-
batchnorm_args.encoder_embed_dim = 8
143-
batchnorm_args.input_feat_per_channel = 8
144-
batchnorm_args.encoder_layers = 3
145-
batchnorm_args.batch_unsafe_relative_shift = True
146-
fake_sample = torch.rand(2, 27, 8)
147-
fake_sample[1, 13:, :] = 0
148-
fake_lengths = torch.LongTensor([27, 13])
149-
encoder = ConformerEncoder(batchnorm_args, self.fake_dict)
150-
encoder.eval()
151-
net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True)
152-
fake_sample2 = fake_sample[1, :13, :]
153-
net_out2 = encoder.forward(fake_sample2.unsqueeze(0), fake_lengths[1].unsqueeze(0), return_all_hiddens=True)
154-
with self.assertRaises(AssertionError) as ae:
155-
torch.testing.assert_allclose(
156-
net_out["encoder_out"][0][:4, 1, :],
157-
net_out2["encoder_out"][0][:, 0, :])
158-
self.assertTrue("Tensor-likes are not close!" in str(ae.exception))
159-
160151

161152
if __name__ == '__main__':
162153
unittest.main()

0 commit comments

Comments
 (0)