diff --git a/sam_audio/model/model.py b/sam_audio/model/model.py index 1bf5c586..2ac5ddc7 100644 --- a/sam_audio/model/model.py +++ b/sam_audio/model/model.py @@ -267,6 +267,18 @@ def separate( ), ) + # Refresh anchor conditioning created by predict_spans() + forward_args.update( + { + "anchor_ids": self._repeat_for_reranking( + batch.anchor_ids, reranking_candidates + ), + "anchor_alignment": self._repeat_for_reranking( + batch.anchor_alignment, reranking_candidates + ), + } + ) + audio_features = forward_args["audio_features"] B, T, C = audio_features.shape C = C // 2 # we stack audio_features, so the actual channels is half