-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement eager, streaming punct fixer #21
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from typing import List, Optional | ||
|
||
from punctfix.inference import PunctFixer, WordPrediction | ||
|
||
|
||
class PunctFixStreamer: | ||
""" | ||
A stateful streamer that receives text in segments, on-line performing punct-fixing and | ||
returning partial results during streaming. These partial results are guaranteed to be | ||
final. | ||
""" | ||
|
||
chunked_words: List[WordPrediction] | ||
buffer: List[WordPrediction] | ||
|
||
def __init__(self, punct_fixer: PunctFixer): | ||
""" | ||
Takes in an instantiated punct fixer. | ||
""" | ||
self.punct_fixer = punct_fixer | ||
self.clear() | ||
|
||
def __call__(self, new_text_segment: str) -> Optional[str]: | ||
""" | ||
Stream in new text, returning None if this new text did not change anything | ||
and the partial, finalized text if there has been updates to it. | ||
""" | ||
self.buffer.extend( | ||
self.punct_fixer.init_word_prediction_list( | ||
self.punct_fixer.split_input_text(new_text_segment) | ||
) | ||
) | ||
if self.process_buffer(): | ||
return self.get_result() | ||
return None | ||
|
||
def finalize(self): | ||
""" | ||
Mark end of stream and return final puncatuated string. | ||
""" | ||
self.process_buffer(is_finalized=True) | ||
punctuated = self.get_result(is_finalized=True) | ||
self.clear() | ||
return punctuated | ||
|
||
def get_result(self, is_finalized=False) -> str: | ||
""" | ||
Returns punctuated string in of all inputs streamed in so far. | ||
If called when not finalized, will only return text that is certain/no longer subject to change | ||
""" | ||
if is_finalized: | ||
finalized_words = self.chunked_words | ||
# These lines perform a tricky calculation in a dumb way: | ||
# When is each word finalized? When it has gotten all the labels that it will get. | ||
# This number of labels is not constant across the sequence and depends on overlap | ||
# size and on chunk size. To avoid trying to be clever, I just calculate the chunks | ||
# and overlaps and sum up how many times each index will be in a chunk. | ||
else: | ||
# The + chunk size makes calculation takes into account that there will be more | ||
# chunks in future and that we should not finalize prematurely | ||
final_num_preds = [0] * ( | ||
len(self.chunked_words) + self.punct_fixer.word_chunk_size | ||
) | ||
for chunk in self.punct_fixer.split_words_into_chunks( | ||
range(len(self.chunked_words)) | ||
): | ||
for idx in chunk: | ||
final_num_preds[idx] += 1 | ||
finalized_words = [ | ||
word | ||
for i, word in enumerate(self.chunked_words) | ||
if len(word.labels) == final_num_preds[i] | ||
] | ||
return self.punct_fixer.combine_word_predictions_into_final_text( | ||
finalized_words | ||
) | ||
|
||
def process_buffer(self, is_finalized=False) -> bool: | ||
""" | ||
Performs actual punctfixing of content in buffer, updating internal state such that a maximal number | ||
of words get predicted labels. Returns true if new chunks were created and processed and false if not. | ||
""" | ||
new_chunks = [] | ||
# Save how many words were chunked before this call | ||
this_processing_started_at = ( | ||
len(self.chunked_words) - self.punct_fixer.word_overlap | ||
if self.chunked_words | ||
else 0 | ||
) | ||
# Whole chunks are appended unless the stream is finalized in which case, the buffer | ||
# is completely emptied | ||
while len(self.buffer) >= self.punct_fixer.word_chunk_size or ( | ||
is_finalized and self.buffer | ||
): | ||
new_chunks.append( | ||
[word.word for word in self.buffer[: self.punct_fixer.word_chunk_size]] | ||
) | ||
# Not all words are chunked for the first time, we must (except for first time) | ||
# skip the first `word_overlap` words to avoid duplicates. | ||
already_chunked_idx = ( | ||
self.punct_fixer.word_overlap if self.chunked_words else 0 | ||
) | ||
self.chunked_words.extend( | ||
self.buffer[already_chunked_idx : self.punct_fixer.word_chunk_size] | ||
) | ||
# We don't remove the entire buffer length from the buffer as we want | ||
# to emulate the overlap feature of the punctfixer; we leave some in there for next chunk. | ||
self.buffer = self.buffer[ | ||
self.punct_fixer.word_chunk_size - self.punct_fixer.word_overlap : | ||
] | ||
if new_chunks: | ||
# Run the forward pass on all new chunks, matching with the words that are included in them | ||
self.punct_fixer.populate_word_prediction_with_labels( | ||
new_chunks, self.chunked_words[this_processing_started_at:] | ||
) | ||
return True | ||
return False | ||
|
||
def clear(self): | ||
""" | ||
Reset internal state. | ||
""" | ||
self.buffer = [] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this clear all memory from the buffer? just curious There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
self.chunked_words = [] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,7 @@ | |
|
||
from punctfix import PunctFixer | ||
from punctfix.inference import NonNormalizedTextWarning | ||
from punctfix.streaming import PunctFixStreamer | ||
|
||
class CleanupDisableTest(unittest.TestCase): | ||
|
||
|
@@ -206,22 +207,22 @@ def test_do_normalize(self): | |
for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand", | ||
"Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand", | ||
"hejsa % mand ! % "): | ||
actual_output = self.model._split_input_text(model_input) | ||
actual_output = self.model.split_input_text(model_input) | ||
self.assertEqual(actual_output, expected_output) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Switch to pytest instead of unittest? 😊 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree - I have made an issue for that #23 |
||
|
||
def test_warnings(self): | ||
self.model.warn_on_normalization = True | ||
with self.assertWarns(NonNormalizedTextWarning): | ||
model_input = "hejsa, mand" | ||
self.model._split_input_text(model_input) | ||
self.model.split_input_text(model_input) | ||
|
||
with self.assertWarns(NonNormalizedTextWarning): | ||
model_input = "hejsa mand" | ||
self.model._split_input_text(model_input) | ||
self.model.split_input_text(model_input) | ||
|
||
with self.assertWarns(NonNormalizedTextWarning): | ||
model_input = "hejsa Mand" | ||
self.model._split_input_text(model_input) | ||
self.model.split_input_text(model_input) | ||
|
||
def test_do_not_normalize(self): | ||
model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ | ||
|
@@ -232,7 +233,7 @@ def test_do_not_normalize(self): | |
" højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ | ||
"de store folkemasser der er mødt op her på" | ||
expected_output = model_input.split(" ") | ||
actual_output = self.model._split_input_text(model_input) | ||
actual_output = self.model.split_input_text(model_input) | ||
self.assertEqual(actual_output, expected_output) | ||
|
||
class InputParameterTest(unittest.TestCase): | ||
|
@@ -246,6 +247,99 @@ def test_setting_batch_size(self): | |
actual_output = model.punctuate(model_input) | ||
self.assertEqual(actual_output, expected_output) | ||
|
||
class PunctFixStreamerTest(unittest.TestCase): | ||
|
||
def setUp(self) -> None: | ||
super().setUp() | ||
self.streamer = PunctFixStreamer(PunctFixer(language="da")) | ||
|
||
def tearDown(self) -> None: | ||
super().tearDown() | ||
del self.streamer | ||
|
||
def test_sample01(self): | ||
model_inputs = "mit navn det er rasmus", "og jeg kommer", "fra firmaet alvenir",\ | ||
"det er mig", "som har trænet", "denne", "lækre model" | ||
expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ | ||
"Det er mig som har trænet denne lækre model." | ||
|
||
for input_ in model_inputs: | ||
self.streamer(input_) | ||
actual_output = self.streamer.finalize() | ||
self.assertEqual(actual_output, expected_output) | ||
|
||
def test_sample02(self): | ||
model_inputs = "en dag bliver vi sku glade", "for", "at vi nu kan", "sætte punktummer ",\ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sgu 🙃 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Grammatik Babba 🙌😀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I almost find it cute that we still have @Rasmusafj old, funny texts here :P |
||
"og kommaer", "i", "en", "sætning det fungerer da meget", "godt ikke" | ||
expected_output = "En dag bliver vi sku glade for, at vi nu kan sætte punktummer " \ | ||
"og kommaer i en sætning. Det fungerer da meget godt, ikke?" | ||
for input_ in model_inputs: | ||
self.streamer(input_) | ||
actual_output = self.streamer.finalize() | ||
self.assertEqual(actual_output, expected_output) | ||
|
||
def test_sample03(self): | ||
# We want it super loooong | ||
model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ | ||
"sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ | ||
"så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ | ||
"utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ | ||
"mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ | ||
" højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ | ||
"de store folkemasser der er mødt op her på byggepladsen her er ekstra ord " * 2 | ||
expected_output = self.streamer.punct_fixer.punctuate(model_input) | ||
for w in model_input.split(): | ||
partial_output = self.streamer(w) | ||
if partial_output is not None: | ||
self.assertIn(partial_output, expected_output) | ||
actual_output = self.streamer.finalize() | ||
self.assertEqual(actual_output, expected_output) | ||
|
||
def test_repeated_same_input(self): | ||
self.streamer("test") | ||
self.streamer("test") | ||
output = self.streamer.finalize() | ||
expected_output = "Test test." | ||
self.assertEqual(output, expected_output) | ||
|
||
def test_empty_string_input(self): | ||
self.streamer("") | ||
output = self.streamer.finalize() | ||
self.assertEqual(output, "") | ||
|
||
### The below tests are isolated method unit tests | ||
def test_get_result_method(self): | ||
self.streamer("test test") | ||
# Call get_result at an intermediate state | ||
output = self.streamer.get_result() | ||
self.assertEqual(output, "") | ||
# Call get_result at a final state but without processing buffer | ||
output = self.streamer.get_result(is_finalized=True) | ||
self.assertEqual(output, "") | ||
# Call get_result at a final state where there is enough data to make the buffer processed | ||
self.streamer(" ".join(["test"]*98)) # gives a total of 100=chunk size | ||
output = self.streamer.get_result(is_finalized= True) | ||
self.assertEqual(output, ("Test "*100)[:-1]) | ||
|
||
def test_finalize_method(self): | ||
self.streamer("finalizing test") | ||
output = self.streamer.finalize() | ||
expected_output = "Finalizing Test." | ||
self.assertEqual(output, expected_output) | ||
|
||
def test_call_method(self): | ||
self.streamer("test input") | ||
self.assertEqual(len(self.streamer.buffer), 2) | ||
self.assertEqual(len(self.streamer.chunked_words), 0) | ||
self.streamer("test " * 100) | ||
self.assertEqual(len(self.streamer.buffer), 72) # Overlap size +2 | ||
self.assertEqual(len(self.streamer.chunked_words), 100) # Chunk size | ||
|
||
def test_clear_method(self): | ||
self.streamer("clearing test") | ||
self.streamer.clear() | ||
self.assertEqual(self.streamer.buffer, []) | ||
self.assertEqual(self.streamer.chunked_words, []) | ||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How large can this buffer get? - just memory wise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As long as one entire text, (+ some storage for labels) so it would not use any more memory than normal punctfixer, it just keeps the memory for longer