Skip to content

Commit

Permalink
Implement eager, streaming punct fixer
Browse files Browse the repository at this point in the history
As of this commit, users can import the PunctFixStreamer which allows
for inputting unfinished segments and getting partial results which can
be trusted as corresponding to a subset of the final result
  • Loading branch information
sorenmulli committed Dec 6, 2023
1 parent 3b2e024 commit 18f01ac
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 7 deletions.
9 changes: 7 additions & 2 deletions punctfix/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def punctuate(self, text: str) -> str:
If it has punctuatation, it will be removed.
:return: A punctuated text.
"""
words = self._split_input_text(text)
words = self.split_input_text(text)

# If we have a long sequence of text (measured by words), we split it into chunks
chunks = []
Expand All @@ -203,7 +203,12 @@ def punctuate(self, text: str) -> str:
word_prediction_list = self.populate_word_prediction_with_labels(chunks, word_prediction_list)
return self.combine_word_predictions_into_final_text(word_prediction_list)

def _split_input_text(self, text: str) -> List[str]:
def split_input_text(self, text: str) -> List[str]:
"""
Splits given text into words using whitespace tokenization, also performing normalization
:param text: A lowercase text with no punctuation (otherwise normalized)
:return: A list of the words in that text, splitted and normalized.
"""
words = text.split(" ")
if self.skip_normalization:
return words
Expand Down
124 changes: 124 additions & 0 deletions punctfix/streaming.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from typing import List, Optional

from punctfix.inference import PunctFixer, WordPrediction


class PunctFixStreamer:
"""
A stateful streamer that receives text in segments, on-line performing punct-fixing and
returning partial results during streaming. These partial results are guaranteed to be
final.
"""

chunked_words: List[WordPrediction]
buffer: List[WordPrediction]

def __init__(self, punct_fixer: PunctFixer):
"""
Takes in an instantiated punct fixer.
"""
self.punct_fixer = punct_fixer
self.clear()

def __call__(self, new_text_segment: str) -> Optional[str]:
"""
Stream in new text, returning None if this new text did not change anything
and the partial, finalized text if there has been updates to it.
"""
self.buffer.extend(
self.punct_fixer.init_word_prediction_list(
self.punct_fixer.split_input_text(new_text_segment)
)
)
if self.process_buffer():
return self.get_result()
return None

def finalize(self):
"""
Mark end of stream and return final puncatuated string.
"""
self.process_buffer(is_finalized=True)
punctuated = self.get_result(is_finalized=True)
self.clear()
return punctuated

def get_result(self, is_finalized=False) -> str:
"""
Returns punctuated string in of all inputs streamed in so far.
If called when not finalized, will only return text that is certain/no longer subject to change
"""
if is_finalized:
finalized_words = self.chunked_words
# These lines perform a tricky calculation in a dumb way:
# When is each word finalized? When it has gotten all the labels that it will get.
# This number of labels is not constant across the sequence and depends on overlap
# size and on chunk size. To avoid trying to be clever, I just calculate the chunks
# and overlaps and sum up how many times each index will be in a chunk.
else:
# The + chunk size makes calculation takes into account that there will be more
# chunks in future and that we should not finalize prematurely
final_num_preds = [0] * (
len(self.chunked_words) + self.punct_fixer.word_chunk_size
)
for chunk in self.punct_fixer.split_words_into_chunks(
range(len(self.chunked_words))
):
for idx in chunk:
final_num_preds[idx] += 1
finalized_words = [
word
for i, word in enumerate(self.chunked_words)
if len(word.labels) == final_num_preds[i]
]
return self.punct_fixer.combine_word_predictions_into_final_text(
finalized_words
)

def process_buffer(self, is_finalized=False) -> bool:
"""
Performs actual punctfixing of content in buffer, updating internal state such that a maximal number
of words get predicted labels. Returns true if new chunks were created and processed and false if not.
"""
new_chunks = []
# Save how many words were chunked before this call
this_processing_started_at = (
len(self.chunked_words) - self.punct_fixer.word_overlap
if self.chunked_words
else 0
)
# Whole chunks are appended unless the stream is finalized in which case, the buffer
# is completely emptied
while len(self.buffer) >= self.punct_fixer.word_chunk_size or (
is_finalized and self.buffer
):
new_chunks.append(
[word.word for word in self.buffer[: self.punct_fixer.word_chunk_size]]
)
# Not all words are chunked for the first time, we must (except for first time)
# skip the first `word_overlap` words to avoid duplicates.
already_chunked_idx = (
self.punct_fixer.word_overlap if self.chunked_words else 0
)
self.chunked_words.extend(
self.buffer[already_chunked_idx : self.punct_fixer.word_chunk_size]
)
# We don't remove the entire buffer length from the buffer as we want
# to emulate the overlap feature of the punctfixer; we leave some in there for next chunk.
self.buffer = self.buffer[
self.punct_fixer.word_chunk_size - self.punct_fixer.word_overlap :
]
if new_chunks:
# Run the forward pass on all new chunks, matching with the words that are included in them
self.punct_fixer.populate_word_prediction_with_labels(
new_chunks, self.chunked_words[this_processing_started_at:]
)
return True
return False

def clear(self):
"""
Reset internal state.
"""
self.buffer = []
self.chunked_words = []
104 changes: 99 additions & 5 deletions tests/test_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from punctfix import PunctFixer
from punctfix.inference import NonNormalizedTextWarning
from punctfix.streaming import PunctFixStreamer

class CleanupDisableTest(unittest.TestCase):

Expand Down Expand Up @@ -206,22 +207,22 @@ def test_do_normalize(self):
for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand",
"Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand",
"hejsa % mand ! % "):
actual_output = self.model._split_input_text(model_input)
actual_output = self.model.split_input_text(model_input)
self.assertEqual(actual_output, expected_output)

def test_warnings(self):
self.model.warn_on_normalization = True
with self.assertWarns(NonNormalizedTextWarning):
model_input = "hejsa, mand"
self.model._split_input_text(model_input)
self.model.split_input_text(model_input)

with self.assertWarns(NonNormalizedTextWarning):
model_input = "hejsa mand"
self.model._split_input_text(model_input)
self.model.split_input_text(model_input)

with self.assertWarns(NonNormalizedTextWarning):
model_input = "hejsa Mand"
self.model._split_input_text(model_input)
self.model.split_input_text(model_input)

def test_do_not_normalize(self):
model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \
Expand All @@ -232,7 +233,7 @@ def test_do_not_normalize(self):
" højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \
"de store folkemasser der er mødt op her på"
expected_output = model_input.split(" ")
actual_output = self.model._split_input_text(model_input)
actual_output = self.model.split_input_text(model_input)
self.assertEqual(actual_output, expected_output)

class InputParameterTest(unittest.TestCase):
Expand All @@ -246,6 +247,99 @@ def test_setting_batch_size(self):
actual_output = model.punctuate(model_input)
self.assertEqual(actual_output, expected_output)

class PunctFixStreamerTest(unittest.TestCase):

def setUp(self) -> None:
super().setUp()
self.streamer = PunctFixStreamer(PunctFixer(language="da"))

def tearDown(self) -> None:
super().tearDown()
del self.streamer

def test_sample01(self):
model_inputs = "mit navn det er rasmus", "og jeg kommer", "fra firmaet alvenir",\
"det er mig", "som har trænet", "denne", "lækre model"
expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \
"Det er mig som har trænet denne lækre model."

for input_ in model_inputs:
self.streamer(input_)
actual_output = self.streamer.finalize()
self.assertEqual(actual_output, expected_output)

def test_sample02(self):
model_inputs = "en dag bliver vi sku glade", "for", "at vi nu kan", "sætte punktummer ",\
"og kommaer", "i", "en", "sætning det fungerer da meget", "godt ikke"
expected_output = "En dag bliver vi sku glade for, at vi nu kan sætte punktummer " \
"og kommaer i en sætning. Det fungerer da meget godt, ikke?"
for input_ in model_inputs:
self.streamer(input_)
actual_output = self.streamer.finalize()
self.assertEqual(actual_output, expected_output)

def test_sample03(self):
# We want it super loooong
model_input = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \
"sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \
"så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \
"utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \
"mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \
" højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \
"de store folkemasser der er mødt op her på byggepladsen her er ekstra ord " * 2
expected_output = self.streamer.punct_fixer.punctuate(model_input)
for w in model_input.split():
partial_output = self.streamer(w)
if partial_output is not None:
self.assertIn(partial_output, expected_output)
actual_output = self.streamer.finalize()
self.assertEqual(actual_output, expected_output)

def test_repeated_same_input(self):
self.streamer("test")
self.streamer("test")
output = self.streamer.finalize()
expected_output = "Test test."
self.assertEqual(output, expected_output)

def test_empty_string_input(self):
self.streamer("")
output = self.streamer.finalize()
self.assertEqual(output, "")

### The below tests are isolated method unit tests
def test_get_result_method(self):
self.streamer("test test")
# Call get_result at an intermediate state
output = self.streamer.get_result()
self.assertEqual(output, "")
# Call get_result at a final state but without processing buffer
output = self.streamer.get_result(is_finalized=True)
self.assertEqual(output, "")
# Call get_result at a final state where there is enough data to make the buffer processed
self.streamer(" ".join(["test"]*98)) # gives a total of 100=chunk size
output = self.streamer.get_result(is_finalized= True)
self.assertEqual(output, ("Test "*100)[:-1])

def test_finalize_method(self):
self.streamer("finalizing test")
output = self.streamer.finalize()
expected_output = "Finalizing Test."
self.assertEqual(output, expected_output)

def test_call_method(self):
self.streamer("test input")
self.assertEqual(len(self.streamer.buffer), 2)
self.assertEqual(len(self.streamer.chunked_words), 0)
self.streamer("test " * 100)
self.assertEqual(len(self.streamer.buffer), 72) # Overlap size +2
self.assertEqual(len(self.streamer.chunked_words), 100) # Chunk size

def test_clear_method(self):
self.streamer("clearing test")
self.streamer.clear()
self.assertEqual(self.streamer.buffer, [])
self.assertEqual(self.streamer.chunked_words, [])

if __name__ == '__main__':
unittest.main()

0 comments on commit 18f01ac

Please sign in to comment.