Skip to content

Commit

Permalink
Unify vector_dim and embedding_dim parameter in Document Store (#1922)
Browse files Browse the repository at this point in the history
* Refactored code to unify vector_dim and embedding_dim parameter in DocumentStores

* Unit test cases updated to use `embedding_dim` instead of `vector_dim`

* Unit test case update to use embedding_dim instead of vector_dim

* Add latest docstring and tutorial changes

* Put usage of `vector_dim` param in same if-block as corresponding warning

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: bogdankostic <[email protected]>
  • Loading branch information
3 people authored Jan 10, 2022
1 parent 00dc30a commit a44b6c1
Show file tree
Hide file tree
Showing 12 changed files with 86 additions and 55 deletions.
14 changes: 8 additions & 6 deletions docs/_src/api/api/document_store.md
Original file line number Diff line number Diff line change
Expand Up @@ -1202,14 +1202,15 @@ the vector embeddings are indexed in a FAISS Index.
#### \_\_init\_\_

```python
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
| __init__(sql_url: str = "sqlite:///faiss_document_store.db", vector_dim: int = None, embedding_dim: int = 768, faiss_index_factory_str: str = "Flat", faiss_index: Optional["faiss.swigfaiss.Index"] = None, return_embedding: bool = False, index: str = "document", similarity: str = "dot_product", embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', faiss_index_path: Union[str, Path] = None, faiss_config_path: Union[str, Path] = None, **kwargs, ,)
```

**Arguments**:

- `sql_url`: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
deployment, Postgres is recommended.
- `vector_dim`: the embedding vector size.
- `vector_dim`: Deprecated. Use embedding_dim instead.
- `embedding_dim`: The embedding vector size. Default: 768.
- `faiss_index_factory_str`: Create a new FAISS index of the specified type.
The type is determined from the given string following the conventions
of the original FAISS index factory.
Expand All @@ -1231,7 +1232,7 @@ the vector embeddings are indexed in a FAISS Index.
- `index`: Name of index in document store to use.
- `similarity`: The similarity function used to compare document vectors. 'dot_product' is the default since it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
For `dot_product`: expit(np.asarray(raw_score / 100))
FOr `cosine`: (raw_score + 1) / 2
- `embedding_field`: Name of field containing an embedding vector.
Expand Down Expand Up @@ -1424,7 +1425,7 @@ Save FAISS Index to the specified file.
- `config_path`: Path to save the initial configuration parameters to.
Defaults to the same as the file path, save the extension (.json).
This file contains all the parameters passed to FAISSDocumentStore()
at creation time (for example the SQL path, vector_dim, etc), and will be
at creation time (for example the SQL path, embedding_dim, etc), and will be
used by the `load` method to restore the index with the appropriate configuration.

**Returns**:
Expand Down Expand Up @@ -1478,7 +1479,7 @@ Usage:
#### \_\_init\_\_

```python
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
| __init__(sql_url: str = "sqlite:///", milvus_url: str = "tcp://localhost:19530", connection_pool: str = "SingletonThread", index: str = "document", vector_dim: int = None, embedding_dim: int = 768, index_file_size: int = 1024, similarity: str = "dot_product", index_type: IndexType = IndexType.FLAT, index_param: Optional[Dict[str, Any]] = None, search_param: Optional[Dict[str, Any]] = None, return_embedding: bool = False, embedding_field: str = "embedding", progress_bar: bool = True, duplicate_documents: str = 'overwrite', **kwargs, ,)
```

**Arguments**:
Expand All @@ -1491,7 +1492,8 @@ Usage:
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
- `connection_pool`: Connection pool type to connect with Milvus server. Default: "SingletonThread".
- `index`: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
- `vector_dim`: The embedding vector size. Default: 768.
- `vector_dim`: Deprecated. Use embedding_dim instead.
- `embedding_dim`: The embedding vector size. Default: 768.
- `index_file_size`: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
Expand Down
2 changes: 1 addition & 1 deletion docs/_src/tutorials/tutorials/12.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ For more info on which suits your use case: https://github.com/facebookresearch/
```python
from haystack.document_stores import FAISSDocumentStore

document_store = FAISSDocumentStore(vector_dim=128, faiss_index_factory_str="Flat")
document_store = FAISSDocumentStore(embedding_dim=128, faiss_index_factory_str="Flat")
```

### Cleaning & indexing documents
Expand Down
46 changes: 28 additions & 18 deletions haystack/document_stores/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pathlib import Path
from typing import Union, List, Optional, Dict, Generator
from tqdm.auto import tqdm
import warnings

try:
import faiss
Expand Down Expand Up @@ -37,7 +38,8 @@ class FAISSDocumentStore(SQLDocumentStore):
def __init__(
self,
sql_url: str = "sqlite:///faiss_document_store.db",
vector_dim: int = 768,
vector_dim: int = None,
embedding_dim: int = 768,
faiss_index_factory_str: str = "Flat",
faiss_index: Optional["faiss.swigfaiss.Index"] = None,
return_embedding: bool = False,
Expand All @@ -53,7 +55,8 @@ def __init__(
"""
:param sql_url: SQL connection URL for database. It defaults to local file based SQLite DB. For large scale
deployment, Postgres is recommended.
:param vector_dim: the embedding vector size.
:param vector_dim: Deprecated. Use embedding_dim instead.
:param embedding_dim: The embedding vector size. Default: 768.
:param faiss_index_factory_str: Create a new FAISS index of the specified type.
The type is determined from the given string following the conventions
of the original FAISS index factory.
Expand All @@ -75,7 +78,7 @@ def __init__(
:param index: Name of index in document store to use.
:param similarity: The similarity function used to compare document vectors. 'dot_product' is the default since it is
more performant with DPR embeddings. 'cosine' is recommended if you are using a Sentence-Transformer model.
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
In both cases, the returned values in Document.score are normalized to be in range [0,1]:
For `dot_product`: expit(np.asarray(raw_score / 100))
FOr `cosine`: (raw_score + 1) / 2
:param embedding_field: Name of field containing an embedding vector.
Expand All @@ -89,7 +92,7 @@ def __init__(
exists.
:param faiss_index_path: Stored FAISS index file. Can be created via calling `save()`.
If specified no other params besides faiss_config_path must be specified.
:param faiss_config_path: Stored FAISS initial configuration parameters.
:param faiss_config_path: Stored FAISS initial configuration parameters.
Can be created via calling `save()`
"""
# special case if we want to load an existing index from disk
Expand All @@ -103,14 +106,15 @@ def __init__(

# save init parameters to enable export of component config as YAML
self.set_config(
sql_url=sql_url,
vector_dim=vector_dim,
sql_url=sql_url,
vector_dim=vector_dim,
embedding_dim=embedding_dim,
faiss_index_factory_str=faiss_index_factory_str,
return_embedding=return_embedding,
duplicate_documents=duplicate_documents,
index=index,
duplicate_documents=duplicate_documents,
index=index,
similarity=similarity,
embedding_field=embedding_field,
embedding_field=embedding_field,
progress_bar=progress_bar
)

Expand All @@ -124,14 +128,20 @@ def __init__(
raise ValueError("The FAISS document store can currently only support dot_product, cosine and l2 similarity. "
"Please set similarity to one of the above.")

self.vector_dim = vector_dim
if vector_dim is not None:
warnings.warn("The 'vector_dim' parameter is deprecated, "
"use 'embedding_dim' instead.", DeprecationWarning, 2)
self.embedding_dim = vector_dim
else:
self.embedding_dim = embedding_dim

self.faiss_index_factory_str = faiss_index_factory_str
self.faiss_indexes: Dict[str, faiss.swigfaiss.Index] = {}
if faiss_index:
self.faiss_indexes[index] = faiss_index
else:
self.faiss_indexes[index] = self._create_new_index(
vector_dim=self.vector_dim,
embedding_dim=self.embedding_dim,
index_factory=faiss_index_factory_str,
metric_type=self.metric_type,
**kwargs
Expand All @@ -158,7 +168,7 @@ def _validate_params_load_from_disk(self, sig: Signature, locals: dict, kwargs:
if param.name not in allowed_params and param.default != locals[param.name]:
invalid_param_set = True
break

if invalid_param_set or len(kwargs) > 0:
raise ValueError("if faiss_index_path is passed no other params besides faiss_config_path are allowed.")

Expand All @@ -172,20 +182,20 @@ def _validate_index_sync(self):
"configuration file correctly points to the same database that "
"was used when creating the original index.")

def _create_new_index(self, vector_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
def _create_new_index(self, embedding_dim: int, metric_type, index_factory: str = "Flat", **kwargs):
if index_factory == "HNSW":
# faiss index factory doesn't give the same results for HNSW IP, therefore direct init.
# defaults here are similar to DPR codebase (good accuracy, but very high RAM consumption)
n_links = kwargs.get("n_links", 64)
index = faiss.IndexHNSWFlat(vector_dim, n_links, metric_type)
index = faiss.IndexHNSWFlat(embedding_dim, n_links, metric_type)
index.hnsw.efSearch = kwargs.get("efSearch", 20)#20
index.hnsw.efConstruction = kwargs.get("efConstruction", 80)#80
if "ivf" in index_factory.lower(): # enable reconstruction of vectors for inverted index
self.faiss_indexes[index].set_direct_map_type(faiss.DirectMap.Hashtable)

logger.info(f"HNSW params: n_links: {n_links}, efSearch: {index.hnsw.efSearch}, efConstruction: {index.hnsw.efConstruction}")
else:
index = faiss.index_factory(vector_dim, index_factory, metric_type)
index = faiss.index_factory(embedding_dim, index_factory, metric_type)
return index

def write_documents(self, documents: Union[List[dict], List[Document]], index: Optional[str] = None,
Expand Down Expand Up @@ -217,7 +227,7 @@ def write_documents(self, documents: Union[List[dict], List[Document]], index: O

if not self.faiss_indexes.get(index):
self.faiss_indexes[index] = self._create_new_index(
vector_dim=self.vector_dim,
embedding_dim=self.embedding_dim,
index_factory=self.faiss_index_factory_str,
metric_type=faiss.METRIC_INNER_PRODUCT,
)
Expand Down Expand Up @@ -544,7 +554,7 @@ def save(self, index_path: Union[str, Path], config_path: Optional[Union[str, Pa
:param config_path: Path to save the initial configuration parameters to.
Defaults to the same as the file path, save the extension (.json).
This file contains all the parameters passed to FAISSDocumentStore()
at creation time (for example the SQL path, vector_dim, etc), and will be
at creation time (for example the SQL path, embedding_dim, etc), and will be
used by the `load` method to restore the index with the appropriate configuration.
:return: None
"""
Expand Down Expand Up @@ -574,7 +584,7 @@ def _load_init_params_from_config(self, index_path: Union[str, Path], config_pat

# Add other init params to override the ones defined in the init params file
init_params["faiss_index"] = faiss_index
init_params["vector_dim"] = faiss_index.d
init_params["embedding_dim"] = faiss_index.d

return init_params

Expand Down
20 changes: 15 additions & 5 deletions haystack/document_stores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from haystack.nodes.retriever import BaseRetriever

import logging
import warnings
import numpy as np
from tqdm import tqdm
from scipy.special import expit
Expand Down Expand Up @@ -41,7 +42,8 @@ def __init__(
milvus_url: str = "tcp://localhost:19530",
connection_pool: str = "SingletonThread",
index: str = "document",
vector_dim: int = 768,
vector_dim: int = None,
embedding_dim: int = 768,
index_file_size: int = 1024,
similarity: str = "dot_product",
index_type: IndexType = IndexType.FLAT,
Expand All @@ -62,7 +64,8 @@ def __init__(
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
:param vector_dim: The embedding vector size. Default: 768.
:param vector_dim: Deprecated. Use embedding_dim instead.
:param embedding_dim: The embedding vector size. Default: 768.
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
Expand Down Expand Up @@ -98,13 +101,20 @@ def __init__(
# save init parameters to enable export of component config as YAML
self.set_config(
sql_url=sql_url, milvus_url=milvus_url, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
search_param=search_param, duplicate_documents=duplicate_documents,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
)

self.milvus_server = Milvus(uri=milvus_url, pool=connection_pool)
self.vector_dim = vector_dim

if vector_dim is not None:
warnings.warn("The 'vector_dim' parameter is deprecated, "
"use 'embedding_dim' instead.", DeprecationWarning, 2)
self.embedding_dim = vector_dim
else:
self.embedding_dim = embedding_dim

self.index_file_size = index_file_size

if similarity in ("dot_product", "cosine"):
Expand Down Expand Up @@ -147,7 +157,7 @@ def _create_collection_and_index_if_not_exist(
if not ok:
collection_param = {
'collection_name': index,
'dimension': self.vector_dim,
'dimension': self.embedding_dim,
'index_file_size': self.index_file_size,
'metric_type': self.metric_type
}
Expand Down
19 changes: 14 additions & 5 deletions haystack/document_stores/milvus2x.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import warnings
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union

if TYPE_CHECKING:
Expand Down Expand Up @@ -59,7 +60,8 @@ def __init__(
port: str = "19530",
connection_pool: str = "SingletonThread",
index: str = "document",
vector_dim: int = 768,
vector_dim: int = None,
embedding_dim: int = 768,
index_file_size: int = 1024,
similarity: str = "dot_product",
index_type: str = "IVF_FLAT",
Expand All @@ -81,7 +83,8 @@ def __init__(
See https://milvus.io/docs/v1.0.0/install_milvus.md for instructions to start a Milvus instance.
:param connection_pool: Connection pool type to connect with Milvus server. Default: "SingletonThread".
:param index: Index name for text, embedding and metadata (in Milvus terms, this is the "collection name").
:param vector_dim: The embedding vector size. Default: 768.
:param vector_dim: Deprecated. Use embedding_dim instead.
:param embedding_dim: The embedding vector size. Default: 768.
:param index_file_size: Specifies the size of each segment file that is stored by Milvus and its default value is 1024 MB.
When the size of newly inserted vectors reaches the specified volume, Milvus packs these vectors into a new segment.
Milvus creates one index file for each segment. When conducting a vector search, Milvus searches all index files one by one.
Expand Down Expand Up @@ -120,7 +123,7 @@ def __init__(
# save init parameters to enable export of component config as YAML
self.set_config(
sql_url=sql_url, host=host, port=port, connection_pool=connection_pool, index=index, vector_dim=vector_dim,
index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
embedding_dim=embedding_dim, index_file_size=index_file_size, similarity=similarity, index_type=index_type, index_param=index_param,
search_param=search_param, duplicate_documents=duplicate_documents, id_field=id_field,
return_embedding=return_embedding, embedding_field=embedding_field, progress_bar=progress_bar,
custom_fields=custom_fields,
Expand All @@ -135,7 +138,13 @@ def __init__(
connections.add_connection(default={"host": host, "port": port})
connections.connect()

self.vector_dim = vector_dim
if vector_dim is not None:
warnings.warn("The 'vector_dim' parameter is deprecated, "
"use 'embedding_dim' instead.", DeprecationWarning, 2)
self.embedding_dim = vector_dim
else:
self.embedding_dim = embedding_dim

self.index_file_size = index_file_size

if similarity == "dot_product":
Expand Down Expand Up @@ -187,7 +196,7 @@ def _create_collection_and_index_if_not_exist(
if not has_collection:
fields = [
FieldSchema(name=self.id_field, dtype=DataType.INT64, is_primary=True, auto_id=True),
FieldSchema(name=self.embedding_field, dtype=DataType.FLOAT_VECTOR, dim=self.vector_dim)
FieldSchema(name=self.embedding_field, dtype=DataType.FLOAT_VECTOR, dim=self.embedding_dim)
]

for field in custom_fields:
Expand Down
Loading

0 comments on commit a44b6c1

Please sign in to comment.