Skip to content

Commit 6f665ca

Browse files
feat: Support async embedding/rerank/completion (#204)
1 parent 4b655a0 commit 6f665ca

File tree

10 files changed

+1005
-249
lines changed

10 files changed

+1005
-249
lines changed

libs/ai-endpoints/langchain_nvidia_ai_endpoints/_common.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,7 +810,8 @@ async def aget_req_stream(
810810
line = await reader.readline()
811811
if not line: # EOF
812812
break
813-
if line and line.strip() != b"data: [DONE]":
813+
line = line.strip()
814+
if line and line != b"data: [DONE]":
814815
line_str = line.decode("utf-8")
815816
msg, final_line = call.postprocess(line_str)
816817
yield msg

libs/ai-endpoints/langchain_nvidia_ai_endpoints/embeddings.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from typing import Any, Dict, List, Literal, Optional
23

34
from langchain_core.embeddings import Embeddings
@@ -164,6 +165,46 @@ def get_available_models(
164165
"""Get a list of available models that work with `NVIDIAEmbeddings`."""
165166
return cls(**kwargs).available_models
166167

168+
def _prepare_payload(
169+
self, texts: List[str], model_type: Literal["passage", "query"]
170+
) -> Dict[str, Any]:
171+
"""Prepare payload for both sync and async methods.
172+
173+
Args:
174+
texts: List of texts to embed
175+
model_type: Type of embedding ("passage" or "query")
176+
177+
Returns:
178+
Payload dictionary
179+
"""
180+
payload: Dict[str, Any] = {
181+
"input": texts,
182+
"model": self.model,
183+
"encoding_format": "float",
184+
"input_type": model_type,
185+
}
186+
if self.truncate:
187+
payload["truncate"] = self.truncate
188+
if self.dimensions:
189+
payload["dimensions"] = self.dimensions
190+
return payload
191+
192+
def _process_response(self, result: Dict[str, Any]) -> List[List[float]]:
193+
"""Process response for both sync and async methods.
194+
195+
Args:
196+
result: Parsed JSON response from the API
197+
198+
Returns:
199+
List of embeddings sorted by index
200+
"""
201+
data = result.get("data", result)
202+
if not isinstance(data, list):
203+
raise ValueError(f"Expected data with a list of embeddings. Got: {data}")
204+
embedding_list = [(res["embedding"], res["index"]) for res in data]
205+
self._invoke_callback_vars(result)
206+
return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
207+
167208
def _embed(
168209
self, texts: List[str], model_type: Literal["passage", "query"]
169210
) -> List[List[float]]:
@@ -177,47 +218,69 @@ def _embed(
177218
# truncate: "NONE" | "START" | "END" -- default "NONE", error raised if
178219
# an input is too long
179220
# dimensions: int -- not supported by all models
180-
payload: Dict[str, Any] = {
181-
"input": texts,
182-
"model": self.model,
183-
"encoding_format": "float",
184-
"input_type": model_type,
185-
}
186-
if self.truncate:
187-
payload["truncate"] = self.truncate
188-
if self.dimensions:
189-
payload["dimensions"] = self.dimensions
190-
221+
payload = self._prepare_payload(texts, model_type)
191222
response = self._client.get_req(
192223
payload=payload,
193224
extra_headers=self.default_headers,
194225
)
195226
response.raise_for_status()
196227
result = response.json()
197-
data = result.get("data", result)
198-
if not isinstance(data, list):
199-
raise ValueError(f"Expected data with a list of embeddings. Got: {data}")
200-
embedding_list = [(res["embedding"], res["index"]) for res in data]
201-
self._invoke_callback_vars(result)
202-
return [x[0] for x in sorted(embedding_list, key=lambda x: x[1])]
228+
return self._process_response(result)
229+
230+
def _validate_texts(self, texts: List[str]) -> None:
231+
"""Validate that texts is a list of strings.
232+
233+
Args:
234+
texts: List to validate
235+
236+
Raises:
237+
ValueError: If texts is not a list of strings
238+
"""
239+
if not isinstance(texts, list) or not all(
240+
isinstance(text, str) for text in texts
241+
):
242+
raise ValueError(f"`texts` must be a list of strings, given: {repr(texts)}")
203243

204244
def embed_query(self, text: str) -> List[float]:
205245
"""Input pathway for query embeddings."""
206246
return self._embed([text], model_type="query")[0]
207247

208248
def embed_documents(self, texts: List[str]) -> List[List[float]]:
209249
"""Input pathway for document embeddings."""
210-
if not isinstance(texts, list) or not all(
211-
isinstance(text, str) for text in texts
212-
):
213-
raise ValueError(f"`texts` must be a list of strings, given: {repr(texts)}")
250+
self._validate_texts(texts)
214251

215252
all_embeddings = []
216253
for i in range(0, len(texts), self.max_batch_size):
217254
batch = texts[i : i + self.max_batch_size]
218255
all_embeddings.extend(self._embed(batch, model_type="passage"))
219256
return all_embeddings
220257

258+
async def _aembed(
259+
self, texts: List[str], model_type: Literal["passage", "query"]
260+
) -> List[List[float]]:
261+
"""Async version of _embed."""
262+
payload = self._prepare_payload(texts, model_type)
263+
response_text = await self._client.aget_req(
264+
payload=payload,
265+
extra_headers=self.default_headers,
266+
)
267+
result = json.loads(response_text)
268+
return self._process_response(result)
269+
270+
async def aembed_query(self, text: str) -> List[float]:
271+
"""Async input pathway for query embeddings."""
272+
return (await self._aembed([text], model_type="query"))[0]
273+
274+
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
275+
"""Async input pathway for document embeddings."""
276+
self._validate_texts(texts)
277+
278+
all_embeddings: List[List[float]] = []
279+
for i in range(0, len(texts), self.max_batch_size):
280+
batch = texts[i : i + self.max_batch_size]
281+
all_embeddings.extend(await self._aembed(batch, model_type="passage"))
282+
return all_embeddings
283+
221284
def _invoke_callback_vars(self, response: dict) -> None:
222285
"""Invoke the callback context variables if there are any."""
223286
callback_vars = [

libs/ai-endpoints/langchain_nvidia_ai_endpoints/llm.py

Lines changed: 101 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
from __future__ import annotations
22

3+
import json
34
import warnings
4-
from typing import Any, Dict, Iterator, List, Optional
5+
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
56

6-
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
7+
from langchain_core.callbacks.manager import (
8+
AsyncCallbackManagerForLLMRun,
9+
CallbackManagerForLLMRun,
10+
)
711
from langchain_core.language_models.llms import LLM
812
from langchain_core.outputs import GenerationChunk
913
from pydantic import ConfigDict, Field, PrivateAttr
@@ -168,13 +172,22 @@ def _identifying_params(self) -> Dict[str, Any]:
168172
"base_url": self.base_url,
169173
}
170174

171-
def _call(
175+
def _prepare_call_payload(
172176
self,
173177
prompt: str,
174178
stop: Optional[List[str]] = None,
175-
run_manager: Optional[CallbackManagerForLLMRun] = None,
176179
**kwargs: Any,
177-
) -> str:
180+
) -> Dict[str, Any]:
181+
"""Prepare payload for non-streaming calls (both sync and async).
182+
183+
Args:
184+
prompt: The prompt to send
185+
stop: Stop words
186+
kwargs: Additional keyword arguments
187+
188+
Returns:
189+
Payload dictionary
190+
"""
178191
payload: Dict[str, Any] = {
179192
"model": self.model,
180193
"prompt": prompt,
@@ -188,28 +201,24 @@ def _call(
188201
warnings.warn("stream set to true for non-streaming call, ignoring")
189202
del payload["stream"]
190203

191-
response = self._client.get_req(payload=payload)
192-
response.raise_for_status()
204+
return payload
193205

194-
# todo: handle response's usage and system_fingerprint
195-
196-
choices = response.json()["choices"]
197-
# todo: write a test for this by setting n > 1 on the request
198-
# aug 2024: n > 1 is not supported by endpoints
199-
if len(choices) > 1:
200-
warnings.warn(
201-
f"Multiple choices in response, returning only the first: {choices}"
202-
)
203-
204-
return choices[0]["text"]
205-
206-
def _stream(
206+
def _prepare_stream_payload(
207207
self,
208208
prompt: str,
209209
stop: Optional[List[str]] = None,
210-
run_manager: Optional[CallbackManagerForLLMRun] = None,
211210
**kwargs: Any,
212-
) -> Iterator[GenerationChunk]:
211+
) -> Dict[str, Any]:
212+
"""Prepare payload for streaming calls (both sync and async).
213+
214+
Args:
215+
prompt: The prompt to send
216+
stop: Stop words
217+
kwargs: Additional keyword arguments
218+
219+
Returns:
220+
Payload dictionary
221+
"""
213222
payload: Dict[str, Any] = {
214223
"model": self.model,
215224
"prompt": prompt,
@@ -226,9 +235,79 @@ def _stream(
226235
warnings.warn("stream set to false for streaming call, ignoring")
227236
payload["stream"] = True
228237

238+
return payload
239+
240+
def _process_result(self, result: Dict[str, Any]) -> str:
241+
"""Process parsed JSON result from both sync and async call methods.
242+
243+
Args:
244+
result: Parsed JSON response
245+
246+
Returns:
247+
Generated text
248+
"""
249+
# todo: handle response's usage and system_fingerprint
250+
choices = result["choices"]
251+
# todo: write a test for this by setting n > 1 on the request
252+
# aug 2024: n > 1 is not supported by endpoints
253+
if len(choices) > 1:
254+
warnings.warn(
255+
f"Multiple choices in response, returning only the first: {choices}"
256+
)
257+
258+
return choices[0]["text"]
259+
260+
def _call(
261+
self,
262+
prompt: str,
263+
stop: Optional[List[str]] = None,
264+
run_manager: Optional[CallbackManagerForLLMRun] = None,
265+
**kwargs: Any,
266+
) -> str:
267+
payload = self._prepare_call_payload(prompt, stop, **kwargs)
268+
response = self._client.get_req(payload=payload)
269+
response.raise_for_status()
270+
result = response.json()
271+
return self._process_result(result)
272+
273+
def _stream(
274+
self,
275+
prompt: str,
276+
stop: Optional[List[str]] = None,
277+
run_manager: Optional[CallbackManagerForLLMRun] = None,
278+
**kwargs: Any,
279+
) -> Iterator[GenerationChunk]:
280+
payload = self._prepare_stream_payload(prompt, stop, **kwargs)
229281
for chunk in self._client.get_req_stream(payload=payload):
230282
content = chunk["content"]
231283
generation = GenerationChunk(text=content)
232284
if run_manager: # todo: add tests for run_manager
233285
run_manager.on_llm_new_token(content, chunk=generation)
234286
yield generation
287+
288+
async def _acall(
289+
self,
290+
prompt: str,
291+
stop: Optional[List[str]] = None,
292+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
293+
**kwargs: Any,
294+
) -> str:
295+
payload = self._prepare_call_payload(prompt, stop, **kwargs)
296+
response_text = await self._client.aget_req(payload=payload)
297+
result = json.loads(response_text)
298+
return self._process_result(result)
299+
300+
async def _astream(
301+
self,
302+
prompt: str,
303+
stop: Optional[List[str]] = None,
304+
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
305+
**kwargs: Any,
306+
) -> AsyncIterator[GenerationChunk]:
307+
payload = self._prepare_stream_payload(prompt, stop, **kwargs)
308+
async for chunk in self._client.aget_req_stream(payload=payload):
309+
content = chunk["content"]
310+
generation = GenerationChunk(text=content)
311+
if run_manager:
312+
await run_manager.on_llm_new_token(content, chunk=generation)
313+
yield generation

0 commit comments

Comments
 (0)