|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License
|
14 | 14 | import copy
|
15 |
| -import math |
16 | 15 | import unittest
|
17 | 16 | from argparse import Namespace
|
18 | 17 |
|
19 |
| -import torch |
20 |
| -from torch import nn |
| 18 | +from torch import nn, Tensor, LongTensor |
21 | 19 |
|
22 | 20 | from examples.speech_to_text.models.conformer import conformer_s, ConformerEncoder
|
23 | 21 | from examples.speech_to_text.modules.conformer_attention import MultiHeadedSelfAttentionModule
|
24 | 22 | from examples.speech_to_text.modules.conformer_encoder_layer import ConformerEncoderLayer
|
25 | 23 | from fairseq.data import Dictionary
|
26 | 24 | from fairseq.data.data_utils import lengths_to_padding_mask
|
27 | 25 |
|
| 26 | +from pangolinn import seq2seq |
| 27 | + |
| 28 | + |
| 29 | +class MultiHeadedSelfAttentionPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): |
| 30 | + def build_module(self) -> nn.Module: |
| 31 | + return MultiHeadedSelfAttentionModule(self.num_input_channels, 2) |
| 32 | + |
| 33 | + @property |
| 34 | + def num_input_channels(self) -> int: |
| 35 | + return 8 |
| 36 | + |
| 37 | + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: |
| 38 | + return self._module(x, lengths_to_padding_mask(lengths)) |
| 39 | + |
| 40 | + |
| 41 | +class ConformerEncoderLayerPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): |
| 42 | + def build_module(self) -> nn.Module: |
| 43 | + base_args = Namespace() |
| 44 | + base_args.input_feat_per_channel = self.num_input_channels |
| 45 | + base_args.input_channels = 1 |
| 46 | + base_args.max_source_positions = 10 |
| 47 | + base_args.no_syncbatchnorm = True |
| 48 | + base_args.encoder_embed_dim = 8 |
| 49 | + conformer_s(base_args) |
| 50 | + return ConformerEncoderLayer(base_args) |
| 51 | + |
| 52 | + @property |
| 53 | + def num_input_channels(self) -> int: |
| 54 | + return 8 |
| 55 | + |
| 56 | + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: |
| 57 | + return self._module(x.transpose(0, 1), lengths_to_padding_mask(lengths)).transpose(0, 1) |
| 58 | + |
| 59 | + |
| 60 | +class ConformerEncoderPangolinnWrapper(seq2seq.PangolinnSeq2SeqModuleWrapper): |
| 61 | + def base_args(self) -> Namespace: |
| 62 | + base_args = Namespace() |
| 63 | + base_args.input_feat_per_channel = self.num_input_channels |
| 64 | + base_args.input_channels = 1 |
| 65 | + base_args.max_source_positions = 10 |
| 66 | + base_args.no_syncbatchnorm = True |
| 67 | + base_args.encoder_embed_dim = 8 |
| 68 | + base_args.encoder_layers = 3 |
| 69 | + base_args.criterion = "ctc_multi_loss" |
| 70 | + base_args.ctc_compress_strategy = "none" |
| 71 | + base_args.ctc_encoder_layer = 2 |
| 72 | + conformer_s(base_args) |
| 73 | + return base_args |
| 74 | + |
| 75 | + def build_module(self) -> nn.Module: |
| 76 | + return ConformerEncoder(self.base_args(), Dictionary()) |
| 77 | + |
| 78 | + @property |
| 79 | + def num_input_channels(self) -> int: |
| 80 | + return 8 |
| 81 | + |
| 82 | + @property |
| 83 | + def sequence_downsampling_factor(self) -> int: |
| 84 | + # the two initial Conv1D reduce sequence length by a factor of 4 |
| 85 | + return 4 |
| 86 | + |
| 87 | + def forward(self, x: Tensor, lengths: LongTensor) -> Tensor: |
| 88 | + return self._module(x, lengths)["encoder_out"][0].transpose(0, 1) |
| 89 | + |
| 90 | + |
| 91 | +class ConformerEncoderUnsafePangolinnWrapper(ConformerEncoderPangolinnWrapper): |
| 92 | + def base_args(self) -> Namespace: |
| 93 | + args = super().base_args() |
| 94 | + args.batch_unsafe_relative_shift = True |
| 95 | + return args |
| 96 | + |
| 97 | + |
| 98 | +class MultiHeadedSelfAttentionTestCase(seq2seq.EncoderPaddingTestCase): |
| 99 | + module_wrapper_class = MultiHeadedSelfAttentionPangolinnWrapper |
| 100 | + |
| 101 | + |
| 102 | +class ConformerEncoderLayerPaddingTestCase(seq2seq.EncoderPaddingTestCase): |
| 103 | + module_wrapper_class = ConformerEncoderLayerPangolinnWrapper |
| 104 | + |
| 105 | + |
| 106 | +class ConformerEncoderPaddingTestCase(seq2seq.EncoderPaddingTestCase): |
| 107 | + module_wrapper_class = ConformerEncoderPangolinnWrapper |
| 108 | + |
| 109 | + |
| 110 | +class ConformerEncoderUnsafePaddingTestCase(seq2seq.EncoderPaddingTestCase): |
| 111 | + module_wrapper_class = ConformerEncoderUnsafePangolinnWrapper |
| 112 | + |
| 113 | + def test_batch_size_does_not_matter(self): |
| 114 | + with self.assertRaises(AssertionError): |
| 115 | + super().test_batch_size_does_not_matter() |
| 116 | + |
28 | 117 |
|
29 | 118 | class ConformerEncoderTestCase(unittest.TestCase):
|
30 | 119 | @classmethod
|
@@ -59,104 +148,6 @@ def check_norm(self, args, norm_class):
|
59 | 148 | self.assertTrue(
|
60 | 149 | isinstance(encoder._modules["conformer_layers"][layer].conv_module.batchnorm, norm_class))
|
61 | 150 |
|
62 |
| - def test_conformer_encoder_layer_padding(self): |
63 |
| - batchnorm_args = copy.deepcopy(self.base_args) |
64 |
| - batchnorm_args.no_syncbatchnorm = True |
65 |
| - batchnorm_args.encoder_embed_dim = 8 |
66 |
| - fake_sample = torch.rand(2, 10, 8) |
67 |
| - fake_sample[1, 3:, :] = 0 |
68 |
| - fake_lengths = torch.LongTensor([10, 3]) |
69 |
| - padding_mask = lengths_to_padding_mask(fake_lengths) |
70 |
| - encoder_layer = ConformerEncoderLayer(batchnorm_args) |
71 |
| - encoder_layer.eval() |
72 |
| - out = encoder_layer(fake_sample.transpose(0, 1), padding_mask).transpose(0, 1) |
73 |
| - self.assertTrue( |
74 |
| - torch.all(out[1, 3:, :] == 0.0), f"non-zero entries in {out[1, 3:, :]}") |
75 |
| - |
76 |
| - def test_encoder_padding(self): |
77 |
| - batchnorm_args = copy.deepcopy(self.base_args) |
78 |
| - batchnorm_args.no_syncbatchnorm = True |
79 |
| - batchnorm_args.encoder_embed_dim = 8 |
80 |
| - batchnorm_args.input_feat_per_channel = 8 |
81 |
| - batchnorm_args.encoder_layers = 3 |
82 |
| - fake_sample = torch.rand(2, 27, 8) |
83 |
| - fake_sample[1, 13:, :] = 0 |
84 |
| - fake_lengths = torch.LongTensor([27, 13]) |
85 |
| - encoder = ConformerEncoder(batchnorm_args, self.fake_dict) |
86 |
| - encoder.eval() |
87 |
| - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) |
88 |
| - padding_area = net_out["encoder_out"][0][4:, 1, :] # output is N x B x C and downsampled by 4 |
89 |
| - self.assertGreater(padding_area.numel(), 0) |
90 |
| - self.assertTrue(torch.all(padding_area == 0.0), f"non-zero entries in {padding_area}") |
91 |
| - |
92 |
| - def test_multihead_selfattn(self): |
93 |
| - batchnorm_args = copy.deepcopy(self.base_args) |
94 |
| - batchnorm_args.no_syncbatchnorm = True |
95 |
| - batchnorm_args.encoder_embed_dim = 8 |
96 |
| - fake_sample = torch.rand(2, 10, 8) |
97 |
| - fake_sample[1, 3:, :] = 0 |
98 |
| - fake_lengths = torch.LongTensor([10, 3]) |
99 |
| - padding_mask = lengths_to_padding_mask(fake_lengths) |
100 |
| - fake_sample2 = fake_sample[1:, :3, :] |
101 |
| - padding_mask2 = lengths_to_padding_mask(fake_lengths[1].unsqueeze(0)) |
102 |
| - attn = MultiHeadedSelfAttentionModule(8, 4) |
103 |
| - attn.eval() |
104 |
| - attn_out = attn(fake_sample, padding_mask) |
105 |
| - attn_out2 = attn(fake_sample2, padding_mask2) |
106 |
| - torch.testing.assert_allclose(attn_out[1, :3, :], attn_out2[0]) |
107 |
| - self.assertTrue( |
108 |
| - torch.all(attn_out[1, 3:, :] == 0.0), f"non-zero entries in {attn_out[1, 3:, :]}") |
109 |
| - |
110 |
| - def test_encoder_batch(self): |
111 |
| - batchnorm_args = copy.deepcopy(self.base_args) |
112 |
| - batchnorm_args.no_syncbatchnorm = True |
113 |
| - batchnorm_args.encoder_embed_dim = 8 |
114 |
| - batchnorm_args.input_feat_per_channel = 8 |
115 |
| - batchnorm_args.encoder_layers = 3 |
116 |
| - fake_sample = torch.rand(5, 27, 8) |
117 |
| - fake_sample[1, 13:, :] = 0 |
118 |
| - fake_sample[2, 8:, :] = 0 |
119 |
| - fake_sample[3, 8:, :] = 0 |
120 |
| - fake_sample[4, 5:, :] = 0 |
121 |
| - fake_lengths = torch.LongTensor([27, 13, 8, 8, 5]) |
122 |
| - encoder = ConformerEncoder(batchnorm_args, self.fake_dict) |
123 |
| - encoder.eval() |
124 |
| - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) |
125 |
| - |
126 |
| - def test_item(item_idx): |
127 |
| - item_len = fake_lengths[item_idx].item() |
128 |
| - item_out_len = math.ceil(item_len / 4) |
129 |
| - fake_sample2 = fake_sample[item_idx, :item_len, :] |
130 |
| - net_out2 = encoder.forward( |
131 |
| - fake_sample2.unsqueeze(0), fake_lengths[item_idx].unsqueeze(0), return_all_hiddens=True) |
132 |
| - torch.testing.assert_allclose( |
133 |
| - net_out["encoder_out"][0][:item_out_len, item_idx, :], |
134 |
| - net_out2["encoder_out"][0][:, 0, :]) |
135 |
| - |
136 |
| - for i in range(5): |
137 |
| - test_item(i) |
138 |
| - |
139 |
| - def test_encoder_batch_unsafe_fails(self): |
140 |
| - batchnorm_args = copy.deepcopy(self.base_args) |
141 |
| - batchnorm_args.no_syncbatchnorm = True |
142 |
| - batchnorm_args.encoder_embed_dim = 8 |
143 |
| - batchnorm_args.input_feat_per_channel = 8 |
144 |
| - batchnorm_args.encoder_layers = 3 |
145 |
| - batchnorm_args.batch_unsafe_relative_shift = True |
146 |
| - fake_sample = torch.rand(2, 27, 8) |
147 |
| - fake_sample[1, 13:, :] = 0 |
148 |
| - fake_lengths = torch.LongTensor([27, 13]) |
149 |
| - encoder = ConformerEncoder(batchnorm_args, self.fake_dict) |
150 |
| - encoder.eval() |
151 |
| - net_out = encoder.forward(fake_sample, fake_lengths, return_all_hiddens=True) |
152 |
| - fake_sample2 = fake_sample[1, :13, :] |
153 |
| - net_out2 = encoder.forward(fake_sample2.unsqueeze(0), fake_lengths[1].unsqueeze(0), return_all_hiddens=True) |
154 |
| - with self.assertRaises(AssertionError) as ae: |
155 |
| - torch.testing.assert_allclose( |
156 |
| - net_out["encoder_out"][0][:4, 1, :], |
157 |
| - net_out2["encoder_out"][0][:, 0, :]) |
158 |
| - self.assertTrue("Tensor-likes are not close!" in str(ae.exception)) |
159 |
| - |
160 | 151 |
|
161 | 152 | if __name__ == '__main__':
|
162 | 153 | unittest.main()
|
0 commit comments