diff --git a/punctfix/inference.py b/punctfix/inference.py index d585c81..8502386 100644 --- a/punctfix/inference.py +++ b/punctfix/inference.py @@ -189,7 +189,7 @@ def punctuate(self, text: str) -> str: If it has punctuatation, it will be removed. :return: A punctuated text. """ - words = self._split_input_text(text) + words = self.split_input_text(text) # If we have a long sequence of text (measured by words), we split it into chunks chunks = [] @@ -203,7 +203,12 @@ def punctuate(self, text: str) -> str: word_prediction_list = self.populate_word_prediction_with_labels(chunks, word_prediction_list) return self.combine_word_predictions_into_final_text(word_prediction_list) - def _split_input_text(self, text: str) -> List[str]: + def split_input_text(self, text: str) -> List[str]: + """ + Splits given text into words using whitespace tokenization, also performing normalization + :param text: A lowercase text with no punctuation (otherwise normalized) + :return: A list of the words in that text, splitted and normalized. + """ words = text.split(" ") if self.skip_normalization: return words diff --git a/punctfix/streaming.py b/punctfix/streaming.py new file mode 100644 index 0000000..ecf1273 --- /dev/null +++ b/punctfix/streaming.py @@ -0,0 +1,124 @@ +from typing import List, Optional + +from punctfix.inference import PunctFixer, WordPrediction + + +class PunctFixStreamer: + """ + A stateful streamer that receives text in segments, on-line performing punct-fixing and + returning partial results during streaming. These partial results are guaranteed to be + final. + """ + + chunked_words: List[WordPrediction] + buffer: List[WordPrediction] + + def __init__(self, punct_fixer: PunctFixer): + """ + Takes in an instantiated punct fixer. + """ + self.punct_fixer = punct_fixer + self.clear() + + def __call__(self, new_text_segment: str) -> Optional[str]: + """ + Stream in new text, returning None if this new text did not change anything + and the partial, finalized text if there has been updates to it. + """ + self.buffer.extend( + self.punct_fixer.init_word_prediction_list( + self.punct_fixer.split_input_text(new_text_segment) + ) + ) + if self.process_buffer(): + return self.get_result() + return None + + def finalize(self): + """ + Mark end of stream and return final puncatuated string. + """ + self.process_buffer(is_finalized=True) + punctuated = self.get_result(is_finalized=True) + self.clear() + return punctuated + + def get_result(self, is_finalized=False) -> str: + """ + Returns punctuated string in of all inputs streamed in so far. + If called when not finalized, will only return text that is certain/no longer subject to change + """ + if is_finalized: + finalized_words = self.chunked_words + # These lines perform a tricky calculation in a dumb way: + # When is each word finalized? When it has gotten all the labels that it will get. + # This number of labels is not constant across the sequence and depends on overlap + # size and on chunk size. To avoid trying to be clever, I just calculate the chunks + # and overlaps and sum up how many times each index will be in a chunk. + else: + # The + chunk size makes calculation takes into account that there will be more + # chunks in future and that we should not finalize prematurely + final_num_preds = [0] * ( + len(self.chunked_words) + self.punct_fixer.word_chunk_size + ) + for chunk in self.punct_fixer.split_words_into_chunks( + range(len(self.chunked_words)) + ): + for idx in chunk: + final_num_preds[idx] += 1 + finalized_words = [ + word + for i, word in enumerate(self.chunked_words) + if len(word.labels) == final_num_preds[i] + ] + return self.punct_fixer.combine_word_predictions_into_final_text( + finalized_words + ) + + def process_buffer(self, is_finalized=False) -> bool: + """ + Performs actual punctfixing of content in buffer, updating internal state such that a maximal number + of words get predicted labels. Returns true if new chunks were created and processed and false if not. + """ + new_chunks = [] + # Save how many words were chunked before this call + this_processing_started_at = ( + len(self.chunked_words) - self.punct_fixer.word_overlap + if self.chunked_words + else 0 + ) + # Whole chunks are appended unless the stream is finalized in which case, the buffer + # is completely emptied + while len(self.buffer) >= self.punct_fixer.word_chunk_size or ( + is_finalized and self.buffer + ): + new_chunks.append( + [word.word for word in self.buffer[: self.punct_fixer.word_chunk_size]] + ) + # Not all words are chunked for the first time, we must (except for first time) + # skip the first `word_overlap` words to avoid duplicates. + already_chunked_idx = ( + self.punct_fixer.word_overlap if self.chunked_words else 0 + ) + self.chunked_words.extend( + self.buffer[already_chunked_idx : self.punct_fixer.word_chunk_size] + ) + # We don't remove the entire buffer length from the buffer as we want + # to emulate the overlap feature of the punctfixer; we leave some in there for next chunk. + self.buffer = self.buffer[ + self.punct_fixer.word_chunk_size - self.punct_fixer.word_overlap : + ] + if new_chunks: + # Run the forward pass on all new chunks, matching with the words that are included in them + self.punct_fixer.populate_word_prediction_with_labels( + new_chunks, self.chunked_words[this_processing_started_at:] + ) + return True + return False + + def clear(self): + """ + Reset internal state. + """ + self.buffer = [] + self.chunked_words = [] diff --git a/tests/test_punctuation.py b/tests/test_punctuation.py index 6a1f784..c6d1270 100644 --- a/tests/test_punctuation.py +++ b/tests/test_punctuation.py @@ -3,6 +3,7 @@ from punctfix import PunctFixer from punctfix.inference import NonNormalizedTextWarning +from punctfix.streaming import PunctFixStreamer class CleanupDisableTest(unittest.TestCase): @@ -206,22 +207,22 @@ def test_do_normalize(self): for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand", "Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand", "hejsa % mand ! % "): - actual_output = self.model._split_input_text(model_input) + actual_output = self.model.split_input_text(model_input) self.assertEqual(actual_output, expected_output) def test_warnings(self): self.model.warn_on_normalization = True with self.assertWarns(NonNormalizedTextWarning): model_input = "hejsa, mand" - self.model._split_input_text(model_input) + self.model.split_input_text(model_input) with self.assertWarns(NonNormalizedTextWarning): model_input = "hejsa mand" - self.model._split_input_text(model_input) + self.model.split_input_text(model_input) with self.assertWarns(NonNormalizedTextWarning): model_input = "hejsa Mand" - self.model._split_input_text(model_input) + self.model.split_input_text(model_input) def test_do_not_normalize(self): model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ @@ -232,7 +233,7 @@ def test_do_not_normalize(self): " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ "de store folkemasser der er mødt op her på" expected_output = model_input.split(" ") - actual_output = self.model._split_input_text(model_input) + actual_output = self.model.split_input_text(model_input) self.assertEqual(actual_output, expected_output) class InputParameterTest(unittest.TestCase): @@ -246,6 +247,99 @@ def test_setting_batch_size(self): actual_output = model.punctuate(model_input) self.assertEqual(actual_output, expected_output) +class PunctFixStreamerTest(unittest.TestCase): + + def setUp(self) -> None: + super().setUp() + self.streamer = PunctFixStreamer(PunctFixer(language="da")) + + def tearDown(self) -> None: + super().tearDown() + del self.streamer + + def test_sample01(self): + model_inputs = "mit navn det er rasmus", "og jeg kommer", "fra firmaet alvenir",\ + "det er mig", "som har trænet", "denne", "lækre model" + expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ + "Det er mig som har trænet denne lækre model." + + for input_ in model_inputs: + self.streamer(input_) + actual_output = self.streamer.finalize() + self.assertEqual(actual_output, expected_output) + + def test_sample02(self): + model_inputs = "en dag bliver vi sku glade", "for", "at vi nu kan", "sætte punktummer ",\ + "og kommaer", "i", "en", "sætning det fungerer da meget", "godt ikke" + expected_output = "En dag bliver vi sku glade for, at vi nu kan sætte punktummer " \ + "og kommaer i en sætning. Det fungerer da meget godt, ikke?" + for input_ in model_inputs: + self.streamer(input_) + actual_output = self.streamer.finalize() + self.assertEqual(actual_output, expected_output) + + def test_sample03(self): + # We want it super loooong + model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ + "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ + "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ + "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ + "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ + " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ + "de store folkemasser der er mødt op her på byggepladsen her er ekstra ord " * 2 + expected_output = self.streamer.punct_fixer.punctuate(model_input) + for w in model_input.split(): + partial_output = self.streamer(w) + if partial_output is not None: + self.assertIn(partial_output, expected_output) + actual_output = self.streamer.finalize() + self.assertEqual(actual_output, expected_output) + + def test_repeated_same_input(self): + self.streamer("test") + self.streamer("test") + output = self.streamer.finalize() + expected_output = "Test test." + self.assertEqual(output, expected_output) + + def test_empty_string_input(self): + self.streamer("") + output = self.streamer.finalize() + self.assertEqual(output, "") + + ### The below tests are isolated method unit tests + def test_get_result_method(self): + self.streamer("test test") + # Call get_result at an intermediate state + output = self.streamer.get_result() + self.assertEqual(output, "") + # Call get_result at a final state but without processing buffer + output = self.streamer.get_result(is_finalized=True) + self.assertEqual(output, "") + # Call get_result at a final state where there is enough data to make the buffer processed + self.streamer(" ".join(["test"]*98)) # gives a total of 100=chunk size + output = self.streamer.get_result(is_finalized= True) + self.assertEqual(output, ("Test "*100)[:-1]) + + def test_finalize_method(self): + self.streamer("finalizing test") + output = self.streamer.finalize() + expected_output = "Finalizing Test." + self.assertEqual(output, expected_output) + + def test_call_method(self): + self.streamer("test input") + self.assertEqual(len(self.streamer.buffer), 2) + self.assertEqual(len(self.streamer.chunked_words), 0) + self.streamer("test " * 100) + self.assertEqual(len(self.streamer.buffer), 72) # Overlap size +2 + self.assertEqual(len(self.streamer.chunked_words), 100) # Chunk size + + def test_clear_method(self): + self.streamer("clearing test") + self.streamer.clear() + self.assertEqual(self.streamer.buffer, []) + self.assertEqual(self.streamer.chunked_words, []) if __name__ == '__main__': unittest.main()