From 18f01ac5dfe95f5991ad4962a467e19f57b12cc7 Mon Sep 17 00:00:00 2001 From: sorenmulli Date: Wed, 6 Dec 2023 15:05:01 +0100 Subject: [PATCH] Implement eager, streaming punct fixer As of this commit, users can import the PunctFixStreamer which allows for inputting unfinished segments and getting partial results which can be trusted as corresponding to a subset of the final result --- punctfix/inference.py | 9 ++- punctfix/streaming.py | 124 ++++++++++++++++++++++++++++++++++++++ tests/test_punctuation.py | 104 ++++++++++++++++++++++++++++++-- 3 files changed, 230 insertions(+), 7 deletions(-) create mode 100644 punctfix/streaming.py diff --git a/punctfix/inference.py b/punctfix/inference.py index d585c81..8502386 100644 --- a/punctfix/inference.py +++ b/punctfix/inference.py @@ -189,7 +189,7 @@ def punctuate(self, text: str) -> str: If it has punctuatation, it will be removed. :return: A punctuated text. """ - words = self._split_input_text(text) + words = self.split_input_text(text) # If we have a long sequence of text (measured by words), we split it into chunks chunks = [] @@ -203,7 +203,12 @@ def punctuate(self, text: str) -> str: word_prediction_list = self.populate_word_prediction_with_labels(chunks, word_prediction_list) return self.combine_word_predictions_into_final_text(word_prediction_list) - def _split_input_text(self, text: str) -> List[str]: + def split_input_text(self, text: str) -> List[str]: + """ + Splits given text into words using whitespace tokenization, also performing normalization + :param text: A lowercase text with no punctuation (otherwise normalized) + :return: A list of the words in that text, splitted and normalized. + """ words = text.split(" ") if self.skip_normalization: return words diff --git a/punctfix/streaming.py b/punctfix/streaming.py new file mode 100644 index 0000000..ecf1273 --- /dev/null +++ b/punctfix/streaming.py @@ -0,0 +1,124 @@ +from typing import List, Optional + +from punctfix.inference import PunctFixer, WordPrediction + + +class PunctFixStreamer: + """ + A stateful streamer that receives text in segments, on-line performing punct-fixing and + returning partial results during streaming. These partial results are guaranteed to be + final. + """ + + chunked_words: List[WordPrediction] + buffer: List[WordPrediction] + + def __init__(self, punct_fixer: PunctFixer): + """ + Takes in an instantiated punct fixer. + """ + self.punct_fixer = punct_fixer + self.clear() + + def __call__(self, new_text_segment: str) -> Optional[str]: + """ + Stream in new text, returning None if this new text did not change anything + and the partial, finalized text if there has been updates to it. + """ + self.buffer.extend( + self.punct_fixer.init_word_prediction_list( + self.punct_fixer.split_input_text(new_text_segment) + ) + ) + if self.process_buffer(): + return self.get_result() + return None + + def finalize(self): + """ + Mark end of stream and return final puncatuated string. + """ + self.process_buffer(is_finalized=True) + punctuated = self.get_result(is_finalized=True) + self.clear() + return punctuated + + def get_result(self, is_finalized=False) -> str: + """ + Returns punctuated string in of all inputs streamed in so far. + If called when not finalized, will only return text that is certain/no longer subject to change + """ + if is_finalized: + finalized_words = self.chunked_words + # These lines perform a tricky calculation in a dumb way: + # When is each word finalized? When it has gotten all the labels that it will get. + # This number of labels is not constant across the sequence and depends on overlap + # size and on chunk size. To avoid trying to be clever, I just calculate the chunks + # and overlaps and sum up how many times each index will be in a chunk. + else: + # The + chunk size makes calculation takes into account that there will be more + # chunks in future and that we should not finalize prematurely + final_num_preds = [0] * ( + len(self.chunked_words) + self.punct_fixer.word_chunk_size + ) + for chunk in self.punct_fixer.split_words_into_chunks( + range(len(self.chunked_words)) + ): + for idx in chunk: + final_num_preds[idx] += 1 + finalized_words = [ + word + for i, word in enumerate(self.chunked_words) + if len(word.labels) == final_num_preds[i] + ] + return self.punct_fixer.combine_word_predictions_into_final_text( + finalized_words + ) + + def process_buffer(self, is_finalized=False) -> bool: + """ + Performs actual punctfixing of content in buffer, updating internal state such that a maximal number + of words get predicted labels. Returns true if new chunks were created and processed and false if not. + """ + new_chunks = [] + # Save how many words were chunked before this call + this_processing_started_at = ( + len(self.chunked_words) - self.punct_fixer.word_overlap + if self.chunked_words + else 0 + ) + # Whole chunks are appended unless the stream is finalized in which case, the buffer + # is completely emptied + while len(self.buffer) >= self.punct_fixer.word_chunk_size or ( + is_finalized and self.buffer + ): + new_chunks.append( + [word.word for word in self.buffer[: self.punct_fixer.word_chunk_size]] + ) + # Not all words are chunked for the first time, we must (except for first time) + # skip the first `word_overlap` words to avoid duplicates. + already_chunked_idx = ( + self.punct_fixer.word_overlap if self.chunked_words else 0 + ) + self.chunked_words.extend( + self.buffer[already_chunked_idx : self.punct_fixer.word_chunk_size] + ) + # We don't remove the entire buffer length from the buffer as we want + # to emulate the overlap feature of the punctfixer; we leave some in there for next chunk. + self.buffer = self.buffer[ + self.punct_fixer.word_chunk_size - self.punct_fixer.word_overlap : + ] + if new_chunks: + # Run the forward pass on all new chunks, matching with the words that are included in them + self.punct_fixer.populate_word_prediction_with_labels( + new_chunks, self.chunked_words[this_processing_started_at:] + ) + return True + return False + + def clear(self): + """ + Reset internal state. + """ + self.buffer = [] + self.chunked_words = [] diff --git a/tests/test_punctuation.py b/tests/test_punctuation.py index 6a1f784..c6d1270 100644 --- a/tests/test_punctuation.py +++ b/tests/test_punctuation.py @@ -3,6 +3,7 @@ from punctfix import PunctFixer from punctfix.inference import NonNormalizedTextWarning +from punctfix.streaming import PunctFixStreamer class CleanupDisableTest(unittest.TestCase): @@ -206,22 +207,22 @@ def test_do_normalize(self): for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand", "Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand", "hejsa % mand ! % "): - actual_output = self.model._split_input_text(model_input) + actual_output = self.model.split_input_text(model_input) self.assertEqual(actual_output, expected_output) def test_warnings(self): self.model.warn_on_normalization = True with self.assertWarns(NonNormalizedTextWarning): model_input = "hejsa, mand" - self.model._split_input_text(model_input) + self.model.split_input_text(model_input) with self.assertWarns(NonNormalizedTextWarning): model_input = "hejsa mand" - self.model._split_input_text(model_input) + self.model.split_input_text(model_input) with self.assertWarns(NonNormalizedTextWarning): model_input = "hejsa Mand" - self.model._split_input_text(model_input) + self.model.split_input_text(model_input) def test_do_not_normalize(self): model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ @@ -232,7 +233,7 @@ def test_do_not_normalize(self): " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ "de store folkemasser der er mødt op her på" expected_output = model_input.split(" ") - actual_output = self.model._split_input_text(model_input) + actual_output = self.model.split_input_text(model_input) self.assertEqual(actual_output, expected_output) class InputParameterTest(unittest.TestCase): @@ -246,6 +247,99 @@ def test_setting_batch_size(self): actual_output = model.punctuate(model_input) self.assertEqual(actual_output, expected_output) +class PunctFixStreamerTest(unittest.TestCase): + + def setUp(self) -> None: + super().setUp() + self.streamer = PunctFixStreamer(PunctFixer(language="da")) + + def tearDown(self) -> None: + super().tearDown() + del self.streamer + + def test_sample01(self): + model_inputs = "mit navn det er rasmus", "og jeg kommer", "fra firmaet alvenir",\ + "det er mig", "som har trænet", "denne", "lækre model" + expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ + "Det er mig som har trænet denne lækre model." + + for input_ in model_inputs: + self.streamer(input_) + actual_output = self.streamer.finalize() + self.assertEqual(actual_output, expected_output) + + def test_sample02(self): + model_inputs = "en dag bliver vi sku glade", "for", "at vi nu kan", "sætte punktummer ",\ + "og kommaer", "i", "en", "sætning det fungerer da meget", "godt ikke" + expected_output = "En dag bliver vi sku glade for, at vi nu kan sætte punktummer " \ + "og kommaer i en sætning. Det fungerer da meget godt, ikke?" + for input_ in model_inputs: + self.streamer(input_) + actual_output = self.streamer.finalize() + self.assertEqual(actual_output, expected_output) + + def test_sample03(self): + # We want it super loooong + model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ + "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ + "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ + "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ + "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ + " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ + "de store folkemasser der er mødt op her på byggepladsen her er ekstra ord " * 2 + expected_output = self.streamer.punct_fixer.punctuate(model_input) + for w in model_input.split(): + partial_output = self.streamer(w) + if partial_output is not None: + self.assertIn(partial_output, expected_output) + actual_output = self.streamer.finalize() + self.assertEqual(actual_output, expected_output) + + def test_repeated_same_input(self): + self.streamer("test") + self.streamer("test") + output = self.streamer.finalize() + expected_output = "Test test." + self.assertEqual(output, expected_output) + + def test_empty_string_input(self): + self.streamer("") + output = self.streamer.finalize() + self.assertEqual(output, "") + + ### The below tests are isolated method unit tests + def test_get_result_method(self): + self.streamer("test test") + # Call get_result at an intermediate state + output = self.streamer.get_result() + self.assertEqual(output, "") + # Call get_result at a final state but without processing buffer + output = self.streamer.get_result(is_finalized=True) + self.assertEqual(output, "") + # Call get_result at a final state where there is enough data to make the buffer processed + self.streamer(" ".join(["test"]*98)) # gives a total of 100=chunk size + output = self.streamer.get_result(is_finalized= True) + self.assertEqual(output, ("Test "*100)[:-1]) + + def test_finalize_method(self): + self.streamer("finalizing test") + output = self.streamer.finalize() + expected_output = "Finalizing Test." + self.assertEqual(output, expected_output) + + def test_call_method(self): + self.streamer("test input") + self.assertEqual(len(self.streamer.buffer), 2) + self.assertEqual(len(self.streamer.chunked_words), 0) + self.streamer("test " * 100) + self.assertEqual(len(self.streamer.buffer), 72) # Overlap size +2 + self.assertEqual(len(self.streamer.chunked_words), 100) # Chunk size + + def test_clear_method(self): + self.streamer("clearing test") + self.streamer.clear() + self.assertEqual(self.streamer.buffer, []) + self.assertEqual(self.streamer.chunked_words, []) if __name__ == '__main__': unittest.main()