Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add chunk batching #20

Merged
merged 1 commit into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions punctfix/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,27 @@ def __init__(self, language: str = "da",
use_auth_token: Optional[Union[bool, str]] = None,
word_overlap: int = 70,
word_chunk_size: int = 100,
device: str = "cpu",
device: Union[str, torch.device] = torch.device("cpu"),
skip_normalization=False,
warn_on_normalization=False,):
warn_on_normalization=False,
batch_size: int = 1
):
"""
:param language: Valid options are "da", "de", "en", for Danish, German and English, respectively.
:param custom_model_path: If you have a trained model yourself. If parsed, then language param will be ignored.
:param word_overlap: How many words should overlap in case text is too long. Defaults to 70.
:param word_chunk_size: How many words should a single pass consist of. Defaults to 100.
:param device: "cpu" or "cuda" to indicate where to run inference. Defaults to "cpu".
:param device: A torch.device on which to perform inference. The strings "cpu" or "cuda" can also be given.
:param skip_normalization: Don't check input text and don't normalize it.
:param warn_on_normalization: Warn the user if the input text was normalized.
:param batch_size: Number of text chunks to pass through token classification pipeline.
"""

self.word_overlap = word_overlap
self.word_chunk_size = word_chunk_size
self.skip_normalization = skip_normalization
self.warn_on_normalization = warn_on_normalization
self.batch_size = batch_size

self.supported_languages = {
"de": "German",
Expand All @@ -90,7 +94,11 @@ def __init__(self, language: str = "da",

self.tokenizer.decoder.cleanup = False
self.model = self.model.eval()
self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1
if isinstance(device, str): # Backwards compatability
self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1
else:
self.device = device


self.pipe = TokenClassificationPipeline(model=self.model,
tokenizer=self.tokenizer,
Expand Down Expand Up @@ -124,8 +132,8 @@ def populate_word_prediction_with_labels(self, chunks: List[List[str]], word_pre
:param word_prediction_list: A list containing word predictions i.e. word and labels.
:return: Word predictions list with all label predictions for each word
"""
for i, chunk_text in enumerate(chunks):
output = self.pipe(" ".join(chunk_text))
outputs = self.pipe([" ".join(chunk_text) for chunk_text in chunks], batch_size=self.batch_size)
for i, output in enumerate(outputs):
word_counter = 0
for entity in output:
label = entity["entity_group"]
Expand Down
37 changes: 37 additions & 0 deletions scripts/test_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from time import time
import torch
from punctfix import PunctFixer

MODEL_INPUT = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \
"sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \
"så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \
"utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \
"mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \
" højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \
"de store folkemasser der er mødt op her på" * 10

def time_fp(device_str: str, batch_size: int):
print(">>> Profiling device %s on batch size %i" % (device_str, batch_size))
start = time()
model = PunctFixer(language="da", device=device_str, batch_size=batch_size)
print("Initialization time %f" % (time() - start))

# Warmup potential CUDA device
model.punctuate(MODEL_INPUT)

times = []
for _ in range(5):
start = time()
model.punctuate(MODEL_INPUT)
times.append(time() - start)
print("Average time: %f\nStd. time: %f" % (torch.tensor(times).mean().item(), torch.tensor(times).std().item()))


if __name__ == "__main__":
devices = ["cpu"]
batch_sizes = [1, 16, 32, 64]
if torch.cuda.is_available():
devices.append("cuda")
for device in devices:
for batch_size in batch_sizes:
time_fp(device, batch_size)
13 changes: 13 additions & 0 deletions tests/test_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_if_gpu_not_available_default_cpu(self):
device=-1,
ignore_labels=ANY)


def tearDown(self) -> None:
super().tearDown()
self.torch_cuda_patch.stop()
Expand Down Expand Up @@ -234,5 +235,17 @@ def test_do_not_normalize(self):
actual_output = self.model._split_input_text(model_input)
self.assertEqual(actual_output, expected_output)

class InputParameterTest(unittest.TestCase):
def test_setting_batch_size(self):
model_input = "mit navn det er rasmus og jeg kommer fra firmaet alvenir " \
"det er mig som har trænet denne lækre model"
expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \
"Det er mig som har trænet denne lækre model."
for batch_size in 1, 27, 99:
model = PunctFixer(language="da", batch_size=batch_size)
actual_output = model.punctuate(model_input)
self.assertEqual(actual_output, expected_output)


if __name__ == '__main__':
unittest.main()
Loading