Skip to content

Commit

Permalink
Simplify loading of EmbeddingRetriever (#2619)
Browse files Browse the repository at this point in the history
* Infer model format for EmbeddingRetriever automatically

* Update Documentation & Code Style

* Adapt conftest to automatic inference of model_format

* Update Documentation & Code Style

* Fix tests

* Update Documentation & Code Style

* Fix tests

* Adapt tutorials

* Update Documentation & Code Style

* Add test for similarity scores with sentence transformers

* Adapt doc string and warning message

* Update Documentation & Code Style

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
bogdankostic and github-actions[bot] authored Jun 2, 2022
1 parent ca19521 commit 61d9429
Show file tree
Hide file tree
Showing 14 changed files with 115 additions and 73 deletions.
14 changes: 9 additions & 5 deletions docs/_src/api/api/retriever.md
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ class EmbeddingRetriever(BaseRetriever)
#### EmbeddingRetriever.\_\_init\_\_

```python
def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: str = "farm", pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, embed_meta_fields: List[str] = [])
def __init__(document_store: BaseDocumentStore, embedding_model: str, model_version: Optional[str] = None, use_gpu: bool = True, batch_size: int = 32, max_seq_len: int = 512, model_format: Optional[str] = None, pooling_strategy: str = "reduce_mean", emb_extraction_layer: int = -1, top_k: int = 10, progress_bar: bool = True, devices: Optional[List[Union[str, torch.device]]] = None, use_auth_token: Optional[Union[str, bool]] = None, scale_score: bool = True, embed_meta_fields: List[str] = [])
```

**Arguments**:
Expand All @@ -1182,10 +1182,14 @@ def __init__(document_store: BaseDocumentStore, embedding_model: str, model_vers
- `use_gpu`: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
- `batch_size`: Number of documents to encode at once.
- `max_seq_len`: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down.
- `model_format`: Name of framework that was used for saving the model. Options:
- ``'farm'``
- ``'transformers'``
- ``'sentence_transformers'``
- `model_format`: Name of framework that was used for saving the model or model type. If no model_format is
provided, it will be inferred automatically from the model configuration files.
Options:

- ``'farm'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
- `pooling_strategy`: Strategy for combining the embeddings from the model (for farm / transformers models only).
Options:

Expand Down
1 change: 0 additions & 1 deletion haystack/json-schemas/haystack-pipeline-master.schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -2265,7 +2265,6 @@
},
"model_format": {
"title": "Model Format",
"default": "farm",
"type": "string"
},
"pooling_strategy": {
Expand Down
56 changes: 45 additions & 11 deletions haystack/nodes/retriever/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
from pathlib import Path
from copy import deepcopy
from requests.exceptions import HTTPError

import numpy as np
from tqdm.auto import tqdm
Expand All @@ -11,6 +12,8 @@
from torch.nn import DataParallel
from torch.utils.data.sampler import SequentialSampler
import pandas as pd
from huggingface_hub import hf_hub_download
from transformers import AutoConfig

from haystack.errors import HaystackError
from haystack.schema import Document
Expand Down Expand Up @@ -1452,7 +1455,7 @@ def __init__(
use_gpu: bool = True,
batch_size: int = 32,
max_seq_len: int = 512,
model_format: str = "farm",
model_format: Optional[str] = None,
pooling_strategy: str = "reduce_mean",
emb_extraction_layer: int = -1,
top_k: int = 10,
Expand All @@ -1469,11 +1472,14 @@ def __init__(
:param use_gpu: Whether to use all available GPUs or the CPU. Falls back on CPU if no GPU is available.
:param batch_size: Number of documents to encode at once.
:param max_seq_len: Longest length of each document sequence. Maximum number of tokens for the document text. Longer ones will be cut down.
:param model_format: Name of framework that was used for saving the model. Options:
- ``'farm'``
- ``'transformers'``
- ``'sentence_transformers'``
:param model_format: Name of framework that was used for saving the model or model type. If no model_format is
provided, it will be inferred automatically from the model configuration files.
Options:
- ``'farm'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
- ``'transformers'`` (will use `_DefaultEmbeddingEncoder` as embedding encoder)
- ``'sentence_transformers'`` (will use `_SentenceTransformersEmbeddingEncoder` as embedding encoder)
- ``'retribert'`` (will use `_RetribertEmbeddingEncoder` as embedding encoder)
:param pooling_strategy: Strategy for combining the embeddings from the model (for farm / transformers models only).
Options:
Expand Down Expand Up @@ -1514,7 +1520,6 @@ def __init__(

self.document_store = document_store
self.embedding_model = embedding_model
self.model_format = model_format
self.model_version = model_version
self.use_gpu = use_gpu
self.batch_size = batch_size
Expand All @@ -1525,19 +1530,26 @@ def __init__(
self.progress_bar = progress_bar
self.use_auth_token = use_auth_token
self.scale_score = scale_score
self.model_format = self._infer_model_format(embedding_model) if model_format is None else model_format

logger.info(f"Init retriever using embeddings of model {embedding_model}")

if model_format not in _EMBEDDING_ENCODERS.keys():
if self.model_format not in _EMBEDDING_ENCODERS.keys():
raise ValueError(f"Unknown retriever embedding model format {model_format}")

if self.embedding_model.startswith("sentence-transformers") and self.model_format != "sentence_transformers":
if (
self.embedding_model.startswith("sentence-transformers")
and model_format
and model_format != "sentence_transformers"
):
logger.warning(
f"You seem to be using a Sentence Transformer embedding model but 'model_format' is set to '{self.model_format}'."
f" You may need to set 'model_format='sentence_transformers' to ensure correct loading of model."
f" You may need to set model_format='sentence_transformers' to ensure correct loading of model."
f"As an alternative, you can let Haystack derive the format automatically by not setting the "
f"'model_format' parameter at all."
)

self.embedding_encoder = _EMBEDDING_ENCODERS[model_format](self)
self.embedding_encoder = _EMBEDDING_ENCODERS[self.model_format](self)
self.embed_meta_fields = embed_meta_fields

def retrieve(
Expand Down Expand Up @@ -1817,3 +1829,25 @@ def _preprocess_documents(self, docs: List[Document]) -> List[Document]:
doc.content = "\n".join(meta_data_fields + [doc.content])
linearized_docs.append(doc)
return linearized_docs

@staticmethod
def _infer_model_format(model_name_or_path: str) -> str:
# Check if model name is a local directory with sentence transformers config file in it
if Path(model_name_or_path).exists():
if Path(f"{model_name_or_path}/config_sentence_transformers.json").exists():
return "sentence_transformers"
# Check if sentence transformers config file in model hub
else:
try:
hf_hub_download(repo_id=model_name_or_path, filename="config_sentence_transformers.json")
return "sentence_transformers"
except HTTPError:
pass

# Check if retribert model
config = AutoConfig.from_pretrained(model_name_or_path)
if config.model_type == "retribert":
return "retribert"

# Model is neither sentence-transformers nor retribert model -> use _DefaultEmbeddingEncoder
return "farm"
5 changes: 1 addition & 4 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,10 +674,7 @@ def get_retriever(retriever_type, document_store):
)
elif retriever_type == "retribert":
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="yjernite/retribert-base-uncased",
model_format="retribert",
use_gpu=False,
document_store=document_store, embedding_model="yjernite/retribert-base-uncased", use_gpu=False
)
elif retriever_type == "dpr_lfqa":
retriever = DensePassageRetriever(
Expand Down
24 changes: 23 additions & 1 deletion test/document_stores/test_document_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -1387,14 +1387,33 @@ def test_elasticsearch_synonyms():
"document_store_with_docs", ["memory", "faiss", "milvus1", "weaviate", "elasticsearch"], indirect=True
)
@pytest.mark.embedding_dim(384)
def test_similarity_score(document_store_with_docs):
def test_similarity_score_sentence_transformers(document_store_with_docs):
retriever = EmbeddingRetriever(
document_store=document_store_with_docs, embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2"
)
document_store_with_docs.update_embeddings(retriever)
pipeline = DocumentSearchPipeline(retriever)
prediction = pipeline.run("Paul lives in New York")
scores = [document.score for document in prediction["documents"]]
assert scores == pytest.approx(
[0.8497486114501953, 0.6622999012470245, 0.6077829301357269, 0.5928314849734306, 0.5614184625446796], abs=1e-3
)


@pytest.mark.parametrize(
"document_store_with_docs", ["memory", "faiss", "milvus1", "weaviate", "elasticsearch"], indirect=True
)
@pytest.mark.embedding_dim(384)
def test_similarity_score(document_store_with_docs):
retriever = EmbeddingRetriever(
document_store=document_store_with_docs,
embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2",
model_format="farm",
)
document_store_with_docs.update_embeddings(retriever)
pipeline = DocumentSearchPipeline(retriever)
prediction = pipeline.run("Paul lives in New York")
scores = [document.score for document in prediction["documents"]]
assert scores == pytest.approx(
[0.9102507941407827, 0.6937791467877008, 0.6491682889305038, 0.6321622491318529, 0.5909129441370939], abs=1e-3
)
Expand All @@ -1409,6 +1428,7 @@ def test_similarity_score_without_scaling(document_store_with_docs):
document_store=document_store_with_docs,
embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2",
scale_score=False,
model_format="farm",
)
document_store_with_docs.update_embeddings(retriever)
pipeline = DocumentSearchPipeline(retriever)
Expand All @@ -1428,6 +1448,7 @@ def test_similarity_score_dot_product(document_store_dot_product_with_docs):
retriever = EmbeddingRetriever(
document_store=document_store_dot_product_with_docs,
embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2",
model_format="farm",
)
document_store_dot_product_with_docs.update_embeddings(retriever)
pipeline = DocumentSearchPipeline(retriever)
Expand All @@ -1447,6 +1468,7 @@ def test_similarity_score_dot_product_without_scaling(document_store_dot_product
document_store=document_store_dot_product_with_docs,
embedding_model="sentence-transformers/paraphrase-MiniLM-L3-v2",
scale_score=False,
model_format="farm",
)
document_store_dot_product_with_docs.update_embeddings(retriever)
pipeline = DocumentSearchPipeline(retriever)
Expand Down
6 changes: 4 additions & 2 deletions test/nodes/test_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,13 @@ def test_embeddings_encoder_of_embedding_retriever_should_warn_about_model_forma

with caplog.at_level(logging.WARNING):
EmbeddingRetriever(
document_store=document_store, embedding_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
document_store=document_store,
embedding_model="sentence-transformers/paraphrase-multilingual-mpnet-base-v2",
model_format="farm",
)

assert (
"You may need to set 'model_format='sentence_transformers' to ensure correct loading of model."
"You may need to set model_format='sentence_transformers' to ensure correct loading of model."
in caplog.text
)

Expand Down
4 changes: 1 addition & 3 deletions tutorials/Tutorial11_Pipelines.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,7 @@
"\n",
"# Initialize dense retriever\n",
"embedding_retriever = EmbeddingRetriever(\n",
" document_store,\n",
" model_format=\"sentence_transformers\",\n",
" embedding_model=\"sentence-transformers/multi-qa-mpnet-base-dot-v1\",\n",
" document_store, embedding_model=\"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n",
")\n",
"document_store.update_embeddings(embedding_retriever, update_existing_embeddings=False)\n",
"\n",
Expand Down
4 changes: 1 addition & 3 deletions tutorials/Tutorial11_Pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ def tutorial11_pipelines():

# Initialize dense retriever
embedding_retriever = EmbeddingRetriever(
document_store,
model_format="sentence_transformers",
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
document_store, embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1"
)
document_store.update_embeddings(embedding_retriever, update_existing_embeddings=False)

Expand Down
6 changes: 2 additions & 4 deletions tutorials/Tutorial14_Query_Classifier.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@
"\n",
"# Initialize dense retriever\n",
"embedding_retriever = EmbeddingRetriever(\n",
" document_store=document_store,\n",
" model_format=\"sentence_transformers\",\n",
" embedding_model=\"sentence-transformers/multi-qa-mpnet-base-dot-v1\",\n",
" document_store=document_store, embedding_model=\"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n",
")\n",
"document_store.update_embeddings(embedding_retriever, update_existing_embeddings=False)\n",
"\n",
Expand Down Expand Up @@ -6782,4 +6780,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}
4 changes: 1 addition & 3 deletions tutorials/Tutorial14_Query_Classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,7 @@ def tutorial14_query_classifier():

# Initialize dense retriever
embedding_retriever = EmbeddingRetriever(
document_store=document_store,
model_format="sentence_transformers",
embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1",
document_store=document_store, embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1"
)
document_store.update_embeddings(embedding_retriever, update_existing_embeddings=False)

Expand Down
52 changes: 24 additions & 28 deletions tutorials/Tutorial15_TableQA.ipynb

Large diffs are not rendered by default.

6 changes: 1 addition & 5 deletions tutorials/Tutorial15_TableQA.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ def read_tables(filename):
# **Here:** We use the EmbeddingRetriever capable of retrieving relevant content among a database
# of texts and tables using dense embeddings.

retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model="deepset/all-mpnet-base-v2-table",
model_format="sentence_transformers",
)
retriever = EmbeddingRetriever(document_store=document_store, embedding_model="deepset/all-mpnet-base-v2-table")

# Add table embeddings to the tables in DocumentStore
document_store.update_embeddings(retriever=retriever)
Expand Down
4 changes: 2 additions & 2 deletions tutorials/Tutorial5_Evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@
"# For more information and suggestions on different models check out the documentation at: https://www.sbert.net/docs/pretrained_models.html\n",
"\n",
"# from haystack.retriever import EmbeddingRetriever, DensePassageRetriever\n",
"# retriever = EmbeddingRetriever(document_store=document_store, model_format=\"sentence_transformers\",\n",
"# retriever = EmbeddingRetriever(document_store=document_store,\n",
"# embedding_model=\"sentence-transformers/multi-qa-mpnet-base-dot-v1\")\n",
"# retriever = DensePassageRetriever(document_store=document_store,\n",
"# query_embedding_model=\"facebook/dpr-question_encoder-single-nq-base\",\n",
Expand Down Expand Up @@ -15737,4 +15737,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}
2 changes: 1 addition & 1 deletion tutorials/Tutorial5_Evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def tutorial5_evaluation():
# For more information and suggestions on different models check out the documentation at: https://www.sbert.net/docs/pretrained_models.html

# from haystack.retriever import EmbeddingRetriever, DensePassageRetriever
# retriever = EmbeddingRetriever(document_store=document_store, model_format="sentence_transformers",
# retriever = EmbeddingRetriever(document_store=document_store,
# embedding_model="sentence-transformers/multi-qa-mpnet-base-dot-v1")
# retriever = DensePassageRetriever(document_store=document_store,
# query_embedding_model="facebook/dpr-question_encoder-single-nq-base",
Expand Down

0 comments on commit 61d9429

Please sign in to comment.