diff --git a/finetuning/sft_12hz.py b/finetuning/sft_12hz.py index 6ec80c80..cb0bd187 100644 --- a/finetuning/sft_12hz.py +++ b/finetuning/sft_12hz.py @@ -92,11 +92,6 @@ def train(): input_embeddings = input_text_embedding + input_codec_embedding - for i in range(1, 16): - codec_i_embedding = model.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i]) - codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1) - input_embeddings = input_embeddings + codec_i_embedding - outputs = model.talker( inputs_embeds=input_embeddings[:, :-1, :], attention_mask=attention_mask[:, :-1],