diff --git a/docs/_src/api/api/question_generator.md b/docs/_src/api/api/question_generator.md index d2d77408af..378333f8f6 100644 --- a/docs/_src/api/api/question_generator.md +++ b/docs/_src/api/api/question_generator.md @@ -23,7 +23,7 @@ come from earlier in the document. #### QuestionGenerator.\_\_init\_\_ ```python -def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", batch_size: Optional[int] = None) +def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", num_queries_per_doc=1, batch_size: Optional[int] = None) ``` Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md index 3f7ec06596..5e9f90f71b 100644 --- a/docs/_src/api/api/retriever.md +++ b/docs/_src/api/api/retriever.md @@ -1433,6 +1433,45 @@ Create embeddings for a list of documents. Embeddings, one per input document + + +#### EmbeddingRetriever.train + +```python +def train(training_data: List[Dict[str, Any]], learning_rate: float = 2e-5, n_epochs: int = 1, num_warmup_steps: int = None, batch_size: int = 16) -> None +``` + +Trains/adapts the underlying embedding model. + +Each training data example is a dictionary with the following keys: + +* question: the question string +* pos_doc: the positive document string +* neg_doc: the negative document string +* score: the score margin + +**Arguments**: + +- `training_data` (`List[Dict[str, Any]]`): The training data +- `learning_rate` (`float`): The learning rate +- `n_epochs` (`int`): The number of epochs +- `num_warmup_steps` (`int`): The number of warmup steps +- `batch_size` (`int (optional)`): The batch size to use for the training, defaults to 16 + + + +#### EmbeddingRetriever.save + +```python +def save(save_dir: Union[Path, str]) -> None +``` + +Save the model to the given directory + +**Arguments**: + +- `save_dir` (`Union[Path, str]`): The directory where the model will be saved + # Module text2sparql diff --git a/haystack/__init__.py b/haystack/__init__.py index 57ada3489e..8b0d9900ef 100644 --- a/haystack/__init__.py +++ b/haystack/__init__.py @@ -79,6 +79,7 @@ def __getattr__(self, attr): retriever, summarizer, translator, + label_generator, ) # Note that we ignore the ImportError here because if the user did not install diff --git a/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json b/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json index 31c5cb47fb..9169ca9fdb 100644 --- a/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json +++ b/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json @@ -110,6 +110,9 @@ { "$ref": "#/definitions/PreProcessorComponent" }, + { + "$ref": "#/definitions/PseudoLabelGeneratorComponent" + }, { "$ref": "#/definitions/QuestionGeneratorComponent" }, @@ -2439,6 +2442,75 @@ ], "additionalProperties": false }, + "PseudoLabelGeneratorComponent": { + "type": "object", + "properties": { + "name": { + "title": "Name", + "description": "Custom name for the component. Helpful for visualization and debugging.", + "type": "string" + }, + "type": { + "title": "Type", + "description": "Haystack Class name for the component.", + "type": "string", + "const": "PseudoLabelGenerator" + }, + "params": { + "title": "Parameters", + "type": "object", + "properties": { + "question_producer": { + "title": "Question Producer", + "anyOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + ] + }, + "retriever": { + "title": "Retriever", + "type": "string" + }, + "cross_encoder_model_name_or_path": { + "title": "Cross Encoder Model Name Or Path", + "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "type": "string" + }, + "total_number_of_questions": { + "title": "Total Number Of Questions", + "default": 9223372036854775807, + "type": "integer" + }, + "top_k": { + "title": "Top K", + "default": 10, + "type": "integer" + } + }, + "required": [ + "question_producer", + "retriever" + ], + "additionalProperties": false, + "description": "Each parameter can reference other components defined in the same YAML file." + } + }, + "required": [ + "type", + "name" + ], + "additionalProperties": false + }, "QuestionGeneratorComponent": { "type": "object", "properties": { diff --git a/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json b/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json index e611e47b94..0eff0d55ab 100644 --- a/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json +++ b/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json @@ -113,6 +113,9 @@ { "$ref": "#/definitions/PreProcessorComponent" }, + { + "$ref": "#/definitions/PseudoLabelGeneratorComponent" + }, { "$ref": "#/definitions/QuestionGeneratorComponent" }, @@ -2618,6 +2621,75 @@ ], "additionalProperties": false }, + "PseudoLabelGeneratorComponent": { + "type": "object", + "properties": { + "name": { + "title": "Name", + "description": "Custom name for the component. Helpful for visualization and debugging.", + "type": "string" + }, + "type": { + "title": "Type", + "description": "Haystack Class name for the component.", + "type": "string", + "const": "PseudoLabelGenerator" + }, + "params": { + "title": "Parameters", + "type": "object", + "properties": { + "question_producer": { + "title": "Question Producer", + "anyOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + ] + }, + "retriever": { + "title": "Retriever", + "type": "string" + }, + "cross_encoder_model_name_or_path": { + "title": "Cross Encoder Model Name Or Path", + "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "type": "string" + }, + "total_number_of_questions": { + "title": "Total Number Of Questions", + "default": 9223372036854775807, + "type": "integer" + }, + "top_k": { + "title": "Top K", + "default": 10, + "type": "integer" + } + }, + "required": [ + "question_producer", + "retriever" + ], + "additionalProperties": false, + "description": "Each parameter can reference other components defined in the same YAML file." + } + }, + "required": [ + "type", + "name" + ], + "additionalProperties": false + }, "QuestionGeneratorComponent": { "type": "object", "properties": { diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json index fa117e876c..8ea55af33e 100644 --- a/haystack/json-schemas/haystack-pipeline-master.schema.json +++ b/haystack/json-schemas/haystack-pipeline-master.schema.json @@ -127,6 +127,9 @@ { "$ref": "#/definitions/PreProcessorComponent" }, + { + "$ref": "#/definitions/PseudoLabelGeneratorComponent" + }, { "$ref": "#/definitions/QuestionGeneratorComponent" }, @@ -3193,6 +3196,85 @@ ], "additionalProperties": false }, + "PseudoLabelGeneratorComponent": { + "type": "object", + "properties": { + "name": { + "title": "Name", + "description": "Custom name for the component. Helpful for visualization and debugging.", + "type": "string" + }, + "type": { + "title": "Type", + "description": "Haystack Class name for the component.", + "type": "string", + "const": "PseudoLabelGenerator" + }, + "params": { + "title": "Parameters", + "type": "object", + "properties": { + "question_producer": { + "title": "Question Producer", + "anyOf": [ + { + "type": "string" + }, + { + "type": "array", + "items": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + ] + }, + "retriever": { + "title": "Retriever", + "type": "string" + }, + "cross_encoder_model_name_or_path": { + "title": "Cross Encoder Model Name Or Path", + "default": "cross-encoder/ms-marco-MiniLM-L-6-v2", + "type": "string" + }, + "max_questions_per_document": { + "title": "Max Questions Per Document", + "default": 3, + "type": "integer" + }, + "top_k": { + "title": "Top K", + "default": 50, + "type": "integer" + }, + "batch_size": { + "title": "Batch Size", + "default": 4, + "type": "integer" + }, + "progress_bar": { + "title": "Progress Bar", + "default": true, + "type": "boolean" + } + }, + "required": [ + "question_producer", + "retriever" + ], + "additionalProperties": false, + "description": "Each parameter can reference other components defined in the same YAML file." + } + }, + "required": [ + "type", + "name" + ], + "additionalProperties": false + }, "QuestionGeneratorComponent": { "type": "object", "properties": { @@ -3254,6 +3336,10 @@ "title": "Prompt", "default": "generate questions:" }, + "num_queries_per_doc": { + "title": "Num Queries Per Doc", + "default": 1 + }, "batch_size": { "title": "Batch Size", "type": "integer" diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py index 3642707e22..ce59531acc 100644 --- a/haystack/nodes/__init__.py +++ b/haystack/nodes/__init__.py @@ -20,6 +20,7 @@ AzureConverter, ParsrConverter, ) +from haystack.nodes.label_generator import PseudoLabelGenerator from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier diff --git a/haystack/nodes/label_generator/__init__.py b/haystack/nodes/label_generator/__init__.py new file mode 100644 index 0000000000..53ec567c4c --- /dev/null +++ b/haystack/nodes/label_generator/__init__.py @@ -0,0 +1 @@ +from haystack.nodes.label_generator.pseudo_label_generator import PseudoLabelGenerator diff --git a/haystack/nodes/label_generator/pseudo_label_generator.py b/haystack/nodes/label_generator/pseudo_label_generator.py new file mode 100644 index 0000000000..c4370ecd8d --- /dev/null +++ b/haystack/nodes/label_generator/pseudo_label_generator.py @@ -0,0 +1,233 @@ +import random +from typing import Dict, Iterable, List, Optional, Tuple, Union + +from sentence_transformers import CrossEncoder +from tqdm.auto import tqdm +from haystack.nodes.base import BaseComponent +from haystack.nodes.question_generator import QuestionGenerator +from haystack.nodes.retriever.base import BaseRetriever +from haystack.schema import Document + + +class PseudoLabelGenerator(BaseComponent): + """ + The PseudoLabelGenerator is a component that creates Generative Pseudo Labeling (GPL) training data for the + training of dense retrievers. + + GPL is an unsupervised domain adaptation method for the training of dense retrievers. It is based on question + generation and pseudo labelling with powerful cross-encoders. To train a domain-adapted model, it needs access + to an unlabeled target corpus, usually via DocumentStore and a retriever to mine for negatives. + + For more details see [https://github.com/UKPLab/gpl](https://github.com/UKPLab/gpl) + + For example: + + ```python + | document_store = DocumentStore(...) + | retriever = Retriever(...) + | qg = QuestionGenerator(model_name_or_path="doc2query/msmarco-t5-base-v1") + | plg = PseudoLabelGenerator(qg, retriever) + | output, output_id = psg.run(documents=document_store.get_all_documents()) + | + ``` + """ + + def __init__( + self, + question_producer: Union[QuestionGenerator, List[Dict[str, str]]], + retriever: BaseRetriever, + cross_encoder_model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", + max_questions_per_document: int = 3, + top_k: int = 50, + batch_size: int = 4, + progress_bar: bool = True, + ): + """ + Loads the cross encoder model and prepares PseudoLabelGenerator. + + :param question_producer: The question producer used to generate questions or a list of already produced + questions/document pairs in Dict format {"question": "question text ...", "document": "document text ..."}. + :type question_producer: Union[QuestionGenerator, List[Dict[str, str]]] + :param retriever: The retriever used to query document stores + :type retriever: BaseRetriever + :param cross_encoder_model_name_or_path: The path to the cross encoder model, defaults to + cross-encoder/ms-marco-MiniLM-L-6-v2 + :type cross_encoder_model_name_or_path: str (optional) + :param max_questions_per_document: The max number of questions generated per document, defaults to 3 + :type max_questions_per_document: int + :param top_k: The number of answers retrieved for each question, defaults to 50 + :type top_k: int (optional) + :param batch_size: Number of documents to process at a time + :type batch_size: int (optional) + """ + + super().__init__() + self.question_document_pairs = None + self.question_generator = None # type: ignore + if isinstance(question_producer, QuestionGenerator): + self.question_generator = question_producer + elif isinstance(question_producer, list) and len(question_producer) > 0: + example = question_producer[0] + if isinstance(example, dict) and "question" in example and "document" in example: + self.question_document_pairs = question_producer + else: + raise ValueError("question_producer list must contain dicts with keys 'question' and 'document'") + else: + raise ValueError("Provide either a QuestionGenerator or nonempty list of questions/document pairs") + + self.retriever = retriever + self.cross_encoder = CrossEncoder(cross_encoder_model_name_or_path) + self.max_questions_per_document = max_questions_per_document + self.top_k = top_k + self.batch_size = batch_size + self.progress_bar = progress_bar + + def generate_questions(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Dict[str, str]]: + """ + It takes a list of documents and generates a list of question-document pairs. + + :param documents: A list of documents to generate questions from + :type documents: List[Document] + :param batch_size: Number of documents to process at a time. + :type batch_size: Optional[int] + :return: A list of question-document pairs. + """ + question_doc_pairs: List[Dict[str, str]] = [] + if self.question_document_pairs: + question_doc_pairs = self.question_document_pairs + else: + batch_size = batch_size if batch_size else self.batch_size + questions: List[List[str]] = self.question_generator.generate_batch( # type: ignore + [d.content for d in documents], batch_size=batch_size + ) + for idx, question_list_per_doc in enumerate(questions): + for q in question_list_per_doc[: self.max_questions_per_document]: # type: ignore + question_doc_pairs.append({"question": q.strip(), "document": documents[idx].content}) + return question_doc_pairs + + def mine_negatives( + self, question_doc_pairs: List[Dict[str, str]], batch_size: Optional[int] = None + ) -> List[Dict[str, str]]: + """ + Given a list of question and pos_doc pairs, this function returns a list of question/pos_doc/neg_doc + dictionaries. + + :param question_doc_pairs: A list of question/pos_doc pairs + :type question_doc_pairs: List[Dict[str, str]] + :param batch_size: The number of queries to run in a batch + :type batch_size: int (optional) + :return: A list of dictionaries, where each dictionary contains the question, positive document, + and negative document. + """ + question_pos_doc_neg_doc: List[Dict[str, str]] = [] + batch_size = batch_size if batch_size else self.batch_size + + for i in tqdm( + range(0, len(question_doc_pairs), batch_size), disable=not self.progress_bar, desc="Mine negatives" + ): + # question in batches to minimize network latency + i_end = min(i + batch_size, len(question_doc_pairs)) + queries: List[str] = [e["question"] for e in question_doc_pairs[i:i_end]] + pos_docs: List[str] = [e["document"] for e in question_doc_pairs[i:i_end]] + + docs: List[List[Document]] = self.retriever.retrieve_batch( + queries=queries, top_k=self.top_k, batch_size=batch_size + ) + + # iterate through queries and find negatives + for question, pos_doc, top_docs in zip(queries, pos_docs, docs): + random.shuffle(top_docs) + for doc_item in top_docs: + neg_doc = doc_item.content + if neg_doc != pos_doc: + question_pos_doc_neg_doc.append({"question": question, "pos_doc": pos_doc, "neg_doc": neg_doc}) + break + return question_pos_doc_neg_doc + + def generate_margin_scores( + self, mined_negatives: List[Dict[str, str]], batch_size: Optional[int] = None + ) -> List[Dict]: + """ + Given a list of mined negatives, predict the score margin between the positive and negative document using + the cross encoder. + + The function returns a list of examples, where each example is a dictionary with the following keys: + + * question: the question string + * pos_doc: the positive document string + * neg_doc: the negative document string + * score: the score margin + + :param mined_negatives: List of mined negatives + :type mined_negatives: List[Dict[str, str]] + :param batch_size: The number of mined negative lists to run in a batch + :type batch_size: int (optional) + :return: A list of dictionaries, each of which has the following keys: + - question: The question string + - pos_doc: The positive document string + - neg_doc: The negative document string + - score: The score margin + """ + examples: List[Dict] = [] + batch_size = batch_size if batch_size else self.batch_size + for i in tqdm(range(0, len(mined_negatives), batch_size), disable=not self.progress_bar, desc="Score margin"): + negatives_batch = mined_negatives[i : i + batch_size] + pb = [] + for item in negatives_batch: + pb.append([item["question"], item["pos_doc"]]) + pb.append([item["question"], item["neg_doc"]]) + scores = self.cross_encoder.predict(pb) + for idx, item in enumerate(negatives_batch): + scores_idx = idx * 2 + score_margin = scores[scores_idx] - scores[scores_idx + 1] + examples.append( + { + "question": item["question"], + "pos_doc": item["pos_doc"], + "neg_doc": item["neg_doc"], + "score": score_margin, + } + ) + return examples + + def generate_pseudo_labels(self, documents: List[Document], batch_size: Optional[int] = None) -> Tuple[dict, str]: + """ + Given a list of documents, generate a list of question-document pairs, mine for negatives, and + score positive/negative margin with cross-encoder. The output is the training data for the + adaptation of dense retriever models. + + :param documents: List[Document] = List of documents to mine negatives from + :type documents: List[Document] + :param batch_size: The number of documents to process in a batch + :type batch_size: Optional[int] + :return: A dictionary with a single key 'gpl_labels' representing a list of dictionaries, where each + dictionary contains the following keys: + - question: the question + - pos_doc: the positive document for the given question + - neg_doc: the negative document for the given question + - score: the margin score (a float) + """ + # see https://github.com/UKPLab/gpl for more information about GPL algorithm + batch_size = batch_size if batch_size else self.batch_size + + # step 1: generate questions + question_doc_pairs = self.generate_questions(documents=documents, batch_size=batch_size) + + # step 2: negative mining + mined_negatives = self.mine_negatives(question_doc_pairs=question_doc_pairs, batch_size=batch_size) + + # step 3: pseudo labeling (scoring) with cross-encoder + pseudo_labels: List[Dict[str, str]] = self.generate_margin_scores(mined_negatives, batch_size=batch_size) + return {"gpl_labels": pseudo_labels}, "output_1" + + def run(self, documents: List[Document]) -> Tuple[dict, str]: # type: ignore + return self.generate_pseudo_labels(documents=documents) + + def run_batch(self, documents: Union[List[Document], List[List[Document]]]) -> Tuple[dict, str]: # type: ignore + flat_list_of_documents = [] + for sub_list_documents in documents: + if isinstance(sub_list_documents, Iterable): + flat_list_of_documents += sub_list_documents + else: + flat_list_of_documents.append(sub_list_documents) + return self.generate_pseudo_labels(documents=flat_list_of_documents) diff --git a/haystack/nodes/question_generator/question_generator.py b/haystack/nodes/question_generator/question_generator.py index d2ab7081a9..7ca6a47735 100644 --- a/haystack/nodes/question_generator/question_generator.py +++ b/haystack/nodes/question_generator/question_generator.py @@ -37,6 +37,7 @@ def __init__( split_overlap=10, use_gpu=True, prompt="generate questions:", + num_queries_per_doc=1, batch_size: Optional[int] = None, ): """ @@ -65,6 +66,7 @@ def __init__( self.split_overlap = split_overlap self.preprocessor = PreProcessor() self.prompt = prompt + self.num_queries_per_doc = num_queries_per_doc self.batch_size = batch_size def run(self, documents: List[Document]): # type: ignore @@ -122,6 +124,7 @@ def generate(self, text: str) -> List[str]: no_repeat_ngram_size=self.no_repeat_ngram_size, length_penalty=self.length_penalty, early_stopping=self.early_stopping, + num_return_sequences=self.num_queries_per_doc, ) string_output = self.tokenizer.batch_decode(tokens_output) @@ -190,6 +193,7 @@ def generate_batch( no_repeat_ngram_size=self.no_repeat_ngram_size, length_penalty=self.length_penalty, early_stopping=self.early_stopping, + num_return_sequences=self.num_queries_per_doc, ) string_output = self.tokenizer.batch_decode(tokens_output) diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py index 918d3277d3..c434959bcd 100644 --- a/haystack/nodes/retriever/_embedding_encoder.py +++ b/haystack/nodes/retriever/_embedding_encoder.py @@ -1,17 +1,20 @@ -from typing import TYPE_CHECKING, Callable, List, Union, Dict - import logging from abc import abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union + import numpy as np -from tqdm.auto import tqdm import torch +from sentence_transformers import InputExample, losses +from torch.utils.data import DataLoader from torch.utils.data.sampler import SequentialSampler -from transformers import AutoTokenizer, AutoModel +from tqdm.auto import tqdm +from transformers import AutoModel, AutoTokenizer -from haystack.schema import Document +from haystack.modeling.data_handler.dataloader import NamedDataLoader from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename from haystack.modeling.infer import Inferencer -from haystack.modeling.data_handler.dataloader import NamedDataLoader +from haystack.schema import Document if TYPE_CHECKING: from haystack.nodes.retriever import EmbeddingRetriever @@ -41,6 +44,47 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: """ pass + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: int = None, + batch_size: int = 16, + ): + """ + Trains/adapts the underlying embedding model. + + Each training data example is a dictionary with the following keys: + + * question: the question string + * pos_doc: the positive document string + * neg_doc: the negative document string + * score: the score margin + + + :param training_data: The training data + :type training_data: List[Dict[str, Any]] + :param learning_rate: The learning rate + :type learning_rate: float + :param n_epochs: The number of training epochs + :type n_epochs: int + :param num_warmup_steps: The number of warmup steps + :type num_warmup_steps: int + :param batch_size: The batch size to use for the training, defaults to 16 + :type batch_size: int (optional) + """ + pass + + def save(self, save_dir: Union[Path, str]): + """ + Save the model to the given directory + + :param save_dir: The directory where the model will be saved + :type save_dir: Union[Path, str] + """ + pass + class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder): def __init__(self, retriever: "EmbeddingRetriever"): @@ -87,6 +131,19 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: passages = [d.content for d in docs] # type: ignore return self.embed(passages) + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: int = None, + batch_size: int = 16, + ): + raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)") + + def save(self, save_dir: Union[Path, str]): + raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)") + class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder): def __init__(self, retriever: "EmbeddingRetriever"): @@ -127,6 +184,33 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]: passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore return self.embed(passages) + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: int = None, + batch_size: int = 16, + ): + + train_examples = [ + InputExample(texts=[i["question"], i["pos_doc"], i["neg_doc"]], label=i["score"]) for i in training_data + ] + logger.info(f"GPL training/adapting {self.embedding_model} with {len(train_examples)} examples") + train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True) + train_loss = losses.MarginMSELoss(self.embedding_model) + + # Tune the model + self.embedding_model.fit( + train_objectives=[(train_dataloader, train_loss)], + epochs=n_epochs, + optimizer_params={"lr": learning_rate}, + warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps, + ) + + def save(self, save_dir: Union[Path, str]): + self.embedding_model.save(path=str(save_dir)) + class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder): def __init__(self, retriever: "EmbeddingRetriever"): @@ -208,6 +292,19 @@ def dataset_from_dicts(self, dicts: List[dict]): dataset, tensornames = convert_features_to_dataset(features=features_flat) return dataset, tensornames + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: int = None, + batch_size: int = 16, + ): + raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)") + + def save(self, save_dir: Union[Path, str]): + raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)") + _EMBEDDING_ENCODERS: Dict[str, Callable] = { "farm": _DefaultEmbeddingEncoder, diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py index ed568da4cf..c721cbb0b3 100644 --- a/haystack/nodes/retriever/dense.py +++ b/haystack/nodes/retriever/dense.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Union, Optional +from typing import List, Dict, Union, Optional, Any import logging from pathlib import Path @@ -1851,3 +1851,50 @@ def _infer_model_format(model_name_or_path: str) -> str: # Model is neither sentence-transformers nor retribert model -> use _DefaultEmbeddingEncoder return "farm" + + def train( + self, + training_data: List[Dict[str, Any]], + learning_rate: float = 2e-5, + n_epochs: int = 1, + num_warmup_steps: int = None, + batch_size: int = 16, + ) -> None: + """ + Trains/adapts the underlying embedding model. + + Each training data example is a dictionary with the following keys: + + * question: the question string + * pos_doc: the positive document string + * neg_doc: the negative document string + * score: the score margin + + + :param training_data: The training data + :type training_data: List[Dict[str, Any]] + :param learning_rate: The learning rate + :type learning_rate: float + :param n_epochs: The number of epochs + :type n_epochs: int + :param num_warmup_steps: The number of warmup steps + :type num_warmup_steps: int + :param batch_size: The batch size to use for the training, defaults to 16 + :type batch_size: int (optional) + """ + self.embedding_encoder.train( + training_data, + learning_rate=learning_rate, + n_epochs=n_epochs, + num_warmup_steps=num_warmup_steps, + batch_size=batch_size, + ) + + def save(self, save_dir: Union[Path, str]) -> None: + """ + Save the model to the given directory + + :param save_dir: The directory where the model will be saved + :type save_dir: Union[Path, str] + """ + self.embedding_encoder.save(save_dir=save_dir) diff --git a/test/conftest.py b/test/conftest.py index ded06bb3e1..c4eb70f303 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -672,6 +672,13 @@ def get_retriever(retriever_type, document_store): retriever = EmbeddingRetriever( document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False ) + elif retriever_type == "embedding_sbert": + retriever = EmbeddingRetriever( + document_store=document_store, + embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b", + model_format="sentence_transformers", + use_gpu=False, + ) elif retriever_type == "retribert": retriever = EmbeddingRetriever( document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False diff --git a/test/nodes/test_label_generator.py b/test/nodes/test_label_generator.py new file mode 100644 index 0000000000..070fd94d6c --- /dev/null +++ b/test/nodes/test_label_generator.py @@ -0,0 +1,62 @@ +from pathlib import Path + +import pytest + +from haystack.nodes import QuestionGenerator, EmbeddingRetriever, PseudoLabelGenerator +from test.conftest import DOCS_WITH_EMBEDDINGS + + +@pytest.mark.slow +@pytest.mark.generator +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True) +def test_pseudo_label_generator( + document_store, retriever: EmbeddingRetriever, question_generator: QuestionGenerator, tmp_path: Path +): + document_store.write_documents(DOCS_WITH_EMBEDDINGS) + psg = PseudoLabelGenerator(question_generator, retriever) + train_examples = [] + for idx, doc in enumerate(document_store): + output, stream = psg.run(documents=[doc]) + assert "gpl_labels" in output + for item in output["gpl_labels"]: + assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item + train_examples.append(item) + + assert len(train_examples) > 0 + retriever.train(train_examples) + retriever.save(tmp_path) + + +@pytest.mark.slow +@pytest.mark.generator +@pytest.mark.parametrize("document_store", ["memory"], indirect=True) +@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True) +def test_pseudo_label_generator_using_question_document_pairs( + document_store, retriever: EmbeddingRetriever, tmp_path: Path +): + document_store.write_documents(DOCS_WITH_EMBEDDINGS) + docs = [ + { + "question": "What is the capital of Germany?", + "document": "Berlin is the capital and largest city of Germany by both area and population.", + }, + { + "question": "What is the largest city in Germany by population and area?", + "document": "Berlin is the capital and largest city of Germany by both area and population.", + }, + ] + psg = PseudoLabelGenerator(docs, retriever) + train_examples = [] + for idx, doc in enumerate(document_store): + # the documents passed here are ignored as we provided source documents in the constructor + output, stream = psg.run(documents=[doc]) + assert "gpl_labels" in output + for item in output["gpl_labels"]: + assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item + train_examples.append(item) + + assert len(train_examples) > 0 + + retriever.train(train_examples) + retriever.save(tmp_path)