1+ import json
12from typing import Any , Dict , List , Literal , Optional
23
34from langchain_core .embeddings import Embeddings
@@ -164,6 +165,46 @@ def get_available_models(
164165 """Get a list of available models that work with `NVIDIAEmbeddings`."""
165166 return cls (** kwargs ).available_models
166167
168+ def _prepare_payload (
169+ self , texts : List [str ], model_type : Literal ["passage" , "query" ]
170+ ) -> Dict [str , Any ]:
171+ """Prepare payload for both sync and async methods.
172+
173+ Args:
174+ texts: List of texts to embed
175+ model_type: Type of embedding ("passage" or "query")
176+
177+ Returns:
178+ Payload dictionary
179+ """
180+ payload : Dict [str , Any ] = {
181+ "input" : texts ,
182+ "model" : self .model ,
183+ "encoding_format" : "float" ,
184+ "input_type" : model_type ,
185+ }
186+ if self .truncate :
187+ payload ["truncate" ] = self .truncate
188+ if self .dimensions :
189+ payload ["dimensions" ] = self .dimensions
190+ return payload
191+
192+ def _process_response (self , result : Dict [str , Any ]) -> List [List [float ]]:
193+ """Process response for both sync and async methods.
194+
195+ Args:
196+ result: Parsed JSON response from the API
197+
198+ Returns:
199+ List of embeddings sorted by index
200+ """
201+ data = result .get ("data" , result )
202+ if not isinstance (data , list ):
203+ raise ValueError (f"Expected data with a list of embeddings. Got: { data } " )
204+ embedding_list = [(res ["embedding" ], res ["index" ]) for res in data ]
205+ self ._invoke_callback_vars (result )
206+ return [x [0 ] for x in sorted (embedding_list , key = lambda x : x [1 ])]
207+
167208 def _embed (
168209 self , texts : List [str ], model_type : Literal ["passage" , "query" ]
169210 ) -> List [List [float ]]:
@@ -177,47 +218,69 @@ def _embed(
177218 # truncate: "NONE" | "START" | "END" -- default "NONE", error raised if
178219 # an input is too long
179220 # dimensions: int -- not supported by all models
180- payload : Dict [str , Any ] = {
181- "input" : texts ,
182- "model" : self .model ,
183- "encoding_format" : "float" ,
184- "input_type" : model_type ,
185- }
186- if self .truncate :
187- payload ["truncate" ] = self .truncate
188- if self .dimensions :
189- payload ["dimensions" ] = self .dimensions
190-
221+ payload = self ._prepare_payload (texts , model_type )
191222 response = self ._client .get_req (
192223 payload = payload ,
193224 extra_headers = self .default_headers ,
194225 )
195226 response .raise_for_status ()
196227 result = response .json ()
197- data = result .get ("data" , result )
198- if not isinstance (data , list ):
199- raise ValueError (f"Expected data with a list of embeddings. Got: { data } " )
200- embedding_list = [(res ["embedding" ], res ["index" ]) for res in data ]
201- self ._invoke_callback_vars (result )
202- return [x [0 ] for x in sorted (embedding_list , key = lambda x : x [1 ])]
228+ return self ._process_response (result )
229+
230+ def _validate_texts (self , texts : List [str ]) -> None :
231+ """Validate that texts is a list of strings.
232+
233+ Args:
234+ texts: List to validate
235+
236+ Raises:
237+ ValueError: If texts is not a list of strings
238+ """
239+ if not isinstance (texts , list ) or not all (
240+ isinstance (text , str ) for text in texts
241+ ):
242+ raise ValueError (f"`texts` must be a list of strings, given: { repr (texts )} " )
203243
204244 def embed_query (self , text : str ) -> List [float ]:
205245 """Input pathway for query embeddings."""
206246 return self ._embed ([text ], model_type = "query" )[0 ]
207247
208248 def embed_documents (self , texts : List [str ]) -> List [List [float ]]:
209249 """Input pathway for document embeddings."""
210- if not isinstance (texts , list ) or not all (
211- isinstance (text , str ) for text in texts
212- ):
213- raise ValueError (f"`texts` must be a list of strings, given: { repr (texts )} " )
250+ self ._validate_texts (texts )
214251
215252 all_embeddings = []
216253 for i in range (0 , len (texts ), self .max_batch_size ):
217254 batch = texts [i : i + self .max_batch_size ]
218255 all_embeddings .extend (self ._embed (batch , model_type = "passage" ))
219256 return all_embeddings
220257
258+ async def _aembed (
259+ self , texts : List [str ], model_type : Literal ["passage" , "query" ]
260+ ) -> List [List [float ]]:
261+ """Async version of _embed."""
262+ payload = self ._prepare_payload (texts , model_type )
263+ response_text = await self ._client .aget_req (
264+ payload = payload ,
265+ extra_headers = self .default_headers ,
266+ )
267+ result = json .loads (response_text )
268+ return self ._process_response (result )
269+
270+ async def aembed_query (self , text : str ) -> List [float ]:
271+ """Async input pathway for query embeddings."""
272+ return (await self ._aembed ([text ], model_type = "query" ))[0 ]
273+
274+ async def aembed_documents (self , texts : List [str ]) -> List [List [float ]]:
275+ """Async input pathway for document embeddings."""
276+ self ._validate_texts (texts )
277+
278+ all_embeddings : List [List [float ]] = []
279+ for i in range (0 , len (texts ), self .max_batch_size ):
280+ batch = texts [i : i + self .max_batch_size ]
281+ all_embeddings .extend (await self ._aembed (batch , model_type = "passage" ))
282+ return all_embeddings
283+
221284 def _invoke_callback_vars (self , response : dict ) -> None :
222285 """Invoke the callback context variables if there are any."""
223286 callback_vars = [
0 commit comments