Skip to content

Commit

Permalink
add siliconflow embedding class (run-llama#16753)
Browse files Browse the repository at this point in the history
  • Loading branch information
nightosong authored Oct 31, 2024
1 parent 9d0db08 commit 35a13b9
Show file tree
Hide file tree
Showing 11 changed files with 561 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
llama_index/_static
.DS_Store
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
bin/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
etc/
include/
lib/
lib64/
parts/
sdist/
share/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
.ruff_cache

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints
notebooks/

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
pyvenv.cfg

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Jetbrains
.idea
modules/
*.swp

# VsCode
.vscode

# pipenv
Pipfile
Pipfile.lock

# pyright
pyrightconfig.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
poetry_requirements(
name="poetry",
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
GIT_ROOT ?= $(shell git rev-parse --show-toplevel)

help: ## Show all Makefile targets.
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[33m%-30s\033[0m %s\n", $$1, $$2}'

format: ## Run code autoformatters (black).
pre-commit install
git ls-files | xargs pre-commit run black --files

lint: ## Run linters: pre-commit (black, ruff, codespell) and mypy
pre-commit install && git ls-files | xargs pre-commit run --show-diff-on-failure --files

test: ## Run tests via pytest.
pytest tests

watch-docs: ## Build and watch documentation.
sphinx-autobuild docs/ docs/_build/html --open-browser --watch $(GIT_ROOT)/llama_index/
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# LlamaIndex Embeddings Integration: SiliconFlow

## 1. Product Introduction

SiliconCloud provides cost-effective GenAI services based on an excellent open-source foundation model.
introduction: https://docs.siliconflow.cn/introduction

## 2. Product features

- As a one-stop cloud service platform that integrates top large models, SiliconCloud is committed to providing developers with faster, cheaper, more comprehensive, and smoother model APIs.

- SiliconCloud has been listed on Qwen2.5-72B, DeepSeek-V2.5, Qwen2, InternLM2.5-20B-Chat, BCE, BGE, SenseVoice-Small, Llama-3.1, FLUX.1, DeepSeek-Coder-V2, SD3 Medium, GLM-4-9B-Chat, A variety of open-source large language models, image generation models, code generation models, vector and reordering models, and multimodal large models, including InstantID.

- Among them, Qwen 2.5 (7B), Llama 3.1 (8B) and other large model APIs are free to use, so that developers and product managers do not need to worry about the computing power costs caused by the R&D stage and large-scale promotion, and realize "token freedom".

- Provide out-of-the-box large model inference acceleration services to bring a more efficient user experience to your GenAI applications.

## 3. Installation

```shell
pip install llama-index-embeddings-siliconflow
```

## 4. Usage

```python
import asyncio
import os
from llama_index.embeddings.siliconflow import SiliconFlowEmbedding

embedding = SiliconFlowEmbedding(
model="BAAI/bge-m3",
api_key=os.getenv("SILICONFLOW_API_KEY"),
)

response = embedding.get_query_embedding("...")
print(response)

response = asyncio.run(embedding.aget_query_embedding("..."))
print(response)
```
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python_sources()
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from llama_index.embeddings.siliconflow.base import SiliconFlowEmbedding

__all__ = ["SiliconFlowEmbedding"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
"""SiliconFLow embeddings file."""

import aiohttp
import base64
import requests
import struct
from typing import Any, List, Optional
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding


DEFAULT_SILICONFLOW_API_URL = "https://api.siliconflow.cn/v1/embeddings"

VALID_ENCODING = ["float", "base64"]

AVAILABLE_OPTIONS = [
("Pro/BAAI/bge-m3", 1024), ## 8192 tokens
("BAAI/bge-m3", 1024), ## 8192 tokens
("BAAI/bge-large-zh-v1.5", 1024), ## 512 tokens
("BAAI/bge-large-en-v1.5", 1024), ## 512 tokens
("netease-youdao/bce-embedding-base_v1", 768), ## 512 tokens
]


def base64_to_float_list(encoded_str: str) -> List[float]:
byte_data = base64.b64decode(encoded_str)
float_count = len(byte_data) // 4
float_list = struct.unpack(f"{float_count}f", byte_data)
return list(float_list)


class SiliconFlowEmbedding(BaseEmbedding):
"""SiliconFlow class for embeddings."""

model: str = Field(
default="BAAI/bge-m3",
description="""\
The name of the embedding model to use.
512 tokens for all models input except `bge-m3` which is 8192.
""",
)
api_key: Optional[str] = Field(
default=None,
description="The SiliconFlow API key.",
)
base_url: str = Field(
default=DEFAULT_SILICONFLOW_API_URL,
description="The base URL for the SiliconFlow API.",
)
encoding_format: str = Field(
default="float",
description="The format to return the embeddings in. Can be either float or base64.",
) # TODO: Consider whether to fix the encoding format as float.

_headers: Any = PrivateAttr()

def __init__(
self,
model: str = "BAAI/bge-m3",
api_key: Optional[str] = None,
base_url: str = DEFAULT_SILICONFLOW_API_URL,
encoding_format: Optional[str] = "float",
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
model=model,
api_key=api_key,
base_url=base_url,
encoding_format=encoding_format,
callback_manager=callback_manager,
**kwargs,
)
assert (
self.encoding_format in VALID_ENCODING
), f"""\
Encoding_format parameter {self.encoding_format} not supported.
Please choose one of {VALID_ENCODING}".
"""

self._headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
}

@classmethod
def class_name(cls) -> str:
return "SiliconFlowEmbedding"

def _data_formatting(self, response: list) -> List[List[float]]:
results = sorted(response["data"], key=lambda e: e["index"])
if self.encoding_format == "base64":
return [base64_to_float_list(data["embedding"]) for data in results]
else:
return [data["embedding"] for data in results]

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self._get_text_embeddings([query])[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self._aget_text_embeddings([query])
return result[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
with requests.Session() as session:
input_json = {
"model": self.model,
"input": texts,
"encoding_format": self.encoding_format,
}
response = session.post(
self.base_url, json=input_json, headers=self._headers
).json()
if "data" not in response:
raise RuntimeError(response)
return self._data_formatting(response)

async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
async with aiohttp.ClientSession() as session:
input_json = {
"input": texts,
"model": self.model,
"encoding_format": self.encoding_format,
}

async with session.post(
self.base_url, json=input_json, headers=self._headers
) as response:
response_json = await response.json()
response.raise_for_status()
return self._data_formatting(response_json)
Loading

0 comments on commit 35a13b9

Please sign in to comment.