Skip to content

Commit

Permalink
Fix bug with non-alphanumerical words
Browse files Browse the repository at this point in the history
As of this commit, normalization of non-alphanum. chars like % or !
is fixed where it previously resulted in error if they were input with
with spacing around them.

Fixes #17.
  • Loading branch information
sorenmulli committed May 17, 2023
1 parent 2a70f80 commit c181e1a
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Its quite simple to use!
'En dag bliver vi sku glade for, at vi nu kan sætte punktummer og kommaer i en sætning. Det fungerer da meget godt, ikke?'
```

Note that, per default, the input text will be normalied. See next section for more details.

## Parameters for PunctFixer
* Pass `device="cuda"` or `device="cpu"` to indicate where to run inference. Default is `device="cpu"`
Expand All @@ -38,7 +39,8 @@ lower acuracy use a chunk size of 150-200 and very little overlap i.e. 5-10. The
default values `word_chunk_size=100`, `word_overlap=70` which makes it run a bit slow. The default parameters
will be updated when we have some results on variations.
* Supported languages are "en" for English, "da" for Danish and "de" for German. Default is `language="da"`.

* Note that the fixer has been trained on normalized text (lowercase letters and numbers) and will per default normalize input text. You can instantiate the model with `skip_normalization=True` to disable this but this might yield errors on some input text.
* To raise warnings every time the input is normalied, set `warn_on_normalization=True`.

## Contribute
If you encounter issues, feel free to open issues in the repo and then we will fix. Even better, create issue and
Expand Down
3 changes: 3 additions & 0 deletions punctfix/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def _split_input_text(self, text: str) -> List[str]:
continue
if len(norm_word) < len(word):
to_warn.append(r"Non-word (r'\W') characters were removed.")
# We might have removed the entire word
if not norm_word:
continue
if not norm_word.islower():
norm_word = norm_word.lower()
to_warn.append("Text was lowercased.")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

setuptools.setup(
name="punctfix",
version="0.10.0",
version="0.10.1",
author="Martin Carsten Nielsen",
author_email="[email protected]",
description="Punctuation restoration library",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def test_do_normalize(self):
self.model.warn_on_normalization = False
expected_output = ["hejsa", "mand"]
for model_input in ("hejsa, mand", " hejsa mand", "hejsa mand",
"Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand"):
"Hejsa mand", "hejsa mand", " hejsa mand", " hejsa, Mand",
"hejsa % mand ! % "):
actual_output = self.model._split_input_text(model_input)
self.assertEqual(actual_output, expected_output)

Expand Down

0 comments on commit c181e1a

Please sign in to comment.