Skip to content

Commit

Permalink
Implement pipeline batching over text chunks
Browse files Browse the repository at this point in the history
  • Loading branch information
sorenmulli committed Nov 28, 2023
1 parent c181e1a commit 5fc6fbb
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 6 deletions.
20 changes: 14 additions & 6 deletions punctfix/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,23 +52,27 @@ def __init__(self, language: str = "da",
use_auth_token: Optional[Union[bool, str]] = None,
word_overlap: int = 70,
word_chunk_size: int = 100,
device: str = "cpu",
device: Union[str, torch.device] = torch.device("cpu"),
skip_normalization=False,
warn_on_normalization=False,):
warn_on_normalization=False,
batch_size: int = 1
):
"""
:param language: Valid options are "da", "de", "en", for Danish, German and English, respectively.
:param custom_model_path: If you have a trained model yourself. If parsed, then language param will be ignored.
:param word_overlap: How many words should overlap in case text is too long. Defaults to 70.
:param word_chunk_size: How many words should a single pass consist of. Defaults to 100.
:param device: "cpu" or "cuda" to indicate where to run inference. Defaults to "cpu".
:param device: A torch.device on which to perform inference. The strings "cpu" or "cuda" can also be given.
:param skip_normalization: Don't check input text and don't normalize it.
:param warn_on_normalization: Warn the user if the input text was normalized.
:param batch_size: Number of text chunks to pass through token classification pipeline.
"""

self.word_overlap = word_overlap
self.word_chunk_size = word_chunk_size
self.skip_normalization = skip_normalization
self.warn_on_normalization = warn_on_normalization
self.batch_size = batch_size

self.supported_languages = {
"de": "German",
Expand All @@ -90,7 +94,11 @@ def __init__(self, language: str = "da",

self.tokenizer.decoder.cleanup = False
self.model = self.model.eval()
self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1
if isinstance(device, str): # Backwards compatability
self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1
else:
self.device = device


self.pipe = TokenClassificationPipeline(model=self.model,
tokenizer=self.tokenizer,
Expand Down Expand Up @@ -124,8 +132,8 @@ def populate_word_prediction_with_labels(self, chunks: List[List[str]], word_pre
:param word_prediction_list: A list containing word predictions i.e. word and labels.
:return: Word predictions list with all label predictions for each word
"""
for i, chunk_text in enumerate(chunks):
output = self.pipe(" ".join(chunk_text))
outputs = self.pipe([" ".join(chunk_text) for chunk_text in chunks], batch_size=self.batch_size)
for i, output in enumerate(outputs):
word_counter = 0
for entity in output:
label = entity["entity_group"]
Expand Down
37 changes: 37 additions & 0 deletions scripts/test_timing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from time import time
import torch
from punctfix import PunctFixer

MODEL_INPUT = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \
"sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \
"så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \
"utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \
"mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \
" højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \
"de store folkemasser der er mødt op her på" * 10

def time_fp(device_str: str, batch_size: int):
print(">>> Profiling device %s on batch size %i" % (device_str, batch_size))
start = time()
model = PunctFixer(language="da", device=device_str, batch_size=batch_size)
print("Initialization time %f" % (time() - start))

# Warmup potential CUDA device
model.punctuate(MODEL_INPUT)

times = []
for _ in range(5):
start = time()
model.punctuate(MODEL_INPUT)
times.append(time() - start)
print("Average time: %f\nStd. time: %f" % (torch.tensor(times).mean().item(), torch.tensor(times).std().item()))


if __name__ == "__main__":
devices = ["cpu"]
batch_sizes = [1, 16, 32, 64]
if torch.cuda.is_available():
devices.append("cuda")
for device in devices:
for batch_size in batch_sizes:
time_fp(device, batch_size)
13 changes: 13 additions & 0 deletions tests/test_punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_if_gpu_not_available_default_cpu(self):
device=-1,
ignore_labels=ANY)


def tearDown(self) -> None:
super().tearDown()
self.torch_cuda_patch.stop()
Expand Down Expand Up @@ -234,5 +235,17 @@ def test_do_not_normalize(self):
actual_output = self.model._split_input_text(model_input)
self.assertEqual(actual_output, expected_output)

class InputParameterTest(unittest.TestCase):
def test_setting_batch_size(self):
model_input = "mit navn det er rasmus og jeg kommer fra firmaet alvenir " \
"det er mig som har trænet denne lækre model"
expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \
"Det er mig som har trænet denne lækre model."
for batch_size in 1, 27, 99:
model = PunctFixer(language="da", batch_size=batch_size)
actual_output = model.punctuate(model_input)
self.assertEqual(actual_output, expected_output)


if __name__ == '__main__':
unittest.main()

0 comments on commit 5fc6fbb

Please sign in to comment.