Skip to content

Commit 0809f2c

Browse files
Lazy import models to speed up startup (#1)
* Lazy import models to speed up startup * Bump version to 0.1.4
1 parent 5876019 commit 0809f2c

File tree

2 files changed

+51
-41
lines changed

2 files changed

+51
-41
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "git-ai-summarize"
7-
version = "0.1.3"
7+
version = "0.1.4"
88
authors = [{ name = "Kevin Beaulieu", email = "[email protected]" }]
99
description = "AI-powered git commands for summarizing changes"
1010
readme = "README.md"

src/git_ai_summarize/models.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,5 @@
11
import os
22
from langchain_core.language_models.chat_models import BaseChatModel
3-
from langchain_anthropic import ChatAnthropic
4-
from langchain_openai import ChatOpenAI
5-
from langchain_google_genai import ChatGoogleGenerativeAI
6-
from langchain_mistralai import ChatMistralAI
7-
from langchain_fireworks import ChatFireworks
8-
from langchain_together import ChatTogether
9-
from langchain_google_vertexai import ChatVertexAI
10-
from langchain_groq import ChatGroq
11-
from langchain_nvidia_ai_endpoints import ChatNVIDIA
12-
from langchain_ollama import ChatOllama
13-
from langchain_ai21 import ChatAI21
14-
from langchain_upstage import ChatUpstage
15-
from langchain_databricks import ChatDatabricks
16-
from langchain_ibm import ChatWatsonx
17-
from langchain_xai import ChatXAI
183
from typing import List
194

205

@@ -24,12 +9,12 @@ def get_supported_providers() -> List[str]:
249
"anthropic",
2510
"openai",
2611
"google",
27-
"mistralai",
12+
"mistral",
2813
"fireworks",
2914
"together",
30-
"vertexai",
15+
"vertex",
3116
"groq",
32-
"nvidia_ai",
17+
"nvidia",
3318
"ollama",
3419
"ai21",
3520
"upstage",
@@ -41,30 +26,55 @@ def get_supported_providers() -> List[str]:
4126

4227
def get_model(provider_name: str | None, model_name: str | None) -> BaseChatModel:
4328
"""Initialize and configure the LangChain components with specified model."""
44-
providers = {
45-
"anthropic": (ChatAnthropic, "ANTHROPIC_API_KEY", "https://www.anthropic.com", "anthropic_api_key"),
46-
"openai": (ChatOpenAI, "OPENAI_API_KEY", "https://platform.openai.com/account/api-keys", "openai_api_key"),
47-
"google": (ChatGoogleGenerativeAI, "GOOGLE_API_KEY", "https://developers.generativeai.google/", "google_api_key"),
48-
"mistral": (ChatMistralAI, "MISTRAL_API_KEY", "https://console.mistral.ai/api-keys/", "mistral_api_key"),
49-
"fireworks": (ChatFireworks, "FIREWORKS_API_KEY", "https://app.fireworks.ai/", "fireworks_api_key"),
50-
"together": (ChatTogether, "TOGETHER_API_KEY", "https://api.together.xyz/", "together_api_key"),
51-
"vertex": (ChatVertexAI, "GOOGLE_APPLICATION_CREDENTIALS", "https://cloud.google.com/vertex-ai", None),
52-
"groq": (ChatGroq, "GROQ_API_KEY", "https://console.groq.com/", "groq_api_key"),
53-
"nvidia": (ChatNVIDIA, "NVIDIA_API_KEY", "https://api.nvidia.com/", "nvidia_api_key"),
54-
"ollama": (ChatOllama, None, "https://ollama.ai/", None),
55-
"ai21": (ChatAI21, "AI21_API_KEY", "https://www.ai21.com/studio", "ai21_api_key"),
56-
"upstage": (ChatUpstage, "UPSTAGE_API_KEY", "https://upstage.ai/", "upstage_api_key"),
57-
"databricks": (ChatDatabricks, "DATABRICKS_TOKEN", "https://www.databricks.com/", "databricks_token"),
58-
"watsonx": (ChatWatsonx, "WATSONX_API_KEY", "https://www.ibm.com/watsonx", "watsonx_api_key"),
59-
"xai": (ChatXAI, "XAI_API_KEY", "https://xai.com/", "xai_api_key"),
60-
}
61-
62-
max_tokens = 500
63-
64-
if provider_name not in providers:
29+
if provider_name == 'anthropic':
30+
from langchain_anthropic import ChatAnthropic
31+
model_class, api_key_env, api_url, api_key_param = ChatAnthropic, "ANTHROPIC_API_KEY", "https://www.anthropic.com", "anthropic_api_key"
32+
elif provider_name == 'openai':
33+
from langchain_openai import ChatOpenAI
34+
model_class, api_key_env, api_url, api_key_param = ChatOpenAI, "OPENAI_API_KEY", "https://platform.openai.com/account/api-keys", "openai_api_key"
35+
elif provider_name == 'google':
36+
from langchain_google_genai import ChatGoogleGenerativeAI
37+
model_class, api_key_env, api_url, api_key_param = ChatGoogleGenerativeAI, "GOOGLE_API_KEY", "https://developers.generativeai.google/", "google_api_key"
38+
elif provider_name == 'mistral':
39+
from langchain_mistralai import ChatMistralAI
40+
model_class, api_key_env, api_url, api_key_param = ChatMistralAI, "MISTRAL_API_KEY", "https://console.mistral.ai/api-keys/", "mistral_api_key"
41+
elif provider_name == 'fireworks':
42+
from langchain_fireworks import ChatFireworks
43+
model_class, api_key_env, api_url, api_key_param = ChatFireworks, "FIREWORKS_API_KEY", "https://app.fireworks.ai/", "fireworks_api_key"
44+
elif provider_name == 'together':
45+
from langchain_together import ChatTogether
46+
model_class, api_key_env, api_url, api_key_param = ChatTogether, "TOGETHER_API_KEY", "https://api.together.xyz/", "together_api_key"
47+
elif provider_name == 'vertex':
48+
from langchain_google_vertexai import ChatVertexAI
49+
model_class, api_key_env, api_url, api_key_param = ChatVertexAI, "GOOGLE_APPLICATION_CREDENTIALS", "https://cloud.google.com/vertex-ai", None
50+
elif provider_name == 'groq':
51+
from langchain_groq import ChatGroq
52+
model_class, api_key_env, api_url, api_key_param = ChatGroq, "GROQ_API_KEY", "https://console.groq.com/", "groq_api_key"
53+
elif provider_name == 'nvidia':
54+
from langchain_nvidia_ai_endpoints import ChatNVIDIA
55+
model_class, api_key_env, api_url, api_key_param = ChatNVIDIA, "NVIDIA_API_KEY", "https://api.nvidia.com/", "nvidia_api_key"
56+
elif provider_name == 'ollama':
57+
from langchain_ollama import ChatOllama
58+
model_class, api_key_env, api_url, api_key_param = ChatOllama, None, "https://ollama.ai/", None
59+
elif provider_name == 'ai21':
60+
from langchain_ai21 import ChatAI21
61+
model_class, api_key_env, api_url, api_key_param = ChatAI21, "AI21_API_KEY", "https://www.ai21.com/studio", "ai21_api_key"
62+
elif provider_name == 'upstage':
63+
from langchain_upstage import ChatUpstage
64+
model_class, api_key_env, api_url, api_key_param = ChatUpstage, "UPSTAGE_API_KEY", "https://upstage.ai/", "upstage_api_key"
65+
elif provider_name == 'databricks':
66+
from langchain_databricks import ChatDatabricks
67+
model_class, api_key_env, api_url, api_key_param = ChatDatabricks, "DATABRICKS_TOKEN", "https://www.databricks.com/", "databricks_token"
68+
elif provider_name == 'watsonx':
69+
from langchain_watsonx import ChatWatsonx
70+
model_class, api_key_env, api_url, api_key_param = ChatWatsonx, "WATSONX_API_KEY", "https://www.ibm.com/watsonx", "watsonx_api_key"
71+
elif provider_name == 'xai':
72+
from langchain_xai import ChatXAI
73+
model_class, api_key_env, api_url, api_key_param = ChatXAI, "XAI_API_KEY", "https://xai.com/", "xai_api_key"
74+
else:
6575
raise ValueError(f"Unsupported LLM provider: {provider_name}")
6676

67-
model_class, api_key_env, api_url, api_key_param = providers[provider_name]
77+
max_tokens = 500
6878

6979
if api_key_env:
7080
api_key = os.getenv(api_key_env)

0 commit comments

Comments
 (0)