Skip to content

Commit b859edd

Browse files
committed
Switch to sentence-transformers from transformers
Why these changes are being introduced: The first approach for the embedding class OSNeuralSparseDocV3GTE used the `transformers` library for creating embeddings. Some early exploratory code seemed to indicate this more low level library would provide more flexibility in response formats and better performance. However, when exploring using multiprocessing for the `transformers` library, the alternate approach of using the `sentence-transformers` library was explored given it's more out-of-the-box multiprocessing support. During that exploration, based on learnings since the original spike code, it was determined that the `sentence-transformers` library might be a better fit overall for our purposes. Pivoting to this library will simplify our actual embedding logic, while providing some out-of-the-box tuning capabilities that should be sufficient for our purposes. How this addresses that need: The embedding class `OSNeuralSparseDocV3GTE` is reworked to use `sentence-transformers` instead of `transformers` for creating embeddings. This reducues considerably complexity in the actual creating of embeddings, while also exposing an API for multiprocessing. It's worth noting that testing is indicating that multiprocessing will *not* speed up embeddings, at least for the contexts we aim to create them, but the `sentence-transformers` library also better handles parallelism without explicit multiprocessing. In summary, switching to `sentence-transformers` results in a simpler API for creating embeddings, better out-of-the-box performance, with an API that still allows for more tuning later. Side effects of this change: * None Relevant ticket(s): * https://mitlibraries.atlassian.net/browse/USE-137
1 parent 0b13f95 commit b859edd

File tree

7 files changed

+1288
-393
lines changed

7 files changed

+1288
-393
lines changed

embeddings/models/base.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,12 @@ def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
5858
embedding_input: EmbeddingInput instance
5959
"""
6060

61+
@abstractmethod
6162
def create_embeddings(
6263
self, embedding_inputs: Iterator[EmbeddingInput]
6364
) -> Iterator[Embedding]:
64-
"""Yield Embeddings for a batch of EmbeddingInputs.
65+
"""Yield Embeddings for multiple EmbeddingInputs.
6566
6667
Args:
6768
embedding_inputs: iterator of EmbeddingInputs
6869
"""
69-
for embedding_input in embedding_inputs:
70-
yield self.create_embedding(embedding_input)

embeddings/models/os_neural_sparse_doc_v3_gte.py

Lines changed: 100 additions & 189 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,21 @@
22

33
import json
44
import logging
5+
import os
56
import shutil
67
import tempfile
78
import time
9+
from collections.abc import Iterator
810
from pathlib import Path
9-
from typing import TYPE_CHECKING
11+
from typing import cast
1012

11-
import torch
1213
from huggingface_hub import snapshot_download
13-
from transformers import AutoModelForMaskedLM, AutoTokenizer
14+
from sentence_transformers.sparse_encoder import SparseEncoder
15+
from torch import Tensor
1416

1517
from embeddings.embedding import Embedding, EmbeddingInput
1618
from embeddings.models.base import BaseEmbeddingModel
1719

18-
if TYPE_CHECKING:
19-
from transformers import PreTrainedModel
20-
from transformers.models.distilbert.tokenization_distilbert_fast import (
21-
DistilBertTokenizerFast,
22-
)
23-
2420
logger = logging.getLogger(__name__)
2521

2622

@@ -42,10 +38,8 @@ def __init__(self, model_path: str | Path) -> None:
4238
model_path: Path where the model will be downloaded to and loaded from.
4339
"""
4440
super().__init__(model_path)
45-
self._model: PreTrainedModel | None = None
46-
self._tokenizer: DistilBertTokenizerFast | None = None
47-
self._special_token_ids: list[int] | None = None
48-
self._device: torch.device = torch.device("cpu")
41+
self.device = os.getenv("TE_TORCH_DEVICE", "cpu")
42+
self._model: SparseEncoder = None # type: ignore[assignment]
4943

5044
def download(self) -> Path:
5145
"""Download and prepare model, saving to self.model_path.
@@ -139,209 +133,126 @@ def load(self) -> None:
139133
start_time = time.perf_counter()
140134
logger.info(f"Loading model from: {self.model_path}")
141135

142-
# ensure model exists locally
143136
if not self.model_path.exists():
144137
raise FileNotFoundError(f"Model not found at path: {self.model_path}")
145138

146-
# setup device (use CUDA if available, otherwise CPU)
147-
self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148-
149-
# load tokenizer
150-
self._tokenizer = AutoTokenizer.from_pretrained( # type: ignore[no-untyped-call]
151-
self.model_path,
152-
local_files_only=True,
153-
)
154-
155-
# load model as AutoModelForMaskedLM (required for sparse embeddings)
156-
self._model = AutoModelForMaskedLM.from_pretrained(
157-
self.model_path,
139+
# load model as SparseEncoder
140+
self._model = SparseEncoder(
141+
str(self.model_path),
158142
trust_remote_code=True,
159-
local_files_only=True,
143+
model_kwargs={},
144+
device="cpu",
160145
)
161-
self._model.to(self._device) # type: ignore[arg-type]
162-
self._model.eval()
163146

164-
# set special token IDs (following model card pattern)
165-
# these will be zeroed out in the sparse vectors
166-
self._special_token_ids = [
167-
self._tokenizer.vocab[token] # type: ignore[index]
168-
for token in self._tokenizer.special_tokens_map.values()
169-
]
170-
171-
logger.info(
172-
f"Model loaded successfully on {self._device}, "
173-
f"{time.perf_counter() - start_time:.2f}s"
174-
)
147+
logger.info(f"Model loaded successfully, {time.perf_counter() - start_time:.2f}s")
175148

176149
def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
177-
"""Create sparse vector and decoded token weight embeddings for an input text.
150+
"""Create an Embedding for an EmbeddingInput.
151+
152+
This model is configured to return a sparse vector of vocabulary token indices
153+
and weights, and a dictionary of decoded tokens and weights that had a weight
154+
> 0 in the sparse vector.
178155
179156
Args:
180-
embedding_input: EmbeddingInput object with a .text attribute
157+
embedding_input: EmbeddingInput instance
181158
"""
182-
# generate the sparse embeddings
183-
sparse_vector, decoded_tokens = self._encode_documents([embedding_input.text])[0]
184-
185-
# coerce sparse vector tensor into list[float]
186-
sparse_vector_list = sparse_vector.cpu().numpy().tolist()
187-
188-
return Embedding(
189-
timdex_record_id=embedding_input.timdex_record_id,
190-
run_id=embedding_input.run_id,
191-
run_record_offset=embedding_input.run_record_offset,
192-
model_uri=self.model_uri,
193-
embedding_strategy=embedding_input.embedding_strategy,
194-
embedding_vector=sparse_vector_list,
195-
embedding_token_weights=decoded_tokens,
196-
)
159+
sparse_vector = self._model.encode_document(embedding_input.text)
160+
sparse_vector = cast("Tensor", sparse_vector)
161+
return self._get_embedding_from_sparse_vector(embedding_input, sparse_vector)
197162

198-
def _encode_documents(
163+
def create_embeddings(
199164
self,
200-
texts: list[str],
201-
) -> list[tuple[torch.Tensor, dict[str, float]]]:
202-
"""Encode documents into sparse vectors and decoded token weights.
203-
204-
This follows the pattern outlined on the HuggingFace model card for document
205-
encoding.
206-
207-
This method will accommodate MULTIPLE text inputs, and return a list of
208-
embeddings, but the calling context of create_embedding() is a SINGULAR input +
209-
output. This method keeps the ability to handle multiple inputs + outputs, in the
210-
event we want something like a create_multiple_embeddings() method in the future,
211-
but only returns a single result.
212-
213-
At a very high level, the following is performed:
214-
215-
1. We tokenize the input text into "features" using the model's tokenizer.
165+
embedding_inputs: Iterator[EmbeddingInput],
166+
) -> Iterator[Embedding]:
167+
"""Yield Embeddings for multiple EmbeddingInputs.
216168
217-
2. The features are fed to the model returning model output logits. These logits
218-
are "dense" in the sense there are few zeros, but they are not "dense vectors"
219-
(embeddings) in the sense that they meaningfully represent the input document in
220-
geometric space; two logit tensors cannot be compared with something like cosine
221-
similarity.
169+
If env var TE_NUM_WORKERS is set and >1, the encoding lib sentence-transformers
170+
will automatically create a pool of worker processes to work in parallel.
222171
223-
3. The logits are then converted into a sparse vector, which is a numeric
224-
array of floats with the same number of values as the model's vocabulary. Each
225-
value's position in the sparse array corresponds to the token id in the
226-
vocabulary, and the value itself is the "weight" of this token in the input text.
172+
Note: currently 2+ workers in amd64 and arm64 Docker contexts immediately exits
173+
due to a "Bus Error". It is recommended to omit the env var TE_NUM_WORKERS, or
174+
set to "1", in Docker contexts.
227175
228-
4. Lastly, we convert this sparse vector into a {token:weight} dictionary of the
229-
actual token strings and their numerical weight. This dictionary may contain
230-
tokens not present in the original text, but will be considerably shorter than
231-
the model vocabulary length given all zero and low scoring tokens are dropped.
232-
This is the final form that we will ultimately index into OpenSearch.
176+
Currently, we also fully consume the input EmbeddingInputs before we start
177+
embedding work. This may change in future iterations if we move to batching
178+
embedding creation, so until then it's assumed that inputs to this method are
179+
memory safe for the full run.
233180
234181
Args:
235-
texts: list of strings to create embeddings for
182+
embedding_inputs: iterator of EmbeddingInputs
236183
"""
237-
if self._model is None or self._tokenizer is None:
238-
raise RuntimeError("Model not loaded. Call load() before create_embedding.")
184+
# consume input EmbeddingInputs
185+
embedding_inputs_list = list(embedding_inputs)
186+
if not embedding_inputs_list:
187+
return
188+
189+
# extract texts from all inputs
190+
texts = [embedding_input.text for embedding_input in embedding_inputs_list]
191+
192+
# read env vars for configurations
193+
num_workers = int(os.getenv("TE_NUM_WORKERS", "1"))
194+
batch_size = int(
195+
os.getenv("TE_BATCH_SIZE", "32")
196+
) # sentence-transformers default
197+
198+
# configure device and worker pool based on number of workers requested
199+
if num_workers > 1 or self.device == "mps":
200+
device = None
201+
pool = self._model.start_multi_process_pool(
202+
[self.device for _ in range(num_workers)]
203+
)
204+
else:
205+
device = self.device
206+
pool = None
207+
logger.info(
208+
f"Num workers: {num_workers}, batch size: {batch_size}, "
209+
f"device: {device}, pool: {pool}"
210+
)
239211

240-
# tokenize the input texts
241-
features = self._tokenizer(
212+
# get sparse vector embedding for input text(s)
213+
sparse_vectors = self._model.encode_document(
242214
texts,
243-
padding=True,
244-
truncation=True,
245-
return_tensors="pt", # returns PyTorch tensors instead of Python lists
246-
return_token_type_ids=False,
215+
batch_size=batch_size,
216+
device=device,
217+
pool=pool,
218+
save_to_cpu=True,
247219
)
220+
sparse_vectors = cast("list[Tensor]", sparse_vectors)
248221

249-
# move to CPU or GPU device, depending on what's available
250-
features = {k: v.to(self._device) for k, v in features.items()}
251-
252-
# pass features to the model and receive model output logits as a tensor
253-
with torch.no_grad():
254-
output = self._model(**features)[0]
255-
256-
# generate sparse vectors from model logits tensor
257-
sparse_vectors = self._get_sparse_vectors(features, output)
258-
259-
# decode sparse vectors to token-weight dictionaries
260-
decoded = self._decode_sparse_vectors(sparse_vectors)
261-
262-
# return list of tuple(vector, decoded token weights) embedding results
263-
return [(sparse_vectors[i], decoded[i]) for i in range(len(texts))]
222+
for i, embedding_input in enumerate(embedding_inputs_list):
223+
sparse_vector = sparse_vectors[i]
224+
sparse_vector = cast("Tensor", sparse_vector)
225+
yield self._get_embedding_from_sparse_vector(embedding_input, sparse_vector)
264226

265-
def _get_sparse_vectors(
266-
self, features: dict[str, torch.Tensor], output: torch.Tensor
267-
) -> torch.Tensor:
268-
"""Convert model logits output to sparse vectors.
269-
270-
This follows the HuggingFace model card exactly: https://huggingface.co/
271-
opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte#usage-huggingface
272-
273-
This implements the get_sparse_vector function from the model card:
274-
1. Max pooling with attention mask
275-
2. log(1 + log(1 + relu())) transformation
276-
3. Zero out special tokens
277-
278-
The end result is a sparse vector with a length of the model vocabulary, with each
279-
position representing a token in the model vocabulary and each value representing
280-
that token's weight relative to the input text.
281-
282-
Args:
283-
features: Tokenizer output with attention_mask
284-
output: Model logits of shape (batch_size, seq_len, vocab_size)
285-
286-
Returns:
287-
Sparse vectors of shape (batch_size, vocab_size)
288-
"""
289-
# collapse sequence positions: take max logit for each vocab token across all
290-
# positions (also masks out padding tokens)
291-
values, _ = torch.max(output * features["attention_mask"].unsqueeze(-1), dim=1)
292-
293-
# compress values to create sparsity: ReLU removes negatives,
294-
# double-log shrinks large values
295-
values = torch.log(1 + torch.log(1 + torch.relu(values)))
296-
297-
# remove special tokens like [CLS], [SEP], [PAD]
298-
values[:, self._special_token_ids] = 0
299-
300-
return values
301-
302-
def _decode_sparse_vectors(
303-
self, sparse_vectors: torch.Tensor
304-
) -> list[dict[str, float]]:
305-
"""Convert sparse vectors to token-weight dictionaries.
227+
def _get_embedding_from_sparse_vector(
228+
self,
229+
embedding_input: EmbeddingInput,
230+
sparse_vector: Tensor,
231+
) -> Embedding:
232+
"""Prepare Embedding from EmbeddingInput and calculated sparse vector.
306233
307-
Handles both single vectors and batches, returning a list of dictionaries mapping
308-
token strings to their weights.
234+
This shared method is used by create_embedding() and create_embeddings() to
235+
prepare and return an Embedding. A sparse vector is provided, which is decoded
236+
into a dictionary of tokens:weights, and a final Embedding instance is returned.
309237
310238
Args:
311-
sparse_vectors: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
312-
313-
Returns:
314-
List of dictionaries with token-weight pairs
239+
embedding_input: EmbeddingInput
240+
sparse_vector: sparse vector returned by model
315241
"""
316-
if sparse_vectors.dim() == 1:
317-
sparse_vectors = sparse_vectors.unsqueeze(0)
318-
319-
# move to CPU for processing
320-
sparse_vectors_cpu = sparse_vectors.cpu()
321-
322-
results: list[dict] = []
323-
for vector in sparse_vectors_cpu:
324-
325-
# find non-zero indices and values
326-
nonzero_indices = torch.nonzero(vector, as_tuple=False).squeeze(-1)
327-
328-
if nonzero_indices.numel() == 0:
329-
results.append({})
330-
continue
331-
332-
# get weights
333-
weights = vector[nonzero_indices].tolist()
334-
335-
# convert indices to token strings
336-
token_ids = nonzero_indices.tolist()
337-
tokens = self._tokenizer.convert_ids_to_tokens(token_ids) # type: ignore[union-attr]
242+
# get decoded dictionary of tokens:weights
243+
decoded_token_weights = self._model.decode(sparse_vector)
244+
decoded_token_weights = cast("list[tuple[str, float]]", decoded_token_weights)
245+
embedding_token_weights = dict(decoded_token_weights)
338246

339-
# create token:weight dictionary
340-
token_dict = {
341-
token: weight
342-
for token, weight in zip(tokens, weights, strict=True)
343-
if token is not None
344-
}
345-
results.append(token_dict)
247+
# prepare sparse vector for JSON serialization
248+
embedding_vector = sparse_vector.to_dense().tolist()
346249

347-
return results
250+
return Embedding(
251+
timdex_record_id=embedding_input.timdex_record_id,
252+
run_id=embedding_input.run_id,
253+
run_record_offset=embedding_input.run_record_offset,
254+
model_uri=self.model_uri,
255+
embedding_strategy=embedding_input.embedding_strategy,
256+
embedding_vector=embedding_vector,
257+
embedding_token_weights=embedding_token_weights,
258+
)

tests/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import logging
33
import zipfile
4+
from collections.abc import Iterator
45
from pathlib import Path
56

67
import pytest
@@ -56,6 +57,12 @@ def create_embedding(self, embedding_input: EmbeddingInput) -> Embedding:
5657
embedding_token_weights={"coffee": 0.9, "seattle": 0.5},
5758
)
5859

60+
def create_embeddings(
61+
self, embedding_inputs: Iterator[EmbeddingInput]
62+
) -> Iterator[Embedding]:
63+
for embedding_input in embedding_inputs:
64+
yield self.create_embedding(embedding_input)
65+
5966

6067
@pytest.fixture
6168
def mock_model(tmp_path):

0 commit comments

Comments
 (0)