-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathconvert_wav2vec2_from_fairseq.py
60 lines (52 loc) · 2.1 KB
/
convert_wav2vec2_from_fairseq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""Convert fairseq's wav2vec2 to our format."""
import torch
import fairseq
from torchaudio.models.wav2vec2.utils import import_fairseq_model
from wav2vec2.model import wav2vec2_model
if __name__ == "__main__":
out_name = "pretrained/wav2vec2-base-ls960.fairseq.pth"
fairseq_ckpt = "pretrained/fairseq/wav2vec_small.pt"
ensemble, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([fairseq_ckpt])
original = ensemble[0]
imported = import_fairseq_model(original)
print(imported)
# default config of wav2vec2 base
wav2vec2_base_config = dict(
extractor_mode="group_norm", # hubert/w2v2 base only uses a group norm at the first conv layer
extractor_conv_layer_config=[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512, 2, 2)] * 2,
extractor_conv_bias=False,
encoder_embed_dim=768,
encoder_projection_dropout=0.1,
encoder_pos_conv_kernel=128,
encoder_pos_conv_groups=16,
encoder_num_layers=12,
encoder_use_attention=[True] * 12,
encoder_use_feed_forward=[True] * 12,
encoder_num_heads=[12] * 12,
encoder_head_dim=64,
encoder_attention_dropout=0.1,
encoder_ff_interm_features=[3072] * 12,
encoder_ff_interm_dropout=0.0,
encoder_dropout=0.1,
encoder_layer_norm_first=False, # hubert/w2v2 base uses post norm
encoder_layer_drop=0.05,
aux_num_out=None,
normalize_waveform=False,
extractor_prune_conv_channels=False,
encoder_prune_attention_heads=False,
encoder_prune_attention_layer=False,
encoder_prune_feed_forward_intermediate=False,
encoder_prune_feed_forward_layer=False,
)
torch.save(
{
'state_dict': imported.state_dict(),
'config': wav2vec2_base_config,
},
out_name
)
# verify the saved ckpt
ckpt = torch.load(out_name, map_location="cpu")
model = wav2vec2_model(**ckpt['config'])
res = model.load_state_dict(ckpt['state_dict'], strict=False)
print(f"Missing: {res.missing_keys}\nUnexpected: {res.unexpected_keys}")