@@ -109,8 +109,9 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
109
109
{'function_calls': {'name': str, arguments: {...}}}
110
110
"""
111
111
112
+ model_name = os .getenv ('MODEL_NAME' , 'gpt-4' )
112
113
gpt_data = {
113
- 'model' : os . getenv ( 'MODEL_NAME' , 'gpt-4' ) ,
114
+ 'model' : model_name ,
114
115
'n' : 1 ,
115
116
'temperature' : temperature ,
116
117
'top_p' : 1 ,
@@ -133,8 +134,18 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
133
134
if prompt_data is not None and function_call_message is not None :
134
135
prompt_data ['function_call_message' ] = function_call_message
135
136
137
+ if '/' in model_name :
138
+ model_provider , model_name = model_name .split ('/' , 1 )
139
+ else :
140
+ model_provider = 'openai'
141
+
136
142
try :
137
- response = stream_gpt_completion (gpt_data , req_type , project )
143
+ if model_provider == 'anthropic' :
144
+ if not os .getenv ('ANTHROPIC_API_KEY' ):
145
+ os .environ ['ANTHROPIC_API_KEY' ] = os .getenv ('OPENAI_API_KEY' )
146
+ response = stream_anthropic (messages , function_call_message , gpt_data , model_name )
147
+ else :
148
+ response = stream_gpt_completion (gpt_data , req_type , project )
138
149
139
150
# Remove JSON schema and any added retry messages
140
151
while len (messages ) > messages_length :
@@ -143,7 +154,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
143
154
except TokenLimitError as e :
144
155
raise e
145
156
except Exception as e :
146
- logger .error (f'The request to { os .getenv ("ENDPOINT" )} API failed: %s' , e )
157
+ logger .error (f'The request to { os .getenv ("ENDPOINT" )} API for { model_provider } / { model_name } failed: %s' , e , exc_info = True )
147
158
print (color_red (f'The request to { os .getenv ("ENDPOINT" )} API failed with error: { e } . Please try again later.' ))
148
159
if isinstance (e , ApiError ):
149
160
raise e
@@ -588,3 +599,48 @@ def postprocessing(gpt_response: str, req_type) -> str:
588
599
589
600
def load_data_to_json (string ):
590
601
return json .loads (fix_json (string ))
602
+
603
+
604
+
605
+ def stream_anthropic (messages , function_call_message , gpt_data , model_name = "claude-3-sonnet-20240229" ):
606
+ try :
607
+ import anthropic
608
+ except ImportError as err :
609
+ raise RuntimeError ("The 'anthropic' package is required to use the Anthropic Claude LLM." ) from err
610
+
611
+ client = anthropic .Anthropic (
612
+ base_url = os .getenv ('ANTHROPIC_ENDPOINT' ),
613
+ )
614
+
615
+ claude_system = "You are a software development AI assistant."
616
+ claude_messages = messages
617
+ if messages [0 ]["role" ] == "system" :
618
+ claude_system = messages [0 ]["content" ]
619
+ claude_messages = messages [1 :]
620
+
621
+ if len (claude_messages ):
622
+ cm2 = [claude_messages [0 ]]
623
+ for i in range (1 , len (claude_messages )):
624
+ if cm2 [- 1 ]["role" ] == claude_messages [i ]["role" ]:
625
+ cm2 [- 1 ]["content" ] += "\n \n " + claude_messages [i ]["content" ]
626
+ else :
627
+ cm2 .append (claude_messages [i ])
628
+ claude_messages = cm2
629
+
630
+ response = ""
631
+ with client .messages .stream (
632
+ model = model_name ,
633
+ max_tokens = 4096 ,
634
+ temperature = 0.5 ,
635
+ system = claude_system ,
636
+ messages = claude_messages ,
637
+ ) as stream :
638
+ for chunk in stream .text_stream :
639
+ print (chunk , type = 'stream' , end = '' , flush = True )
640
+ response += chunk
641
+
642
+ if function_call_message is not None :
643
+ response = clean_json_response (response )
644
+ assert_json_schema (response , gpt_data ["functions" ])
645
+
646
+ return {"text" : response }
0 commit comments