22
33import json
44import logging
5+ import os
56import shutil
67import tempfile
78import time
9+ from collections .abc import Iterator
810from pathlib import Path
9- from typing import TYPE_CHECKING
11+ from typing import cast
1012
11- import torch
1213from huggingface_hub import snapshot_download
13- from transformers import AutoModelForMaskedLM , AutoTokenizer
14+ from sentence_transformers .sparse_encoder import SparseEncoder
15+ from torch import Tensor
1416
1517from embeddings .embedding import Embedding , EmbeddingInput
1618from embeddings .models .base import BaseEmbeddingModel
1719
18- if TYPE_CHECKING :
19- from transformers import PreTrainedModel
20- from transformers .models .distilbert .tokenization_distilbert_fast import (
21- DistilBertTokenizerFast ,
22- )
23-
2420logger = logging .getLogger (__name__ )
2521
2622
@@ -42,10 +38,8 @@ def __init__(self, model_path: str | Path) -> None:
4238 model_path: Path where the model will be downloaded to and loaded from.
4339 """
4440 super ().__init__ (model_path )
45- self ._model : PreTrainedModel | None = None
46- self ._tokenizer : DistilBertTokenizerFast | None = None
47- self ._special_token_ids : list [int ] | None = None
48- self ._device : torch .device = torch .device ("cpu" )
41+ self .device = os .getenv ("TE_TORCH_DEVICE" , "cpu" )
42+ self ._model : SparseEncoder = None # type: ignore[assignment]
4943
5044 def download (self ) -> Path :
5145 """Download and prepare model, saving to self.model_path.
@@ -139,209 +133,126 @@ def load(self) -> None:
139133 start_time = time .perf_counter ()
140134 logger .info (f"Loading model from: { self .model_path } " )
141135
142- # ensure model exists locally
143136 if not self .model_path .exists ():
144137 raise FileNotFoundError (f"Model not found at path: { self .model_path } " )
145138
146- # setup device (use CUDA if available, otherwise CPU)
147- self ._device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
148-
149- # load tokenizer
150- self ._tokenizer = AutoTokenizer .from_pretrained ( # type: ignore[no-untyped-call]
151- self .model_path ,
152- local_files_only = True ,
153- )
154-
155- # load model as AutoModelForMaskedLM (required for sparse embeddings)
156- self ._model = AutoModelForMaskedLM .from_pretrained (
157- self .model_path ,
139+ # load model as SparseEncoder
140+ self ._model = SparseEncoder (
141+ str (self .model_path ),
158142 trust_remote_code = True ,
159- local_files_only = True ,
143+ model_kwargs = {},
144+ device = "cpu" ,
160145 )
161- self ._model .to (self ._device ) # type: ignore[arg-type]
162- self ._model .eval ()
163146
164- # set special token IDs (following model card pattern)
165- # these will be zeroed out in the sparse vectors
166- self ._special_token_ids = [
167- self ._tokenizer .vocab [token ] # type: ignore[index]
168- for token in self ._tokenizer .special_tokens_map .values ()
169- ]
170-
171- logger .info (
172- f"Model loaded successfully on { self ._device } , "
173- f"{ time .perf_counter () - start_time :.2f} s"
174- )
147+ logger .info (f"Model loaded successfully, { time .perf_counter () - start_time :.2f} s" )
175148
176149 def create_embedding (self , embedding_input : EmbeddingInput ) -> Embedding :
177- """Create sparse vector and decoded token weight embeddings for an input text.
150+ """Create an Embedding for an EmbeddingInput.
151+
152+ This model is configured to return a sparse vector of vocabulary token indices
153+ and weights, and a dictionary of decoded tokens and weights that had a weight
154+ > 0 in the sparse vector.
178155
179156 Args:
180- embedding_input: EmbeddingInput object with a .text attribute
157+ embedding_input: EmbeddingInput instance
181158 """
182- # generate the sparse embeddings
183- sparse_vector , decoded_tokens = self ._encode_documents ([embedding_input .text ])[0 ]
184-
185- # coerce sparse vector tensor into list[float]
186- sparse_vector_list = sparse_vector .cpu ().numpy ().tolist ()
187-
188- return Embedding (
189- timdex_record_id = embedding_input .timdex_record_id ,
190- run_id = embedding_input .run_id ,
191- run_record_offset = embedding_input .run_record_offset ,
192- model_uri = self .model_uri ,
193- embedding_strategy = embedding_input .embedding_strategy ,
194- embedding_vector = sparse_vector_list ,
195- embedding_token_weights = decoded_tokens ,
196- )
159+ sparse_vector = self ._model .encode_document (embedding_input .text )
160+ sparse_vector = cast ("Tensor" , sparse_vector )
161+ return self ._get_embedding_from_sparse_vector (embedding_input , sparse_vector )
197162
198- def _encode_documents (
163+ def create_embeddings (
199164 self ,
200- texts : list [str ],
201- ) -> list [tuple [torch .Tensor , dict [str , float ]]]:
202- """Encode documents into sparse vectors and decoded token weights.
203-
204- This follows the pattern outlined on the HuggingFace model card for document
205- encoding.
206-
207- This method will accommodate MULTIPLE text inputs, and return a list of
208- embeddings, but the calling context of create_embedding() is a SINGULAR input +
209- output. This method keeps the ability to handle multiple inputs + outputs, in the
210- event we want something like a create_multiple_embeddings() method in the future,
211- but only returns a single result.
212-
213- At a very high level, the following is performed:
214-
215- 1. We tokenize the input text into "features" using the model's tokenizer.
165+ embedding_inputs : Iterator [EmbeddingInput ],
166+ ) -> Iterator [Embedding ]:
167+ """Yield Embeddings for multiple EmbeddingInputs.
216168
217- 2. The features are fed to the model returning model output logits. These logits
218- are "dense" in the sense there are few zeros, but they are not "dense vectors"
219- (embeddings) in the sense that they meaningfully represent the input document in
220- geometric space; two logit tensors cannot be compared with something like cosine
221- similarity.
169+ If env var TE_NUM_WORKERS is set and >1, the encoding lib sentence-transformers
170+ will automatically create a pool of worker processes to work in parallel.
222171
223- 3. The logits are then converted into a sparse vector, which is a numeric
224- array of floats with the same number of values as the model's vocabulary. Each
225- value's position in the sparse array corresponds to the token id in the
226- vocabulary, and the value itself is the "weight" of this token in the input text.
172+ Note: currently 2+ workers in amd64 and arm64 Docker contexts immediately exits
173+ due to a "Bus Error". It is recommended to omit the env var TE_NUM_WORKERS, or
174+ set to "1", in Docker contexts.
227175
228- 4. Lastly, we convert this sparse vector into a {token:weight} dictionary of the
229- actual token strings and their numerical weight. This dictionary may contain
230- tokens not present in the original text, but will be considerably shorter than
231- the model vocabulary length given all zero and low scoring tokens are dropped.
232- This is the final form that we will ultimately index into OpenSearch.
176+ Currently, we also fully consume the input EmbeddingInputs before we start
177+ embedding work. This may change in future iterations if we move to batching
178+ embedding creation, so until then it's assumed that inputs to this method are
179+ memory safe for the full run.
233180
234181 Args:
235- texts: list of strings to create embeddings for
182+ embedding_inputs: iterator of EmbeddingInputs
236183 """
237- if self ._model is None or self ._tokenizer is None :
238- raise RuntimeError ("Model not loaded. Call load() before create_embedding." )
184+ # consume input EmbeddingInputs
185+ embedding_inputs_list = list (embedding_inputs )
186+ if not embedding_inputs_list :
187+ return
188+
189+ # extract texts from all inputs
190+ texts = [embedding_input .text for embedding_input in embedding_inputs_list ]
191+
192+ # read env vars for configurations
193+ num_workers = int (os .getenv ("TE_NUM_WORKERS" , "1" ))
194+ batch_size = int (
195+ os .getenv ("TE_BATCH_SIZE" , "32" )
196+ ) # sentence-transformers default
197+
198+ # configure device and worker pool based on number of workers requested
199+ if num_workers > 1 or self .device == "mps" :
200+ device = None
201+ pool = self ._model .start_multi_process_pool (
202+ [self .device for _ in range (num_workers )]
203+ )
204+ else :
205+ device = self .device
206+ pool = None
207+ logger .info (
208+ f"Num workers: { num_workers } , batch size: { batch_size } , "
209+ f"device: { device } , pool: { pool } "
210+ )
239211
240- # tokenize the input texts
241- features = self ._tokenizer (
212+ # get sparse vector embedding for input text(s)
213+ sparse_vectors = self ._model . encode_document (
242214 texts ,
243- padding = True ,
244- truncation = True ,
245- return_tensors = "pt" , # returns PyTorch tensors instead of Python lists
246- return_token_type_ids = False ,
215+ batch_size = batch_size ,
216+ device = device ,
217+ pool = pool ,
218+ save_to_cpu = True ,
247219 )
220+ sparse_vectors = cast ("list[Tensor]" , sparse_vectors )
248221
249- # move to CPU or GPU device, depending on what's available
250- features = {k : v .to (self ._device ) for k , v in features .items ()}
251-
252- # pass features to the model and receive model output logits as a tensor
253- with torch .no_grad ():
254- output = self ._model (** features )[0 ]
255-
256- # generate sparse vectors from model logits tensor
257- sparse_vectors = self ._get_sparse_vectors (features , output )
258-
259- # decode sparse vectors to token-weight dictionaries
260- decoded = self ._decode_sparse_vectors (sparse_vectors )
261-
262- # return list of tuple(vector, decoded token weights) embedding results
263- return [(sparse_vectors [i ], decoded [i ]) for i in range (len (texts ))]
222+ for i , embedding_input in enumerate (embedding_inputs_list ):
223+ sparse_vector = sparse_vectors [i ]
224+ sparse_vector = cast ("Tensor" , sparse_vector )
225+ yield self ._get_embedding_from_sparse_vector (embedding_input , sparse_vector )
264226
265- def _get_sparse_vectors (
266- self , features : dict [str , torch .Tensor ], output : torch .Tensor
267- ) -> torch .Tensor :
268- """Convert model logits output to sparse vectors.
269-
270- This follows the HuggingFace model card exactly: https://huggingface.co/
271- opensearch-project/opensearch-neural-sparse-encoding-doc-v3-gte#usage-huggingface
272-
273- This implements the get_sparse_vector function from the model card:
274- 1. Max pooling with attention mask
275- 2. log(1 + log(1 + relu())) transformation
276- 3. Zero out special tokens
277-
278- The end result is a sparse vector with a length of the model vocabulary, with each
279- position representing a token in the model vocabulary and each value representing
280- that token's weight relative to the input text.
281-
282- Args:
283- features: Tokenizer output with attention_mask
284- output: Model logits of shape (batch_size, seq_len, vocab_size)
285-
286- Returns:
287- Sparse vectors of shape (batch_size, vocab_size)
288- """
289- # collapse sequence positions: take max logit for each vocab token across all
290- # positions (also masks out padding tokens)
291- values , _ = torch .max (output * features ["attention_mask" ].unsqueeze (- 1 ), dim = 1 )
292-
293- # compress values to create sparsity: ReLU removes negatives,
294- # double-log shrinks large values
295- values = torch .log (1 + torch .log (1 + torch .relu (values )))
296-
297- # remove special tokens like [CLS], [SEP], [PAD]
298- values [:, self ._special_token_ids ] = 0
299-
300- return values
301-
302- def _decode_sparse_vectors (
303- self , sparse_vectors : torch .Tensor
304- ) -> list [dict [str , float ]]:
305- """Convert sparse vectors to token-weight dictionaries.
227+ def _get_embedding_from_sparse_vector (
228+ self ,
229+ embedding_input : EmbeddingInput ,
230+ sparse_vector : Tensor ,
231+ ) -> Embedding :
232+ """Prepare Embedding from EmbeddingInput and calculated sparse vector.
306233
307- Handles both single vectors and batches, returning a list of dictionaries mapping
308- token strings to their weights.
234+ This shared method is used by create_embedding() and create_embeddings() to
235+ prepare and return an Embedding. A sparse vector is provided, which is decoded
236+ into a dictionary of tokens:weights, and a final Embedding instance is returned.
309237
310238 Args:
311- sparse_vectors: Tensor of shape (batch_size, vocab_size) or (vocab_size,)
312-
313- Returns:
314- List of dictionaries with token-weight pairs
239+ embedding_input: EmbeddingInput
240+ sparse_vector: sparse vector returned by model
315241 """
316- if sparse_vectors .dim () == 1 :
317- sparse_vectors = sparse_vectors .unsqueeze (0 )
318-
319- # move to CPU for processing
320- sparse_vectors_cpu = sparse_vectors .cpu ()
321-
322- results : list [dict ] = []
323- for vector in sparse_vectors_cpu :
324-
325- # find non-zero indices and values
326- nonzero_indices = torch .nonzero (vector , as_tuple = False ).squeeze (- 1 )
327-
328- if nonzero_indices .numel () == 0 :
329- results .append ({})
330- continue
331-
332- # get weights
333- weights = vector [nonzero_indices ].tolist ()
334-
335- # convert indices to token strings
336- token_ids = nonzero_indices .tolist ()
337- tokens = self ._tokenizer .convert_ids_to_tokens (token_ids ) # type: ignore[union-attr]
242+ # get decoded dictionary of tokens:weights
243+ decoded_token_weights = self ._model .decode (sparse_vector )
244+ decoded_token_weights = cast ("list[tuple[str, float]]" , decoded_token_weights )
245+ embedding_token_weights = dict (decoded_token_weights )
338246
339- # create token:weight dictionary
340- token_dict = {
341- token : weight
342- for token , weight in zip (tokens , weights , strict = True )
343- if token is not None
344- }
345- results .append (token_dict )
247+ # prepare sparse vector for JSON serialization
248+ embedding_vector = sparse_vector .to_dense ().tolist ()
346249
347- return results
250+ return Embedding (
251+ timdex_record_id = embedding_input .timdex_record_id ,
252+ run_id = embedding_input .run_id ,
253+ run_record_offset = embedding_input .run_record_offset ,
254+ model_uri = self .model_uri ,
255+ embedding_strategy = embedding_input .embedding_strategy ,
256+ embedding_vector = embedding_vector ,
257+ embedding_token_weights = embedding_token_weights ,
258+ )
0 commit comments