From 3b2e024f68d8426cead6b1beb649222ada3c65b2 Mon Sep 17 00:00:00 2001 From: sorenmulli Date: Tue, 28 Nov 2023 22:25:05 +0100 Subject: [PATCH] Implement pipeline batching over text chunks --- punctfix/inference.py | 20 ++++++++++++++------ scripts/test_timing.py | 37 +++++++++++++++++++++++++++++++++++++ tests/test_punctuation.py | 13 +++++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) create mode 100644 scripts/test_timing.py diff --git a/punctfix/inference.py b/punctfix/inference.py index 9c6b38b..d585c81 100644 --- a/punctfix/inference.py +++ b/punctfix/inference.py @@ -52,23 +52,27 @@ def __init__(self, language: str = "da", use_auth_token: Optional[Union[bool, str]] = None, word_overlap: int = 70, word_chunk_size: int = 100, - device: str = "cpu", + device: Union[str, torch.device] = torch.device("cpu"), skip_normalization=False, - warn_on_normalization=False,): + warn_on_normalization=False, + batch_size: int = 1 + ): """ :param language: Valid options are "da", "de", "en", for Danish, German and English, respectively. :param custom_model_path: If you have a trained model yourself. If parsed, then language param will be ignored. :param word_overlap: How many words should overlap in case text is too long. Defaults to 70. :param word_chunk_size: How many words should a single pass consist of. Defaults to 100. - :param device: "cpu" or "cuda" to indicate where to run inference. Defaults to "cpu". + :param device: A torch.device on which to perform inference. The strings "cpu" or "cuda" can also be given. :param skip_normalization: Don't check input text and don't normalize it. :param warn_on_normalization: Warn the user if the input text was normalized. + :param batch_size: Number of text chunks to pass through token classification pipeline. """ self.word_overlap = word_overlap self.word_chunk_size = word_chunk_size self.skip_normalization = skip_normalization self.warn_on_normalization = warn_on_normalization + self.batch_size = batch_size self.supported_languages = { "de": "German", @@ -90,7 +94,11 @@ def __init__(self, language: str = "da", self.tokenizer.decoder.cleanup = False self.model = self.model.eval() - self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1 + if isinstance(device, str): # Backwards compatability + self.device = 0 if device == "cuda" and torch.cuda.is_available() else -1 + else: + self.device = device + self.pipe = TokenClassificationPipeline(model=self.model, tokenizer=self.tokenizer, @@ -124,8 +132,8 @@ def populate_word_prediction_with_labels(self, chunks: List[List[str]], word_pre :param word_prediction_list: A list containing word predictions i.e. word and labels. :return: Word predictions list with all label predictions for each word """ - for i, chunk_text in enumerate(chunks): - output = self.pipe(" ".join(chunk_text)) + outputs = self.pipe([" ".join(chunk_text) for chunk_text in chunks], batch_size=self.batch_size) + for i, output in enumerate(outputs): word_counter = 0 for entity in output: label = entity["entity_group"] diff --git a/scripts/test_timing.py b/scripts/test_timing.py new file mode 100644 index 0000000..4dee87a --- /dev/null +++ b/scripts/test_timing.py @@ -0,0 +1,37 @@ +from time import time +import torch +from punctfix import PunctFixer + +MODEL_INPUT = "det der sker over de tre dage fra præsident huden tav ankommer til københavn det er at der " \ + "sådan en bliver spillet sådan et form for tom og jerry kispus mellem københavns politi og " \ + "så de har danske demonstranter for tibet og fåfalungongsom meget gerne vil vise deres " \ + "utilfredshed med det kinesiske regime og det de opfatter som undertrykkelse af de her " \ + "mindretal i kine og lige nu står støttekomiteen for ti bedet bag en demonstration på" \ + " højbro plads i københavn lisbeth davidsen hvor mange er der kommet det er ikke " \ + "de store folkemasser der er mødt op her på" * 10 + +def time_fp(device_str: str, batch_size: int): + print(">>> Profiling device %s on batch size %i" % (device_str, batch_size)) + start = time() + model = PunctFixer(language="da", device=device_str, batch_size=batch_size) + print("Initialization time %f" % (time() - start)) + + # Warmup potential CUDA device + model.punctuate(MODEL_INPUT) + + times = [] + for _ in range(5): + start = time() + model.punctuate(MODEL_INPUT) + times.append(time() - start) + print("Average time: %f\nStd. time: %f" % (torch.tensor(times).mean().item(), torch.tensor(times).std().item())) + + +if __name__ == "__main__": + devices = ["cpu"] + batch_sizes = [1, 16, 32, 64] + if torch.cuda.is_available(): + devices.append("cuda") + for device in devices: + for batch_size in batch_sizes: + time_fp(device, batch_size) diff --git a/tests/test_punctuation.py b/tests/test_punctuation.py index da412c1..6a1f784 100644 --- a/tests/test_punctuation.py +++ b/tests/test_punctuation.py @@ -184,6 +184,7 @@ def test_if_gpu_not_available_default_cpu(self): device=-1, ignore_labels=ANY) + def tearDown(self) -> None: super().tearDown() self.torch_cuda_patch.stop() @@ -234,5 +235,17 @@ def test_do_not_normalize(self): actual_output = self.model._split_input_text(model_input) self.assertEqual(actual_output, expected_output) +class InputParameterTest(unittest.TestCase): + def test_setting_batch_size(self): + model_input = "mit navn det er rasmus og jeg kommer fra firmaet alvenir " \ + "det er mig som har trænet denne lækre model" + expected_output = "Mit navn det er Rasmus og jeg kommer fra firmaet Alvenir. " \ + "Det er mig som har trænet denne lækre model." + for batch_size in 1, 27, 99: + model = PunctFixer(language="da", batch_size=batch_size) + actual_output = model.punctuate(model_input) + self.assertEqual(actual_output, expected_output) + + if __name__ == '__main__': unittest.main()