Skip to content

Commit

Permalink
Make normalization warnings opt-in
Browse files Browse the repository at this point in the history
  • Loading branch information
sorenmulli committed Feb 10, 2023
1 parent e9a0930 commit 8d099ea
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
2 changes: 1 addition & 1 deletion linting_config/pylint-configuration.pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ known-third-party=enchant
max-args=10

# Maximum number of attributes for a class (see R0902).
max-attributes=7
max-attributes=10

# Maximum number of boolean expressions in an if statement.
max-bool-expr=5
Expand Down
15 changes: 8 additions & 7 deletions punctfix/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
get_danish_model_and_tokenizer, get_german_model_and_tokenizer


WORD_NORMALIZATION_PATTERN = re.compile(r"[\W_]+")

class NoLanguageOrModelSelect(Exception):
"""
Exception raised if you fail to specify either a language or custom model path.
Expand Down Expand Up @@ -44,30 +46,28 @@ class PunctFixer:
"""
PunctFixer used to punctuate a given text.
"""
word_normalization_pattern = re.compile(r"[\W_]+")

def __init__(self, language: str = "da",
custom_model_path: str = None,
word_overlap: int = 70,
word_chunk_size: int = 100,
device: str = "cpu",
skip_normalization=False,
suppress_normalization_warning=False,):
warn_on_normalization=False,):
"""
:param language: Valid options are "da", "de", "en", for Danish, German and English, respectively.
:param custom_model_path: If you have a trained model yourself. If parsed, then language param will be ignored.
:param word_overlap: How many words should overlap in case text is too long. Defaults to 70.
:param word_chunk_size: How many words should a single pass consist of. Defaults to 100.
:param device: "cpu" or "cuda" to indicate where to run inference. Defaults to "cpu".
:param skip_normalization: Don't check input text and don't normalize it.
:param suppress_normalization_warning: Don't warn about normalizing input text.
No effect if skip_normalization=False.
:param warn_on_normalization: Warn the user if the input text was normalized.
"""

self.word_overlap = word_overlap
self.word_chunk_size = word_chunk_size
self.skip_normalization = skip_normalization
self.suppress_normalization_warning = suppress_normalization_warning
self.warn_on_normalization = warn_on_normalization

self.supported_languages = {
"de": "German",
Expand Down Expand Up @@ -177,6 +177,7 @@ def punctuate(self, text: str) -> str:
Punctuates given text.
:param text: A lowercase text with no punctuation.
If it has punctuatation, it will be removed.
:return: A punctuated text.
"""
words = self._split_input_text(text)
Expand All @@ -203,7 +204,7 @@ def _split_input_text(self, text: str) -> List[str]:
for word in words:
if not word:
to_warn.append("Additional whitespace was removed.")
norm_word = self.word_normalization_pattern.sub("", word)
norm_word = WORD_NORMALIZATION_PATTERN.sub("", word)
if not word:
continue
if len(norm_word) < len(word):
Expand All @@ -214,7 +215,7 @@ def _split_input_text(self, text: str) -> List[str]:
normalized_words.append(norm_word)

# Warn once for each type of normalization
if to_warn and not self.suppress_normalization_warning:
if self.warn_on_normalization and to_warn:
warnings.warn(
"The input text was modified to follow model normalization: " +
" ".join(sorted(set(to_warn))) +
Expand Down
4 changes: 2 additions & 2 deletions tests/test_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,15 +200,15 @@ def tearDown(self) -> None:
self.model = None

def test_do_normalize(self):
self.model.suppress_normalization_warning = True
self.model.warn_on_normalization = False
expected_output = ["hejsa", "mand"]
for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand",
"Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand"):
actual_output = self.model._split_input_text(model_input)
self.assertEqual(actual_output, expected_output)

def test_warnings(self):
self.model.suppress_normalization_warning = False
self.model.warn_on_normalization = True
with self.assertWarns(NonNormalizedTextWarning):
model_input = "hejsa, mand"
self.model._split_input_text(model_input)
Expand Down

0 comments on commit 8d099ea

Please sign in to comment.