diff --git a/docs/model_support.md b/docs/model_support.md index ba9acf5b1..6da5bc964 100644 --- a/docs/model_support.md +++ b/docs/model_support.md @@ -37,6 +37,7 @@ After these steps, the new model should be compatible with most FastChat feature - example: `python3 -m fastchat.serve.cli --model-path meta-llama/Llama-2-7b-chat-hf` - Vicuna, Alpaca, LLaMA, Koala - example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5` +- [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5) - [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) - [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B) - [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B) @@ -66,6 +67,7 @@ After these steps, the new model should be compatible with most FastChat feature - [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5) - [meta-math/MetaMath-7B-V1.0](https://huggingface.co/meta-math/MetaMath-7B-V1.0) - [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b) +- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large) - [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat) - example: `python3 -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat` - [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT) @@ -81,6 +83,7 @@ After these steps, the new model should be compatible with most FastChat feature - [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat) - [rishiraj/CatPPT](https://huggingface.co/rishiraj/CatPPT) - [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b) +- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual) - [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b) - [tenyx/TenyxChat-7B-v1](https://huggingface.co/tenyx/TenyxChat-7B-v1) - [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0) diff --git a/fastchat/model/model_adapter.py b/fastchat/model/model_adapter.py index 9625df6db..c7afcdabf 100644 --- a/fastchat/model/model_adapter.py +++ b/fastchat/model/model_adapter.py @@ -144,6 +144,34 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("one_shot") +class BaseEmbeddingModelAdapter(BaseModelAdapter): + """The base embedding model adapter""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "embedding" in model_path.lower() + + def load_model(self, model_path: str, from_pretrained_kwargs: dict): + revision = from_pretrained_kwargs.get("revision", "main") + model = AutoModel.from_pretrained( + model_path, + **from_pretrained_kwargs, + ) + tokenizer = AutoTokenizer.from_pretrained( + model_path, trust_remote_code=True, revision=revision + ) + if hasattr(model.config, "max_position_embeddings") and hasattr( + tokenizer, "model_max_length" + ): + model.config.max_sequence_length = min( + model.config.max_position_embeddings, tokenizer.model_max_length + ) + model.use_cls_pooling = True + model.eval() + return model, tokenizer + + # A global registry for all model adapters # TODO (lmzheng): make it a priority queue. model_adapters: List[BaseModelAdapter] = [] @@ -1836,7 +1864,7 @@ def get_default_conv_template(self, model_path: str) -> Conversation: return get_conv_template("qwen-7b-chat") -class BGEAdapter(BaseModelAdapter): +class BGEAdapter(BaseEmbeddingModelAdapter): """The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)""" use_fast_tokenizer = False @@ -1844,30 +1872,8 @@ class BGEAdapter(BaseModelAdapter): def match(self, model_path: str): return "bge" in model_path.lower() - def load_model(self, model_path: str, from_pretrained_kwargs: dict): - revision = from_pretrained_kwargs.get("revision", "main") - model = AutoModel.from_pretrained( - model_path, - **from_pretrained_kwargs, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, revision=revision - ) - if hasattr(model.config, "max_position_embeddings") and hasattr( - tokenizer, "model_max_length" - ): - model.config.max_sequence_length = min( - model.config.max_position_embeddings, tokenizer.model_max_length - ) - model.use_cls_pooling = True - model.eval() - return model, tokenizer - - def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("one_shot") - -class E5Adapter(BaseModelAdapter): +class E5Adapter(BaseEmbeddingModelAdapter): """The model adapter for E5 (e.g., intfloat/e5-large-v2)""" use_fast_tokenizer = False @@ -1875,25 +1881,32 @@ class E5Adapter(BaseModelAdapter): def match(self, model_path: str): return "e5-" in model_path.lower() - def load_model(self, model_path: str, from_pretrained_kwargs: dict): - revision = from_pretrained_kwargs.get("revision", "main") - model = AutoModel.from_pretrained( - model_path, - **from_pretrained_kwargs, - ) - tokenizer = AutoTokenizer.from_pretrained( - model_path, trust_remote_code=True, revision=revision - ) - if hasattr(model.config, "max_position_embeddings") and hasattr( - tokenizer, "model_max_length" - ): - model.config.max_sequence_length = min( - model.config.max_position_embeddings, tokenizer.model_max_length - ) - return model, tokenizer - def get_default_conv_template(self, model_path: str) -> Conversation: - return get_conv_template("one_shot") +class Text2VecAdapter(BaseEmbeddingModelAdapter): + """The model adapter for text2vec (e.g., shibing624/text2vec-base-chinese)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "text2vec" in model_path.lower() + + +class M3EAdapter(BaseEmbeddingModelAdapter): + """The model adapter for m3e (e.g., moka-ai/m3e-large)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "m3e-" in model_path.lower() + + +class GTEAdapter(BaseEmbeddingModelAdapter): + """The model adapter for gte (e.g., Alibaba-NLP/gte-large-en-v1.5)""" + + use_fast_tokenizer = False + + def match(self, model_path: str): + return "gte-" in model_path.lower() class AquilaChatAdapter(BaseModelAdapter): @@ -2562,6 +2575,9 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(AquilaChatAdapter) register_model_adapter(BGEAdapter) register_model_adapter(E5Adapter) +register_model_adapter(Text2VecAdapter) +register_model_adapter(M3EAdapter) +register_model_adapter(GTEAdapter) register_model_adapter(Lamma2ChineseAdapter) register_model_adapter(Lamma2ChineseAlpacaAdapter) register_model_adapter(VigogneAdapter) @@ -2603,5 +2619,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation: register_model_adapter(GrokAdapter) register_model_adapter(NoSystemAdapter) +register_model_adapter(BaseEmbeddingModelAdapter) # After all adapters, try the default base adapter. register_model_adapter(BaseModelAdapter)