diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8d2b2b424..cb2efa981 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -36,6 +36,8 @@ repos: rev: 5.12.0 hooks: - id: isort + args: ["--profile", "black"] + files: '\.py$' - repo: https://github.com/sondrelg/pep585-upgrade rev: v1.0.1 diff --git a/prompt2model/utils/__init__.py b/prompt2model/utils/__init__.py index 52c32b6d8..03a4544ea 100644 --- a/prompt2model/utils/__init__.py +++ b/prompt2model/utils/__init__.py @@ -1,11 +1,16 @@ """Import utility functions.""" -from prompt2model.utils.openai_tools import ChatGPTAgent # noqa: F401 -from prompt2model.utils.openai_tools import OPENAI_ERRORS, handle_openai_error +from prompt2model.utils.openai_tools import ( + OPENAI_ERRORS, + ChatGPTAgent, + handle_openai_error, +) from prompt2model.utils.rng import seed_generator +from prompt2model.utils.tevatron_utils import encode_text __all__ = ( # noqa: F401 - "seed_generator", "ChatGPTAgent", - "OPENAI_ERRORS", + "encode_text", "handle_openai_error", + "OPENAI_ERRORS", + "seed_generator", ) diff --git a/prompt2model/utils/tevatron_utils/__init__.py b/prompt2model/utils/tevatron_utils/__init__.py new file mode 100644 index 000000000..5973c85e6 --- /dev/null +++ b/prompt2model/utils/tevatron_utils/__init__.py @@ -0,0 +1,4 @@ +"""Import Tevatron utility functions.""" +from prompt2model.utils.tevatron_utils.encode import encode_text + +__all__ = ["encode_text"] diff --git a/prompt2model/utils/tevatron_utils/encode.py b/prompt2model/utils/tevatron_utils/encode.py new file mode 100644 index 000000000..019879281 --- /dev/null +++ b/prompt2model/utils/tevatron_utils/encode.py @@ -0,0 +1,168 @@ +"""Tools for encoding and serializing a search index with a contextual encoder.""" + +from __future__ import annotations # noqa FI58 + +import json +import os +import pickle +import tempfile +from contextlib import nullcontext + +import numpy as np +import torch +from tevatron.arguments import DataArguments +from tevatron.data import EncodeCollator, EncodeDataset +from tevatron.datasets import HFCorpusDataset, HFQueryDataset +from tevatron.modeling import DenseModelForInference +from torch.utils.data import DataLoader +from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizerBase + + +def load_tevatron_model( + model_name_or_path: str, model_cache_dir: str | None = None +) -> tuple[DenseModelForInference, PreTrainedTokenizerBase]: + """Load a Tevatron model from a model name/path. + + Args: + model_name_or_path: The HuggingFace model name or path to the model. + model_cache_dir: The directory to cache the model. + + Returns: + A Tevatron dense retrieval model and its associated tokenizer. + """ + config = AutoConfig.from_pretrained( + model_name_or_path, + cache_dir=model_cache_dir, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + cache_dir=model_cache_dir, + use_fast=False, + ) + model = DenseModelForInference.build( + model_name_or_path=model_name_or_path, + config=config, + cache_dir=model_cache_dir, + ) + return model, tokenizer + + +def encode_text( + model_name_or_path: str, + file_to_encode: str | None = None, + text_to_encode: list[str] | str | None = None, + encode_query: bool = False, + encoding_file: str | None = None, + max_len: int = 128, + device: torch.device = torch.device("cpu"), + dataloader_num_workers: int = 0, + model_cache_dir: str | None = None, + data_cache_dir: str = "~/.cache/huggingface/datasets", + batch_size=8, + fp16: bool = False, +) -> np.ndarray: + """Encode a query or documents. + + This code is mostly duplicated from tevatron/driver/encode.py in the Tevatron + repository. + + Args: + model_name_or_path: The HuggingFace model name or path to the model. + file_to_encode: JSON or JSONL file containing `"text"` fields to encode. + text_to_encode: String or list of strings to encode. + encode_query: Whether or not we are encoding a query or documents. + encoding_file: If given, store the encoded data in this file. + max_len: Truncate the input to this length (in tokens). + device: Device that Torch will use to encode the text. + dataloader_num_workers: Number of workers to use for the dataloader. + model_cache_dir: The directory to cache the model. + data_cache_dir: The directory to cache the tokenized dataset. + batch_size: Batch size to use for encoding. + fp16: Whether or not to run inference in fp16 for more-efficient encoding. + + Returns: + A numpy array of shape `(num_examples, embedding_dim)` containing text + encoded by the specified model. + """ + model, tokenizer = load_tevatron_model(model_name_or_path, model_cache_dir) + + if file_to_encode is None and text_to_encode is None: + raise ValueError("Must provide either a dataset file or text to encode.") + elif file_to_encode is not None and text_to_encode is not None: + raise ValueError("Provide either a dataset file or text to encode, not both.") + + with tempfile.TemporaryDirectory() as temp_dir: + if text_to_encode is not None: + if isinstance(text_to_encode, str): + text_to_encode = [text_to_encode] + with open( + os.path.join(temp_dir, "text_to_encode.json"), "w" + ) as temporary_file: + text_rows = [ + {"text_id": i, "text": text} + for i, text in enumerate(text_to_encode) + ] + json.dump(text_rows, temporary_file) + file_to_encode = temporary_file.name + temporary_file.close() + + data_args = DataArguments( + encoded_save_path=encoding_file, + encode_in_path=file_to_encode, + encode_is_qry=encode_query, + data_cache_dir=data_cache_dir, + ) + if encode_query: + data_args.q_max_len = max_len + hf_dataset = HFQueryDataset( + tokenizer=tokenizer, + data_args=data_args, + cache_dir=data_args.data_cache_dir or model_cache_dir, + ) + else: + data_args.p_max_len = max_len + hf_dataset = HFCorpusDataset( + tokenizer=tokenizer, + data_args=data_args, + cache_dir=data_args.data_cache_dir or model_cache_dir, + ) + + encode_dataset = EncodeDataset( + hf_dataset.process(1, 0), tokenizer, max_len=max_len + ) + + encode_loader = DataLoader( + encode_dataset, + batch_size=batch_size, + collate_fn=EncodeCollator( + tokenizer, max_length=max_len, padding="max_length" + ), + shuffle=False, + drop_last=False, + num_workers=dataloader_num_workers, + ) + encoded = [] + lookup_indices = [] + model = model.to(device) + model.eval() + + for batch_ids, batch in encode_loader: + lookup_indices.extend(batch_ids) + with torch.cuda.amp.autocast() if fp16 else nullcontext(): + with torch.no_grad(): + for k, v in batch.items(): + batch[k] = v.to(device) + if data_args.encode_is_qry: + model_output = model(query=batch) + encoded.append(model_output.q_reps.cpu().detach().numpy()) + else: + model_output = model(passage=batch) + encoded.append(model_output.p_reps.cpu().detach().numpy()) + + encoded = np.concatenate(encoded) + + if encoding_file: + with open(encoding_file, "wb") as f: + pickle.dump((encoded, lookup_indices), f) + + return encoded diff --git a/pyproject.toml b/pyproject.toml index 2702809ca..58213c041 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "bert_score==0.3.13", "sacrebleu==2.3.1", "evaluate==0.4.0", + "tevatron==0.1.0", + "faiss-cpu==1.7.4" ] dynamic = ["version"] diff --git a/tests/dataset_generator_test.py b/tests/dataset_generator_test.py index fe8031a25..0c2cac2ea 100644 --- a/tests/dataset_generator_test.py +++ b/tests/dataset_generator_test.py @@ -103,7 +103,7 @@ def check_generate_dataset_dict(dataset_generator: OpenAIDatasetGenerator): "prompt2model.utils.ChatGPTAgent.generate_openai_chat_completion", side_effect=MOCK_CLASSIFICATION_EXAMPLE, ) -def test_limited_and_unlimited_generation(mocked_generate_example): +def test_encode_text(mocked_generate_example): """Test classification dataset generation using the OpenAIDatasetGenerator. This function first test the unlimited generation. Then test generation diff --git a/tests/tevatron_utils_test.py b/tests/tevatron_utils_test.py new file mode 100644 index 000000000..820a57247 --- /dev/null +++ b/tests/tevatron_utils_test.py @@ -0,0 +1,78 @@ +"""Testing DatasetGenerator through OpenAIDatasetGenerator.""" + +import json +import os +import pickle +import tempfile + +import pytest +from tevatron.modeling import DenseModelForInference +from transformers import PreTrainedTokenizerBase + +from prompt2model.utils.tevatron_utils import encode_text +from prompt2model.utils.tevatron_utils.encode import load_tevatron_model + +TINY_MODEL_NAME = "google/bert_uncased_L-2_H-128_A-2" + + +def test_load_tevatron_model(): + """Test loading a small Tevatron model.""" + model, tokenizer = load_tevatron_model(TINY_MODEL_NAME) + assert isinstance(model, DenseModelForInference) + assert isinstance(tokenizer, PreTrainedTokenizerBase) + + +def test_encode_text_from_string(): + """Test encoding text from a string into a vector.""" + text = "This is an example sentence" + encoded = encode_text(TINY_MODEL_NAME, text_to_encode=text) + assert encoded.shape == (1, 128) + + +def test_encode_text_from_file(): + """Test encoding text from a file into a vector.""" + text_rows = [ + {"text_id": 0, "text": "This is an example sentence"}, + {"text_id": 1, "text": "This is another example sentence"}, + ] + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f: + json.dump(text_rows, f) + f.seek(0) + encoded = encode_text(TINY_MODEL_NAME, file_to_encode=f.name) + assert encoded.shape == (2, 128) + + +def test_encode_text_from_file_store_to_file(): + """Test encoding text from a file into a vector, then stored to file.""" + text_rows = [ + {"text_id": 0, "text": "This is an example sentence"}, + {"text_id": 1, "text": "This is another example sentence"}, + ] + with tempfile.TemporaryDirectory() as tempdir: + with tempfile.NamedTemporaryFile(mode="w", suffix=".json") as f: + json.dump(text_rows, f) + f.seek(0) + encoding_file_path = os.path.join(tempdir, "encoding.pkl") + encoded = encode_text( + TINY_MODEL_NAME, file_to_encode=f.name, encoding_file=encoding_file_path + ) + assert encoded.shape == (2, 128) + encoded_vectors, encoded_indices = pickle.load( + open(encoding_file_path, "rb") + ) + assert (encoded == encoded_vectors).all() + assert encoded_indices == [0, 1] + + +def test_encode_text_error_from_no_string_or_file(): + """Test that either a string or a file must be passed to encode.""" + with pytest.raises(ValueError): + _ = encode_text(TINY_MODEL_NAME) + + +def test_encode_text_error_from_both_string_and_file(): + """Test that either a string or a file, but not both, must be passed to encode.""" + text = "This is an example sentence" + file = "/tmp/test.txt" + with pytest.raises(ValueError): + _ = encode_text(TINY_MODEL_NAME, file_to_encode=file, text_to_encode=text)