From f497e329606340de31379cc0ba564610898fd0bf Mon Sep 17 00:00:00 2001 From: James Briggs Date: Sun, 28 Apr 2024 13:11:15 +0800 Subject: [PATCH] feat: add token limit to openai encoder --- .github/workflows/test.yml | 2 +- Makefile | 7 + coverage.xml | 2673 +++++++++++++++------------- semantic_router/encoders/openai.py | 59 +- semantic_router/schema.py | 43 +- tests/unit/encoders/test_openai.py | 1 - 6 files changed, 1552 insertions(+), 1233 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2a957e0..6e80f47 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -45,7 +45,7 @@ jobs: - name: Download nltk data run: | python -m nltk.downloader punkt stopwords wordnet - - name: Pytest + - name: Pytest All env: PINECONE_API_KEY: ${{ secrets.PINECONE_API_KEY }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} diff --git a/Makefile b/Makefile index a7c6964..3d76b7b 100644 --- a/Makefile +++ b/Makefile @@ -13,3 +13,10 @@ lint lint_diff: test: poetry run pytest -vv -n 20 --cov=semantic_router --cov-report=term-missing --cov-report=xml + +test_functional: + poetry run pytest -vv -n 20 tests/functional +test_unit: + poetry run pytest -vv -n 20 tests/unit +test_integration: + poetry run pytest -vv -n 20 tests/integration \ No newline at end of file diff --git a/coverage.xml b/coverage.xml index 321f6c5..bcccea3 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,24 +1,24 @@ - - + + - /Users/andreped/workspace/semantic-router/semantic_router + /Users/jamesbriggs/Documents/projects/aurelio-labs/semantic-router/semantic_router - + - + - - - + + + - + @@ -33,116 +33,116 @@ - - - + + + - - - - + + + + - - + + - - - - - - + + + + + + - - - - - - + + + + + + - + - - - - - - - - - - - - - + + + + + + + + + + + + + - - - - - - - + + + + + + + - - + + - - + + - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + - - - + + + - - - - - - + + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + + - - - + + + - + @@ -159,289 +159,324 @@ - - - - - - - + + + + + + + - - + + - + - - - - - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - + + - - - - - - - - - + + + + + + + + + - - - - - - - - - - - - + + + + + + + + + + + - + - - - + + + - - - - - - - - + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - - - - - - - - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -462,7 +497,7 @@ - + @@ -475,8 +510,8 @@ - - + + @@ -485,17 +520,17 @@ - - + + - - + + - - - + + + @@ -513,8 +548,8 @@ - - + + @@ -554,88 +589,75 @@ - + - + + + + + + - + + - - - + + - + - - + - - - - - + + + - - - - + + + + + - - + + - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - + + + + + + + + + + + + + + @@ -649,142 +671,265 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - + + + + + + + + + - - + + - - + + - + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - + - - - - - - - - - - - - - - - - - - - + + + + - - - + + - - - - - - - - - - - - - - - - - - + @@ -798,310 +943,314 @@ - - - + - + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - - - - + + + + + + + + + - - - - - - + + + + + + + - + - - - - - - - - - - - - + + + - - + + + - + + - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - + - - + + - - - - + + - - - - - - - + + + + + + + + + - + + - - + + + + + + + + + + + + + + + + + + + - + + + + - - - - - - - - - - - - - - - - - - - + + + + + + + + + - + + + - + + - + + - - - - - + + + + + - - + + - - + + + - - + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + @@ -1173,7 +1322,7 @@ - + @@ -1192,57 +1341,57 @@ - - + + - - - - - - - - - - - + + + + + + + + + + + - - - - - - - - + + + + + + + + - + - + - - + + - + - - - - + + + + - - - + + + - - - - - - - - + + + + + + + + @@ -1325,7 +1474,7 @@ - + @@ -1333,7 +1482,8 @@ - + + @@ -1355,14 +1505,14 @@ - - - - - + + + + + - + @@ -1371,57 +1521,69 @@ - - + + - - - - + + + + - - - - - - - + + + + + + + - + - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - - - - + + + + + - - - - - - + + + + + + + + + + + + - + @@ -1458,117 +1620,222 @@ - - - - - - - - - - + + + + + + - - + + + + + - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - + + - - - - - - - - - - - - - + + + - - - - - - - - + + + + + + + - - - - - - - - - - - - - - + + - - - - + + + + + + + + + + + + + + - + @@ -1579,10 +1846,11 @@ - + + - + @@ -1606,8 +1874,8 @@ - - + + @@ -1617,15 +1885,15 @@ - + - - - - - - - + + + + + + + @@ -1662,7 +1930,7 @@ - + @@ -1677,37 +1945,43 @@ - - - - - - - - - - - - + + + + + + + + + + + + + + - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - + @@ -1721,31 +1995,41 @@ - - + + - - - - - + + + + + + - - - - - - - - - + + + + + + + + + + - + + + + + + + + + @@ -1911,144 +2195,143 @@ - + - + - - - - - + + + + + - + - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + - + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + - - - + + @@ -2059,191 +2342,198 @@ - - - + + + + - - - + + + - - - - + + - - - - - - - - + + - - + + + + - + + - + + - - - + + - + + + + + + + + + - - - - - - - - - - + + + + + + + + + + - - - - + - - - - - - - + + + + + + + + - - - - - - - + + + + - - - + + + + + - - - - - - - + + + + - + - - + + + + + + + + - - - - - - - - - - - + + + + + + + + + - - - - - - + - - - - - - - - - + + + + + + + - - - + + + + + - - - - - + + + + + + + + + + + + + + + - + - - - + + + - + @@ -2251,7 +2541,7 @@ - + @@ -2268,6 +2558,7 @@ + diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 7712b19..14b9bb7 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -1,20 +1,43 @@ import os from time import sleep -from typing import List, Optional, Union +from typing import Any, List, Optional, Union import openai from openai import OpenAIError from openai._types import NotGiven from openai.types import CreateEmbeddingResponse +import tiktoken from semantic_router.encoders import BaseEncoder +from semantic_router.schema import EncoderInfo from semantic_router.utils.defaults import EncoderDefault from semantic_router.utils.logger import logger +model_configs = { + "text-embedding-ada-002": EncoderInfo( + name="text-embedding-ada-002", + type="openai", + token_limit=4000 + ), + "text-embed-3-small": EncoderInfo( + name="text-embed-3-small", + type="openai", + token_limit=8192 + ), + "text-embed-3-large": EncoderInfo( + name="text-embed-3-large", + type="openai", + token_limit=8192 + ) +} + + class OpenAIEncoder(BaseEncoder): client: Optional[openai.Client] dimensions: Union[int, NotGiven] = NotGiven() + token_limit: Optional[int] = None + token_encoder: Optional[Any] = None type: str = "openai" def __init__( @@ -44,13 +67,32 @@ def __init__( ) from e # set dimensions to support openai embed 3 dimensions param self.dimensions = dimensions + # if model name is known, set token limit + if name in model_configs: + self.token_limit = model_configs[name].token_limit + # get token encoder + self.token_encoder = tiktoken.encoding_for_model(name) - def __call__(self, docs: List[str]) -> List[List[float]]: + def __call__(self, docs: List[str], truncate: bool = True) -> List[List[float]]: + """Encode a list of text documents into embeddings using OpenAI API. + + :param docs: List of text documents to encode. + :param truncate: Whether to truncate the documents to token limit. If + False and a document exceeds the token limit, an error will be + raised. + :return: List of embeddings for each document.""" if self.client is None: raise ValueError("OpenAI client is not initialized.") embeds = None error_message = "" + if truncate: + # check if any document exceeds token limit and truncate if so + for i in range(len(docs)): + logger.info(f"Document {i+1} length: {len(docs[i])}") + docs[i] = self._truncate(docs[i]) + logger.info(f"Document {i+1} trunc length: {len(docs[i])}") + # Exponential backoff for j in range(1, 7): try: @@ -74,7 +116,20 @@ def __call__(self, docs: List[str]) -> List[List[float]]: or not isinstance(embeds, CreateEmbeddingResponse) or not embeds.data ): + logger.info(f"Returned embeddings: {embeds}") raise ValueError(f"No embeddings returned. Error: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] return embeddings + + def _truncate(self, text: str) -> str: + tokens = self.token_encoder.encode(text) + if len(tokens) > self.token_limit: + logger.warning( + f"Document exceeds token limit: {len(tokens)} > {self.token_limit}" + "\nTruncating document..." + ) + text = self.token_encoder.decode(tokens[:self.token_limit-1]) + logger.info(f"Trunc length: {len(self.token_encoder.encode(text))}") + return text + return text diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 6b7485a..daf6081 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -2,16 +2,6 @@ from typing import List, Optional from pydantic.v1 import BaseModel -from pydantic.v1.dataclasses import dataclass - -from semantic_router.encoders import ( - BaseEncoder, - CohereEncoder, - FastEmbedEncoder, - GoogleEncoder, - MistralEncoder, - OpenAIEncoder, -) class EncoderType(Enum): @@ -23,40 +13,17 @@ class EncoderType(Enum): GOOGLE = "google" +class EncoderInfo(BaseModel): + name: str + type: EncoderType + token_limit: int + class RouteChoice(BaseModel): name: Optional[str] = None function_call: Optional[dict] = None similarity_score: Optional[float] = None -@dataclass -class Encoder: - type: EncoderType - name: Optional[str] - model: BaseEncoder - - def __init__(self, type: str, name: Optional[str]): - self.type = EncoderType(type) - self.name = name - if self.type == EncoderType.HUGGINGFACE: - raise NotImplementedError - elif self.type == EncoderType.FASTEMBED: - self.model = FastEmbedEncoder(name=name) - elif self.type == EncoderType.OPENAI: - self.model = OpenAIEncoder(name=name) - elif self.type == EncoderType.COHERE: - self.model = CohereEncoder(name=name) - elif self.type == EncoderType.MISTRAL: - self.model = MistralEncoder(name=name) - elif self.type == EncoderType.GOOGLE: - self.model = GoogleEncoder(name=name) - else: - raise ValueError - - def __call__(self, texts: List[str]) -> List[List[float]]: - return self.model(texts) - - class Message(BaseModel): role: str content: str diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py index 508e9e9..de3594f 100644 --- a/tests/unit/encoders/test_openai.py +++ b/tests/unit/encoders/test_openai.py @@ -5,7 +5,6 @@ from semantic_router.encoders import OpenAIEncoder - @pytest.fixture def openai_encoder(mocker): mocker.patch("openai.Client")