diff --git a/docs/_src/api/api/question_generator.md b/docs/_src/api/api/question_generator.md
index d2d77408af..378333f8f6 100644
--- a/docs/_src/api/api/question_generator.md
+++ b/docs/_src/api/api/question_generator.md
@@ -23,7 +23,7 @@ come from earlier in the document.
#### QuestionGenerator.\_\_init\_\_
```python
-def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", batch_size: Optional[int] = None)
+def __init__(model_name_or_path="valhalla/t5-base-e2e-qg", model_version=None, num_beams=4, max_length=256, no_repeat_ngram_size=3, length_penalty=1.5, early_stopping=True, split_length=50, split_overlap=10, use_gpu=True, prompt="generate questions:", num_queries_per_doc=1, batch_size: Optional[int] = None)
```
Uses the valhalla/t5-base-e2e-qg model by default. This class supports any question generation model that is
diff --git a/docs/_src/api/api/retriever.md b/docs/_src/api/api/retriever.md
index 3f7ec06596..5e9f90f71b 100644
--- a/docs/_src/api/api/retriever.md
+++ b/docs/_src/api/api/retriever.md
@@ -1433,6 +1433,45 @@ Create embeddings for a list of documents.
Embeddings, one per input document
+
+
+#### EmbeddingRetriever.train
+
+```python
+def train(training_data: List[Dict[str, Any]], learning_rate: float = 2e-5, n_epochs: int = 1, num_warmup_steps: int = None, batch_size: int = 16) -> None
+```
+
+Trains/adapts the underlying embedding model.
+
+Each training data example is a dictionary with the following keys:
+
+* question: the question string
+* pos_doc: the positive document string
+* neg_doc: the negative document string
+* score: the score margin
+
+**Arguments**:
+
+- `training_data` (`List[Dict[str, Any]]`): The training data
+- `learning_rate` (`float`): The learning rate
+- `n_epochs` (`int`): The number of epochs
+- `num_warmup_steps` (`int`): The number of warmup steps
+- `batch_size` (`int (optional)`): The batch size to use for the training, defaults to 16
+
+
+
+#### EmbeddingRetriever.save
+
+```python
+def save(save_dir: Union[Path, str]) -> None
+```
+
+Save the model to the given directory
+
+**Arguments**:
+
+- `save_dir` (`Union[Path, str]`): The directory where the model will be saved
+
# Module text2sparql
diff --git a/haystack/__init__.py b/haystack/__init__.py
index 57ada3489e..8b0d9900ef 100644
--- a/haystack/__init__.py
+++ b/haystack/__init__.py
@@ -79,6 +79,7 @@ def __getattr__(self, attr):
retriever,
summarizer,
translator,
+ label_generator,
)
# Note that we ignore the ImportError here because if the user did not install
diff --git a/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json b/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json
index 31c5cb47fb..9169ca9fdb 100644
--- a/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json
+++ b/haystack/json-schemas/haystack-pipeline-1.1.0.schema.json
@@ -110,6 +110,9 @@
{
"$ref": "#/definitions/PreProcessorComponent"
},
+ {
+ "$ref": "#/definitions/PseudoLabelGeneratorComponent"
+ },
{
"$ref": "#/definitions/QuestionGeneratorComponent"
},
@@ -2439,6 +2442,75 @@
],
"additionalProperties": false
},
+ "PseudoLabelGeneratorComponent": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "title": "Name",
+ "description": "Custom name for the component. Helpful for visualization and debugging.",
+ "type": "string"
+ },
+ "type": {
+ "title": "Type",
+ "description": "Haystack Class name for the component.",
+ "type": "string",
+ "const": "PseudoLabelGenerator"
+ },
+ "params": {
+ "title": "Parameters",
+ "type": "object",
+ "properties": {
+ "question_producer": {
+ "title": "Question Producer",
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ ]
+ },
+ "retriever": {
+ "title": "Retriever",
+ "type": "string"
+ },
+ "cross_encoder_model_name_or_path": {
+ "title": "Cross Encoder Model Name Or Path",
+ "default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
+ "type": "string"
+ },
+ "total_number_of_questions": {
+ "title": "Total Number Of Questions",
+ "default": 9223372036854775807,
+ "type": "integer"
+ },
+ "top_k": {
+ "title": "Top K",
+ "default": 10,
+ "type": "integer"
+ }
+ },
+ "required": [
+ "question_producer",
+ "retriever"
+ ],
+ "additionalProperties": false,
+ "description": "Each parameter can reference other components defined in the same YAML file."
+ }
+ },
+ "required": [
+ "type",
+ "name"
+ ],
+ "additionalProperties": false
+ },
"QuestionGeneratorComponent": {
"type": "object",
"properties": {
diff --git a/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json b/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json
index e611e47b94..0eff0d55ab 100644
--- a/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json
+++ b/haystack/json-schemas/haystack-pipeline-1.3.0.schema.json
@@ -113,6 +113,9 @@
{
"$ref": "#/definitions/PreProcessorComponent"
},
+ {
+ "$ref": "#/definitions/PseudoLabelGeneratorComponent"
+ },
{
"$ref": "#/definitions/QuestionGeneratorComponent"
},
@@ -2618,6 +2621,75 @@
],
"additionalProperties": false
},
+ "PseudoLabelGeneratorComponent": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "title": "Name",
+ "description": "Custom name for the component. Helpful for visualization and debugging.",
+ "type": "string"
+ },
+ "type": {
+ "title": "Type",
+ "description": "Haystack Class name for the component.",
+ "type": "string",
+ "const": "PseudoLabelGenerator"
+ },
+ "params": {
+ "title": "Parameters",
+ "type": "object",
+ "properties": {
+ "question_producer": {
+ "title": "Question Producer",
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ ]
+ },
+ "retriever": {
+ "title": "Retriever",
+ "type": "string"
+ },
+ "cross_encoder_model_name_or_path": {
+ "title": "Cross Encoder Model Name Or Path",
+ "default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
+ "type": "string"
+ },
+ "total_number_of_questions": {
+ "title": "Total Number Of Questions",
+ "default": 9223372036854775807,
+ "type": "integer"
+ },
+ "top_k": {
+ "title": "Top K",
+ "default": 10,
+ "type": "integer"
+ }
+ },
+ "required": [
+ "question_producer",
+ "retriever"
+ ],
+ "additionalProperties": false,
+ "description": "Each parameter can reference other components defined in the same YAML file."
+ }
+ },
+ "required": [
+ "type",
+ "name"
+ ],
+ "additionalProperties": false
+ },
"QuestionGeneratorComponent": {
"type": "object",
"properties": {
diff --git a/haystack/json-schemas/haystack-pipeline-master.schema.json b/haystack/json-schemas/haystack-pipeline-master.schema.json
index fa117e876c..8ea55af33e 100644
--- a/haystack/json-schemas/haystack-pipeline-master.schema.json
+++ b/haystack/json-schemas/haystack-pipeline-master.schema.json
@@ -127,6 +127,9 @@
{
"$ref": "#/definitions/PreProcessorComponent"
},
+ {
+ "$ref": "#/definitions/PseudoLabelGeneratorComponent"
+ },
{
"$ref": "#/definitions/QuestionGeneratorComponent"
},
@@ -3193,6 +3196,85 @@
],
"additionalProperties": false
},
+ "PseudoLabelGeneratorComponent": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "title": "Name",
+ "description": "Custom name for the component. Helpful for visualization and debugging.",
+ "type": "string"
+ },
+ "type": {
+ "title": "Type",
+ "description": "Haystack Class name for the component.",
+ "type": "string",
+ "const": "PseudoLabelGenerator"
+ },
+ "params": {
+ "title": "Parameters",
+ "type": "object",
+ "properties": {
+ "question_producer": {
+ "title": "Question Producer",
+ "anyOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "additionalProperties": {
+ "type": "string"
+ }
+ }
+ }
+ ]
+ },
+ "retriever": {
+ "title": "Retriever",
+ "type": "string"
+ },
+ "cross_encoder_model_name_or_path": {
+ "title": "Cross Encoder Model Name Or Path",
+ "default": "cross-encoder/ms-marco-MiniLM-L-6-v2",
+ "type": "string"
+ },
+ "max_questions_per_document": {
+ "title": "Max Questions Per Document",
+ "default": 3,
+ "type": "integer"
+ },
+ "top_k": {
+ "title": "Top K",
+ "default": 50,
+ "type": "integer"
+ },
+ "batch_size": {
+ "title": "Batch Size",
+ "default": 4,
+ "type": "integer"
+ },
+ "progress_bar": {
+ "title": "Progress Bar",
+ "default": true,
+ "type": "boolean"
+ }
+ },
+ "required": [
+ "question_producer",
+ "retriever"
+ ],
+ "additionalProperties": false,
+ "description": "Each parameter can reference other components defined in the same YAML file."
+ }
+ },
+ "required": [
+ "type",
+ "name"
+ ],
+ "additionalProperties": false
+ },
"QuestionGeneratorComponent": {
"type": "object",
"properties": {
@@ -3254,6 +3336,10 @@
"title": "Prompt",
"default": "generate questions:"
},
+ "num_queries_per_doc": {
+ "title": "Num Queries Per Doc",
+ "default": 1
+ },
"batch_size": {
"title": "Batch Size",
"type": "integer"
diff --git a/haystack/nodes/__init__.py b/haystack/nodes/__init__.py
index 3642707e22..ce59531acc 100644
--- a/haystack/nodes/__init__.py
+++ b/haystack/nodes/__init__.py
@@ -20,6 +20,7 @@
AzureConverter,
ParsrConverter,
)
+from haystack.nodes.label_generator import PseudoLabelGenerator
from haystack.nodes.other import Docs2Answers, JoinDocuments, RouteDocuments, JoinAnswers
from haystack.nodes.preprocessor import BasePreProcessor, PreProcessor
from haystack.nodes.query_classifier import SklearnQueryClassifier, TransformersQueryClassifier
diff --git a/haystack/nodes/label_generator/__init__.py b/haystack/nodes/label_generator/__init__.py
new file mode 100644
index 0000000000..53ec567c4c
--- /dev/null
+++ b/haystack/nodes/label_generator/__init__.py
@@ -0,0 +1 @@
+from haystack.nodes.label_generator.pseudo_label_generator import PseudoLabelGenerator
diff --git a/haystack/nodes/label_generator/pseudo_label_generator.py b/haystack/nodes/label_generator/pseudo_label_generator.py
new file mode 100644
index 0000000000..c4370ecd8d
--- /dev/null
+++ b/haystack/nodes/label_generator/pseudo_label_generator.py
@@ -0,0 +1,233 @@
+import random
+from typing import Dict, Iterable, List, Optional, Tuple, Union
+
+from sentence_transformers import CrossEncoder
+from tqdm.auto import tqdm
+from haystack.nodes.base import BaseComponent
+from haystack.nodes.question_generator import QuestionGenerator
+from haystack.nodes.retriever.base import BaseRetriever
+from haystack.schema import Document
+
+
+class PseudoLabelGenerator(BaseComponent):
+ """
+ The PseudoLabelGenerator is a component that creates Generative Pseudo Labeling (GPL) training data for the
+ training of dense retrievers.
+
+ GPL is an unsupervised domain adaptation method for the training of dense retrievers. It is based on question
+ generation and pseudo labelling with powerful cross-encoders. To train a domain-adapted model, it needs access
+ to an unlabeled target corpus, usually via DocumentStore and a retriever to mine for negatives.
+
+ For more details see [https://github.com/UKPLab/gpl](https://github.com/UKPLab/gpl)
+
+ For example:
+
+ ```python
+ | document_store = DocumentStore(...)
+ | retriever = Retriever(...)
+ | qg = QuestionGenerator(model_name_or_path="doc2query/msmarco-t5-base-v1")
+ | plg = PseudoLabelGenerator(qg, retriever)
+ | output, output_id = psg.run(documents=document_store.get_all_documents())
+ |
+ ```
+ """
+
+ def __init__(
+ self,
+ question_producer: Union[QuestionGenerator, List[Dict[str, str]]],
+ retriever: BaseRetriever,
+ cross_encoder_model_name_or_path: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
+ max_questions_per_document: int = 3,
+ top_k: int = 50,
+ batch_size: int = 4,
+ progress_bar: bool = True,
+ ):
+ """
+ Loads the cross encoder model and prepares PseudoLabelGenerator.
+
+ :param question_producer: The question producer used to generate questions or a list of already produced
+ questions/document pairs in Dict format {"question": "question text ...", "document": "document text ..."}.
+ :type question_producer: Union[QuestionGenerator, List[Dict[str, str]]]
+ :param retriever: The retriever used to query document stores
+ :type retriever: BaseRetriever
+ :param cross_encoder_model_name_or_path: The path to the cross encoder model, defaults to
+ cross-encoder/ms-marco-MiniLM-L-6-v2
+ :type cross_encoder_model_name_or_path: str (optional)
+ :param max_questions_per_document: The max number of questions generated per document, defaults to 3
+ :type max_questions_per_document: int
+ :param top_k: The number of answers retrieved for each question, defaults to 50
+ :type top_k: int (optional)
+ :param batch_size: Number of documents to process at a time
+ :type batch_size: int (optional)
+ """
+
+ super().__init__()
+ self.question_document_pairs = None
+ self.question_generator = None # type: ignore
+ if isinstance(question_producer, QuestionGenerator):
+ self.question_generator = question_producer
+ elif isinstance(question_producer, list) and len(question_producer) > 0:
+ example = question_producer[0]
+ if isinstance(example, dict) and "question" in example and "document" in example:
+ self.question_document_pairs = question_producer
+ else:
+ raise ValueError("question_producer list must contain dicts with keys 'question' and 'document'")
+ else:
+ raise ValueError("Provide either a QuestionGenerator or nonempty list of questions/document pairs")
+
+ self.retriever = retriever
+ self.cross_encoder = CrossEncoder(cross_encoder_model_name_or_path)
+ self.max_questions_per_document = max_questions_per_document
+ self.top_k = top_k
+ self.batch_size = batch_size
+ self.progress_bar = progress_bar
+
+ def generate_questions(self, documents: List[Document], batch_size: Optional[int] = None) -> List[Dict[str, str]]:
+ """
+ It takes a list of documents and generates a list of question-document pairs.
+
+ :param documents: A list of documents to generate questions from
+ :type documents: List[Document]
+ :param batch_size: Number of documents to process at a time.
+ :type batch_size: Optional[int]
+ :return: A list of question-document pairs.
+ """
+ question_doc_pairs: List[Dict[str, str]] = []
+ if self.question_document_pairs:
+ question_doc_pairs = self.question_document_pairs
+ else:
+ batch_size = batch_size if batch_size else self.batch_size
+ questions: List[List[str]] = self.question_generator.generate_batch( # type: ignore
+ [d.content for d in documents], batch_size=batch_size
+ )
+ for idx, question_list_per_doc in enumerate(questions):
+ for q in question_list_per_doc[: self.max_questions_per_document]: # type: ignore
+ question_doc_pairs.append({"question": q.strip(), "document": documents[idx].content})
+ return question_doc_pairs
+
+ def mine_negatives(
+ self, question_doc_pairs: List[Dict[str, str]], batch_size: Optional[int] = None
+ ) -> List[Dict[str, str]]:
+ """
+ Given a list of question and pos_doc pairs, this function returns a list of question/pos_doc/neg_doc
+ dictionaries.
+
+ :param question_doc_pairs: A list of question/pos_doc pairs
+ :type question_doc_pairs: List[Dict[str, str]]
+ :param batch_size: The number of queries to run in a batch
+ :type batch_size: int (optional)
+ :return: A list of dictionaries, where each dictionary contains the question, positive document,
+ and negative document.
+ """
+ question_pos_doc_neg_doc: List[Dict[str, str]] = []
+ batch_size = batch_size if batch_size else self.batch_size
+
+ for i in tqdm(
+ range(0, len(question_doc_pairs), batch_size), disable=not self.progress_bar, desc="Mine negatives"
+ ):
+ # question in batches to minimize network latency
+ i_end = min(i + batch_size, len(question_doc_pairs))
+ queries: List[str] = [e["question"] for e in question_doc_pairs[i:i_end]]
+ pos_docs: List[str] = [e["document"] for e in question_doc_pairs[i:i_end]]
+
+ docs: List[List[Document]] = self.retriever.retrieve_batch(
+ queries=queries, top_k=self.top_k, batch_size=batch_size
+ )
+
+ # iterate through queries and find negatives
+ for question, pos_doc, top_docs in zip(queries, pos_docs, docs):
+ random.shuffle(top_docs)
+ for doc_item in top_docs:
+ neg_doc = doc_item.content
+ if neg_doc != pos_doc:
+ question_pos_doc_neg_doc.append({"question": question, "pos_doc": pos_doc, "neg_doc": neg_doc})
+ break
+ return question_pos_doc_neg_doc
+
+ def generate_margin_scores(
+ self, mined_negatives: List[Dict[str, str]], batch_size: Optional[int] = None
+ ) -> List[Dict]:
+ """
+ Given a list of mined negatives, predict the score margin between the positive and negative document using
+ the cross encoder.
+
+ The function returns a list of examples, where each example is a dictionary with the following keys:
+
+ * question: the question string
+ * pos_doc: the positive document string
+ * neg_doc: the negative document string
+ * score: the score margin
+
+ :param mined_negatives: List of mined negatives
+ :type mined_negatives: List[Dict[str, str]]
+ :param batch_size: The number of mined negative lists to run in a batch
+ :type batch_size: int (optional)
+ :return: A list of dictionaries, each of which has the following keys:
+ - question: The question string
+ - pos_doc: The positive document string
+ - neg_doc: The negative document string
+ - score: The score margin
+ """
+ examples: List[Dict] = []
+ batch_size = batch_size if batch_size else self.batch_size
+ for i in tqdm(range(0, len(mined_negatives), batch_size), disable=not self.progress_bar, desc="Score margin"):
+ negatives_batch = mined_negatives[i : i + batch_size]
+ pb = []
+ for item in negatives_batch:
+ pb.append([item["question"], item["pos_doc"]])
+ pb.append([item["question"], item["neg_doc"]])
+ scores = self.cross_encoder.predict(pb)
+ for idx, item in enumerate(negatives_batch):
+ scores_idx = idx * 2
+ score_margin = scores[scores_idx] - scores[scores_idx + 1]
+ examples.append(
+ {
+ "question": item["question"],
+ "pos_doc": item["pos_doc"],
+ "neg_doc": item["neg_doc"],
+ "score": score_margin,
+ }
+ )
+ return examples
+
+ def generate_pseudo_labels(self, documents: List[Document], batch_size: Optional[int] = None) -> Tuple[dict, str]:
+ """
+ Given a list of documents, generate a list of question-document pairs, mine for negatives, and
+ score positive/negative margin with cross-encoder. The output is the training data for the
+ adaptation of dense retriever models.
+
+ :param documents: List[Document] = List of documents to mine negatives from
+ :type documents: List[Document]
+ :param batch_size: The number of documents to process in a batch
+ :type batch_size: Optional[int]
+ :return: A dictionary with a single key 'gpl_labels' representing a list of dictionaries, where each
+ dictionary contains the following keys:
+ - question: the question
+ - pos_doc: the positive document for the given question
+ - neg_doc: the negative document for the given question
+ - score: the margin score (a float)
+ """
+ # see https://github.com/UKPLab/gpl for more information about GPL algorithm
+ batch_size = batch_size if batch_size else self.batch_size
+
+ # step 1: generate questions
+ question_doc_pairs = self.generate_questions(documents=documents, batch_size=batch_size)
+
+ # step 2: negative mining
+ mined_negatives = self.mine_negatives(question_doc_pairs=question_doc_pairs, batch_size=batch_size)
+
+ # step 3: pseudo labeling (scoring) with cross-encoder
+ pseudo_labels: List[Dict[str, str]] = self.generate_margin_scores(mined_negatives, batch_size=batch_size)
+ return {"gpl_labels": pseudo_labels}, "output_1"
+
+ def run(self, documents: List[Document]) -> Tuple[dict, str]: # type: ignore
+ return self.generate_pseudo_labels(documents=documents)
+
+ def run_batch(self, documents: Union[List[Document], List[List[Document]]]) -> Tuple[dict, str]: # type: ignore
+ flat_list_of_documents = []
+ for sub_list_documents in documents:
+ if isinstance(sub_list_documents, Iterable):
+ flat_list_of_documents += sub_list_documents
+ else:
+ flat_list_of_documents.append(sub_list_documents)
+ return self.generate_pseudo_labels(documents=flat_list_of_documents)
diff --git a/haystack/nodes/question_generator/question_generator.py b/haystack/nodes/question_generator/question_generator.py
index d2ab7081a9..7ca6a47735 100644
--- a/haystack/nodes/question_generator/question_generator.py
+++ b/haystack/nodes/question_generator/question_generator.py
@@ -37,6 +37,7 @@ def __init__(
split_overlap=10,
use_gpu=True,
prompt="generate questions:",
+ num_queries_per_doc=1,
batch_size: Optional[int] = None,
):
"""
@@ -65,6 +66,7 @@ def __init__(
self.split_overlap = split_overlap
self.preprocessor = PreProcessor()
self.prompt = prompt
+ self.num_queries_per_doc = num_queries_per_doc
self.batch_size = batch_size
def run(self, documents: List[Document]): # type: ignore
@@ -122,6 +124,7 @@ def generate(self, text: str) -> List[str]:
no_repeat_ngram_size=self.no_repeat_ngram_size,
length_penalty=self.length_penalty,
early_stopping=self.early_stopping,
+ num_return_sequences=self.num_queries_per_doc,
)
string_output = self.tokenizer.batch_decode(tokens_output)
@@ -190,6 +193,7 @@ def generate_batch(
no_repeat_ngram_size=self.no_repeat_ngram_size,
length_penalty=self.length_penalty,
early_stopping=self.early_stopping,
+ num_return_sequences=self.num_queries_per_doc,
)
string_output = self.tokenizer.batch_decode(tokens_output)
diff --git a/haystack/nodes/retriever/_embedding_encoder.py b/haystack/nodes/retriever/_embedding_encoder.py
index 918d3277d3..c434959bcd 100644
--- a/haystack/nodes/retriever/_embedding_encoder.py
+++ b/haystack/nodes/retriever/_embedding_encoder.py
@@ -1,17 +1,20 @@
-from typing import TYPE_CHECKING, Callable, List, Union, Dict
-
import logging
from abc import abstractmethod
+from pathlib import Path
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Union
+
import numpy as np
-from tqdm.auto import tqdm
import torch
+from sentence_transformers import InputExample, losses
+from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
-from transformers import AutoTokenizer, AutoModel
+from tqdm.auto import tqdm
+from transformers import AutoModel, AutoTokenizer
-from haystack.schema import Document
+from haystack.modeling.data_handler.dataloader import NamedDataLoader
from haystack.modeling.data_handler.dataset import convert_features_to_dataset, flatten_rename
from haystack.modeling.infer import Inferencer
-from haystack.modeling.data_handler.dataloader import NamedDataLoader
+from haystack.schema import Document
if TYPE_CHECKING:
from haystack.nodes.retriever import EmbeddingRetriever
@@ -41,6 +44,47 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
"""
pass
+ def train(
+ self,
+ training_data: List[Dict[str, Any]],
+ learning_rate: float = 2e-5,
+ n_epochs: int = 1,
+ num_warmup_steps: int = None,
+ batch_size: int = 16,
+ ):
+ """
+ Trains/adapts the underlying embedding model.
+
+ Each training data example is a dictionary with the following keys:
+
+ * question: the question string
+ * pos_doc: the positive document string
+ * neg_doc: the negative document string
+ * score: the score margin
+
+
+ :param training_data: The training data
+ :type training_data: List[Dict[str, Any]]
+ :param learning_rate: The learning rate
+ :type learning_rate: float
+ :param n_epochs: The number of training epochs
+ :type n_epochs: int
+ :param num_warmup_steps: The number of warmup steps
+ :type num_warmup_steps: int
+ :param batch_size: The batch size to use for the training, defaults to 16
+ :type batch_size: int (optional)
+ """
+ pass
+
+ def save(self, save_dir: Union[Path, str]):
+ """
+ Save the model to the given directory
+
+ :param save_dir: The directory where the model will be saved
+ :type save_dir: Union[Path, str]
+ """
+ pass
+
class _DefaultEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@@ -87,6 +131,19 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
passages = [d.content for d in docs] # type: ignore
return self.embed(passages)
+ def train(
+ self,
+ training_data: List[Dict[str, Any]],
+ learning_rate: float = 2e-5,
+ n_epochs: int = 1,
+ num_warmup_steps: int = None,
+ batch_size: int = 16,
+ ):
+ raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
+
+ def save(self, save_dir: Union[Path, str]):
+ raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
+
class _SentenceTransformersEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@@ -127,6 +184,33 @@ def embed_documents(self, docs: List[Document]) -> List[np.ndarray]:
passages = [[d.meta["name"] if d.meta and "name" in d.meta else "", d.content] for d in docs] # type: ignore
return self.embed(passages)
+ def train(
+ self,
+ training_data: List[Dict[str, Any]],
+ learning_rate: float = 2e-5,
+ n_epochs: int = 1,
+ num_warmup_steps: int = None,
+ batch_size: int = 16,
+ ):
+
+ train_examples = [
+ InputExample(texts=[i["question"], i["pos_doc"], i["neg_doc"]], label=i["score"]) for i in training_data
+ ]
+ logger.info(f"GPL training/adapting {self.embedding_model} with {len(train_examples)} examples")
+ train_dataloader = DataLoader(train_examples, batch_size=batch_size, drop_last=True, shuffle=True)
+ train_loss = losses.MarginMSELoss(self.embedding_model)
+
+ # Tune the model
+ self.embedding_model.fit(
+ train_objectives=[(train_dataloader, train_loss)],
+ epochs=n_epochs,
+ optimizer_params={"lr": learning_rate},
+ warmup_steps=int(len(train_dataloader) * 0.1) if num_warmup_steps is None else num_warmup_steps,
+ )
+
+ def save(self, save_dir: Union[Path, str]):
+ self.embedding_model.save(path=str(save_dir))
+
class _RetribertEmbeddingEncoder(_BaseEmbeddingEncoder):
def __init__(self, retriever: "EmbeddingRetriever"):
@@ -208,6 +292,19 @@ def dataset_from_dicts(self, dicts: List[dict]):
dataset, tensornames = convert_features_to_dataset(features=features_flat)
return dataset, tensornames
+ def train(
+ self,
+ training_data: List[Dict[str, Any]],
+ learning_rate: float = 2e-5,
+ n_epochs: int = 1,
+ num_warmup_steps: int = None,
+ batch_size: int = 16,
+ ):
+ raise NotImplementedError("train method can only be used with sentence-transformers EmbeddingRetriever(s)")
+
+ def save(self, save_dir: Union[Path, str]):
+ raise NotImplementedError("save method can only be used with sentence-transformers EmbeddingRetriever(s)")
+
_EMBEDDING_ENCODERS: Dict[str, Callable] = {
"farm": _DefaultEmbeddingEncoder,
diff --git a/haystack/nodes/retriever/dense.py b/haystack/nodes/retriever/dense.py
index ed568da4cf..c721cbb0b3 100644
--- a/haystack/nodes/retriever/dense.py
+++ b/haystack/nodes/retriever/dense.py
@@ -1,4 +1,4 @@
-from typing import List, Dict, Union, Optional
+from typing import List, Dict, Union, Optional, Any
import logging
from pathlib import Path
@@ -1851,3 +1851,50 @@ def _infer_model_format(model_name_or_path: str) -> str:
# Model is neither sentence-transformers nor retribert model -> use _DefaultEmbeddingEncoder
return "farm"
+
+ def train(
+ self,
+ training_data: List[Dict[str, Any]],
+ learning_rate: float = 2e-5,
+ n_epochs: int = 1,
+ num_warmup_steps: int = None,
+ batch_size: int = 16,
+ ) -> None:
+ """
+ Trains/adapts the underlying embedding model.
+
+ Each training data example is a dictionary with the following keys:
+
+ * question: the question string
+ * pos_doc: the positive document string
+ * neg_doc: the negative document string
+ * score: the score margin
+
+
+ :param training_data: The training data
+ :type training_data: List[Dict[str, Any]]
+ :param learning_rate: The learning rate
+ :type learning_rate: float
+ :param n_epochs: The number of epochs
+ :type n_epochs: int
+ :param num_warmup_steps: The number of warmup steps
+ :type num_warmup_steps: int
+ :param batch_size: The batch size to use for the training, defaults to 16
+ :type batch_size: int (optional)
+ """
+ self.embedding_encoder.train(
+ training_data,
+ learning_rate=learning_rate,
+ n_epochs=n_epochs,
+ num_warmup_steps=num_warmup_steps,
+ batch_size=batch_size,
+ )
+
+ def save(self, save_dir: Union[Path, str]) -> None:
+ """
+ Save the model to the given directory
+
+ :param save_dir: The directory where the model will be saved
+ :type save_dir: Union[Path, str]
+ """
+ self.embedding_encoder.save(save_dir=save_dir)
diff --git a/test/conftest.py b/test/conftest.py
index ded06bb3e1..c4eb70f303 100644
--- a/test/conftest.py
+++ b/test/conftest.py
@@ -672,6 +672,13 @@ def get_retriever(retriever_type, document_store):
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="deepset/sentence_bert", use_gpu=False
)
+ elif retriever_type == "embedding_sbert":
+ retriever = EmbeddingRetriever(
+ document_store=document_store,
+ embedding_model="sentence-transformers/msmarco-distilbert-base-tas-b",
+ model_format="sentence_transformers",
+ use_gpu=False,
+ )
elif retriever_type == "retribert":
retriever = EmbeddingRetriever(
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
diff --git a/test/nodes/test_label_generator.py b/test/nodes/test_label_generator.py
new file mode 100644
index 0000000000..070fd94d6c
--- /dev/null
+++ b/test/nodes/test_label_generator.py
@@ -0,0 +1,62 @@
+from pathlib import Path
+
+import pytest
+
+from haystack.nodes import QuestionGenerator, EmbeddingRetriever, PseudoLabelGenerator
+from test.conftest import DOCS_WITH_EMBEDDINGS
+
+
+@pytest.mark.slow
+@pytest.mark.generator
+@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
+@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True)
+def test_pseudo_label_generator(
+ document_store, retriever: EmbeddingRetriever, question_generator: QuestionGenerator, tmp_path: Path
+):
+ document_store.write_documents(DOCS_WITH_EMBEDDINGS)
+ psg = PseudoLabelGenerator(question_generator, retriever)
+ train_examples = []
+ for idx, doc in enumerate(document_store):
+ output, stream = psg.run(documents=[doc])
+ assert "gpl_labels" in output
+ for item in output["gpl_labels"]:
+ assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item
+ train_examples.append(item)
+
+ assert len(train_examples) > 0
+ retriever.train(train_examples)
+ retriever.save(tmp_path)
+
+
+@pytest.mark.slow
+@pytest.mark.generator
+@pytest.mark.parametrize("document_store", ["memory"], indirect=True)
+@pytest.mark.parametrize("retriever", ["embedding_sbert"], indirect=True)
+def test_pseudo_label_generator_using_question_document_pairs(
+ document_store, retriever: EmbeddingRetriever, tmp_path: Path
+):
+ document_store.write_documents(DOCS_WITH_EMBEDDINGS)
+ docs = [
+ {
+ "question": "What is the capital of Germany?",
+ "document": "Berlin is the capital and largest city of Germany by both area and population.",
+ },
+ {
+ "question": "What is the largest city in Germany by population and area?",
+ "document": "Berlin is the capital and largest city of Germany by both area and population.",
+ },
+ ]
+ psg = PseudoLabelGenerator(docs, retriever)
+ train_examples = []
+ for idx, doc in enumerate(document_store):
+ # the documents passed here are ignored as we provided source documents in the constructor
+ output, stream = psg.run(documents=[doc])
+ assert "gpl_labels" in output
+ for item in output["gpl_labels"]:
+ assert "question" in item and "pos_doc" in item and "neg_doc" in item and "score" in item
+ train_examples.append(item)
+
+ assert len(train_examples) > 0
+
+ retriever.train(train_examples)
+ retriever.save(tmp_path)