Skip to content

Commit 248135e

Browse files
sarapapimgaido91
authored andcommitted
[!156][STREAMING] Add baseline FixedAudioHistorySelection for StreamST
# Why is the change needed? To compare the new StreamAtt policy, it is useful to have a baseline that cuts the audio based on the number of textual history words discarded * fixed word duration (of 280ms, following previous work) to compare with. # What changes does the patch introduce? Adds a text-first history selection method that implements this logic: - first, the new textual history is selected based on a fixed number of words to retain - second, the number of words discarded (the difference between the textual history words of the previous step and the fixed number of words to retain) are multiplied by a fixed duration (here, 280ms) and these frames are cut from the audio history # How was this patch tested? UTs
1 parent 2002726 commit 248135e

File tree

2 files changed

+87
-2
lines changed

2 files changed

+87
-2
lines changed

examples/speech_to_text/simultaneous_translation/agents/v1_1/streaming/text_first_history_selection.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import torch
1818

19-
from examples.speech_to_text.simultaneous_translation.agents.speech_utils import BOW_PREFIX
19+
from examples.speech_to_text.simultaneous_translation.agents.speech_utils import BOW_PREFIX, SHIFT_SIZE
2020
from examples.speech_to_text.simultaneous_translation.agents.v1_1.streaming.history_selection import HistorySelection
2121
from fairseq.data import Dictionary
2222
from fairseq.data.audio.speech_to_text_dataset import SpeechToTextDataset
@@ -194,3 +194,37 @@ def text_history(self, action: Action, states: AgentStates):
194194
f"{self.history_max_len}")
195195
new_history = new_history[-self.history_max_len:]
196196
return new_history
197+
198+
199+
class FixedAudioHistorySelection(FixedWordsHistorySelection):
200+
"""
201+
Audio history selection method that assign to each token of the textual history a fixed
202+
duration of *FIXED_WORD_DURATION* and cut the audio history, stored in *states.source*,
203+
accordingly. The history for the next decoding step is defined as follows:
204+
- First, a pre-defined number of words (*history_words*) is retained as textual history from
205+
the textual history of the previous decoding step and the *current_hypo* that is determined
206+
by the SimulST agent and added to *states.target_indices*;
207+
- Second, the new audio history is selected by discarding the audio frames corresponding to
208+
the number of words discarded from the textual history multiplied by *FIXED_WORD_DURATION*.
209+
210+
The implementation works only for SentencePiece up to now.
211+
"""
212+
FIXED_WORD_DURATION = 280 # duration of a word (in ms) as per (Ma et al., 2021)
213+
214+
def audio_history(self, action: Action, states: AgentStates, new_text_history: List[int]):
215+
# Compute the number of words discarded from textual history
216+
n_discarded_tokens = len(states.target_indices) - len(new_text_history)
217+
218+
# If no discarded tokens, return the original audio
219+
if n_discarded_tokens == 0:
220+
return states.source[0]
221+
222+
discarded_tokens = states.target_indices[:n_discarded_tokens]
223+
224+
n_discarded_words = len(self.tgt_dict.string(
225+
discarded_tokens).strip(BOW_PREFIX).split(BOW_PREFIX))
226+
227+
# Recover the original number of frames considering that each audio feature corresponds
228+
# to 10ms (SHIFT_SIZE)
229+
frames_to_discard = n_discarded_words * self.FIXED_WORD_DURATION // SHIFT_SIZE
230+
return states.source[0][frames_to_discard:]

fbk_simul_uts/v1_1/test_streamatt.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from examples.speech_to_text.simultaneous_translation.agents.v1_1.streaming.streaming_st_agent import StreamingSTAgent, \
2020
get_class_from_string
2121
from examples.speech_to_text.simultaneous_translation.agents.v1_1.streaming.text_first_history_selection import \
22-
PunctuationHistorySelection
22+
PunctuationHistorySelection, FixedAudioHistorySelection
2323
from simuleval.agents import ReadAction, WriteAction
2424

2525
from fbk_simul_uts.v1_1.test_base_simulst_agent import BaseSTAgentTestCaseV2, MockedLoadModelVocab
@@ -218,6 +218,57 @@ def test_prefix_punctuation_selection(self, get_hypo_and_prefix):
218218
# Check first no frame discarded
219219
self.assertEqual(len(self.states.source[0]), 24)
220220

221+
@patch('examples.speech_to_text.simultaneous_translation.agents.v1_1.'
222+
'simul_offline_alignatt.AlignAttSTAgent._get_hypo_and_prefix')
223+
def test_fixed_audio_selection(self, get_hypo_and_prefix):
224+
hypo = {
225+
"tokens": torch.tensor([4, 5, 7, 8, 0]), # I am quokka.
226+
"attention": torch.tensor([
227+
[0.5, 0.05, 0.05, 0.05, 0.05, 0.3], # first frame mostly attended
228+
[0.0, 0.6, 0.05, 0.03, 0.02, 0.3], # second frame mostly attended
229+
[0.05, 0.5, 0.05, 0.05, 0.05, 0.3], # second frame mostly attended
230+
[0.0, 0.6, 0.05, 0.03, 0.02, 0.3], # second frame mostly attended
231+
[0.05, 0.05, 0.05, 0.5, 0.05, 0.3], # last frame mostly attended
232+
]).transpose(0, 1)
233+
}
234+
235+
self.args.history_words = 1
236+
self.agent.history_selection_method = FixedAudioHistorySelection(
237+
self.agent.simulst_agent.tgtdict, self.agent.simulst_agent.args)
238+
239+
# No prefix
240+
get_hypo_and_prefix.return_value = hypo, 0
241+
self.states.target_indices = []
242+
self.states.source = [torch.rand(280 // 10 * 4)]
243+
action = self.agent.policy(self.states)
244+
self.assertIsInstance(action, WriteAction)
245+
self.assertEqual(action.content, "I am")
246+
# "I am" should be written but only "am" should be retained as textual history (since
247+
# history_words is set to 1), therefore 280ms (corresponding to one word) should be
248+
# discarded
249+
self.assertEqual(len(self.states.source[0]), 280 // 10 * 3)
250+
251+
# History len 1: "I"
252+
get_hypo_and_prefix.return_value = hypo, 1
253+
self.states.target_indices = [4]
254+
self.states.source = [torch.rand(280 // 10 * 4)]
255+
action = self.agent.policy(self.states)
256+
self.assertIsInstance(action, WriteAction)
257+
self.assertEqual(action.content, "am")
258+
# "am" should be written and retained as textual history (since history_words is set to 1)
259+
# while "I" should be discarded, therefore 280ms (corresponding to one word) should be
260+
# discarded
261+
self.assertEqual(len(self.states.source[0]), 280 // 10 * 3)
262+
263+
# History len 1: "am"
264+
get_hypo_and_prefix.return_value = hypo, 2
265+
self.agent.states.target_indices = [5]
266+
self.states.source = [torch.rand(280 // 10 * 4)]
267+
action = self.agent.policy(self.states)
268+
self.assertIsInstance(action, ReadAction)
269+
# Check no frame discarded
270+
self.assertEqual(len(self.states.source[0]), 280 // 10 * 4)
271+
221272
@patch('examples.speech_to_text.simultaneous_translation.agents.v1_1.'
222273
'simul_offline_alignatt.AlignAttSTAgent._get_hypo_and_prefix')
223274
def test_no_token_emitted(self, get_hypo_and_prefix):

0 commit comments

Comments
 (0)