Skip to content

Commit afc3256

Browse files
authored
refactor llm service architecture with factory pattern (#70)
* refactor llm service architecture with factory pattern * refactor llm_chat_completion function to improve clarity * remove docstring for LLMService and chat_completion * remove unused imports, simplify LLMServiceFactory
1 parent fca1ff6 commit afc3256

File tree

6 files changed

+40
-18
lines changed

6 files changed

+40
-18
lines changed
Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
1-
from ai_commit_msg.services.anthropic_service import AnthropicService
21
from ai_commit_msg.services.config_service import ConfigService
3-
from ai_commit_msg.services.ollama_service import OLlamaService
4-
from ai_commit_msg.services.openai_service import OpenAiService
2+
from ai_commit_msg.services.llm_service_factory import LLMServiceFactory
53
from ai_commit_msg.utils.logger import Logger
6-
from ai_commit_msg.utils.models import ANTHROPIC_MODEL_LIST, OPEN_AI_MODEL_LIST
74

85

96
def llm_chat_completion(prompt):
107
select_model = ConfigService.get_model()
118

12-
# TODO - create a factory with a shared interface for calling the LLM models, this will make it easier to add new models
13-
ai_gen_commit_msg = None
14-
if str(select_model) in OPEN_AI_MODEL_LIST:
15-
ai_gen_commit_msg = OpenAiService().chat_with_openai(prompt)
16-
elif select_model.startswith("ollama"):
17-
ai_gen_commit_msg = OLlamaService().chat_completion(prompt)
18-
elif select_model in ANTHROPIC_MODEL_LIST:
19-
ai_gen_commit_msg = AnthropicService().chat_completion(prompt)
9+
service = LLMServiceFactory.create_service(select_model)
2010

21-
if ai_gen_commit_msg is None:
11+
if service is None:
2212
Logger().log("Unsupported model: " + select_model)
2313
return ""
2414

25-
return ai_gen_commit_msg
15+
return service.chat_completion(prompt)

ai_commit_msg/services/anthropic_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
from ai_commit_msg.utils.models import ANTHROPIC_MODEL_LIST
55
from ai_commit_msg.services.config_service import ConfigService
66
from ai_commit_msg.utils.error import map_error
7+
from ai_commit_msg.services.llm_service import LLMService
78

89

9-
class AnthropicService:
10+
class AnthropicService(LLMService):
1011
api_key = ""
1112

1213
def __init__(self):
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from abc import ABC, abstractmethod
2+
3+
class LLMService(ABC):
4+
@abstractmethod
5+
def chat_completion(self, messages):
6+
pass
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from ai_commit_msg.services.openai_service import OpenAiService
2+
from ai_commit_msg.services.anthropic_service import AnthropicService
3+
from ai_commit_msg.services.ollama_service import OLlamaService
4+
from ai_commit_msg.utils.models import OPEN_AI_MODEL_LIST, ANTHROPIC_MODEL_LIST
5+
from ai_commit_msg.utils.logger import Logger
6+
7+
8+
class LLMServiceFactory:
9+
10+
@staticmethod
11+
def create_service(model_name):
12+
if model_name in OPEN_AI_MODEL_LIST:
13+
return OpenAiService()
14+
elif model_name.startswith("ollama"):
15+
return OLlamaService()
16+
elif model_name in ANTHROPIC_MODEL_LIST:
17+
return AnthropicService()
18+
else:
19+
Logger().log(f"Unsupported model: {model_name}")
20+
return None

ai_commit_msg/services/ollama_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22

33
from ai_commit_msg.services.config_service import ConfigService
44
from ai_commit_msg.utils.logger import Logger
5+
from ai_commit_msg.services.llm_service import LLMService
56

67

7-
class OLlamaService:
8+
class OLlamaService(LLMService):
89
def __init__(self):
910
self.url = ConfigService.get_ollama_url()
1011

ai_commit_msg/services/openai_service.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
CONFIG_COLLECTION_KEY,
99
)
1010
from ai_commit_msg.utils.models import OPEN_AI_MODEL_LIST
11+
from ai_commit_msg.services.llm_service import LLMService
1112

1213

13-
class OpenAiService:
14+
class OpenAiService(LLMService):
1415
client = None
1516

1617
def __init__(self):
@@ -26,7 +27,7 @@ def __init__(self):
2627
)
2728
self.client = OpenAI(api_key=api_key)
2829

29-
def chat_with_openai(self, messages):
30+
def chat_completion(self, messages):
3031
model_name = ConfigService.get_model()
3132

3233
if model_name not in OPEN_AI_MODEL_LIST:
@@ -41,6 +42,9 @@ def chat_with_openai(self, messages):
4142
except Exception as e:
4243
raise map_error("OPENAI", getattr(e, "code", str(e)), e)
4344

45+
def chat_with_openai(self, messages):
46+
return self.chat_completion(messages)
47+
4448
@staticmethod
4549
def get_openai_api_key():
4650
raw_json_db = LocalDbService().get_db()[CONFIG_COLLECTION_KEY]

0 commit comments

Comments
 (0)