Skip to content

Commit 59d9061

Browse files
committed
Experimental support for Anthropic Claude API
1 parent a08b470 commit 59d9061

File tree

3 files changed

+62
-3
lines changed

3 files changed

+62
-3
lines changed

pilot/main.py

+2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def init():
6666

6767
if '--api-key' in args:
6868
os.environ["OPENAI_API_KEY"] = args['--api-key']
69+
if '--model-name' in args:
70+
os.environ['MODEL_NAME'] = args['--model-name']
6971
if '--api-endpoint' in args:
7072
os.environ["OPENAI_ENDPOINT"] = args['--api-endpoint']
7173

pilot/utils/llm_connection.py

+59-3
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,9 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
109109
{'function_calls': {'name': str, arguments: {...}}}
110110
"""
111111

112+
model_name = os.getenv('MODEL_NAME', 'gpt-4')
112113
gpt_data = {
113-
'model': os.getenv('MODEL_NAME', 'gpt-4'),
114+
'model': model_name,
114115
'n': 1,
115116
'temperature': temperature,
116117
'top_p': 1,
@@ -133,8 +134,18 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
133134
if prompt_data is not None and function_call_message is not None:
134135
prompt_data['function_call_message'] = function_call_message
135136

137+
if '/' in model_name:
138+
model_provider, model_name = model_name.split('/', 1)
139+
else:
140+
model_provider = 'openai'
141+
136142
try:
137-
response = stream_gpt_completion(gpt_data, req_type, project)
143+
if model_provider == 'anthropic':
144+
if not os.getenv('ANTHROPIC_API_KEY'):
145+
os.environ['ANTHROPIC_API_KEY'] = os.getenv('OPENAI_API_KEY')
146+
response = stream_anthropic(messages, function_call_message, gpt_data, model_name)
147+
else:
148+
response = stream_gpt_completion(gpt_data, req_type, project)
138149

139150
# Remove JSON schema and any added retry messages
140151
while len(messages) > messages_length:
@@ -143,7 +154,7 @@ def create_gpt_chat_completion(messages: List[dict], req_type, project,
143154
except TokenLimitError as e:
144155
raise e
145156
except Exception as e:
146-
logger.error(f'The request to {os.getenv("ENDPOINT")} API failed: %s', e)
157+
logger.error(f'The request to {os.getenv("ENDPOINT")} API for {model_provider}/{model_name} failed: %s', e, exc_info=True)
147158
print(color_red(f'The request to {os.getenv("ENDPOINT")} API failed with error: {e}. Please try again later.'))
148159
if isinstance(e, ApiError):
149160
raise e
@@ -588,3 +599,48 @@ def postprocessing(gpt_response: str, req_type) -> str:
588599

589600
def load_data_to_json(string):
590601
return json.loads(fix_json(string))
602+
603+
604+
605+
def stream_anthropic(messages, function_call_message, gpt_data, model_name = "claude-3-sonnet-20240229"):
606+
try:
607+
import anthropic
608+
except ImportError as err:
609+
raise RuntimeError("The 'anthropic' package is required to use the Anthropic Claude LLM.") from err
610+
611+
client = anthropic.Anthropic(
612+
base_url=os.getenv('ANTHROPIC_ENDPOINT'),
613+
)
614+
615+
claude_system = "You are a software development AI assistant."
616+
claude_messages = messages
617+
if messages[0]["role"] == "system":
618+
claude_system = messages[0]["content"]
619+
claude_messages = messages[1:]
620+
621+
if len(claude_messages):
622+
cm2 = [claude_messages[0]]
623+
for i in range(1, len(claude_messages)):
624+
if cm2[-1]["role"] == claude_messages[i]["role"]:
625+
cm2[-1]["content"] += "\n\n" + claude_messages[i]["content"]
626+
else:
627+
cm2.append(claude_messages[i])
628+
claude_messages = cm2
629+
630+
response = ""
631+
with client.messages.stream(
632+
model=model_name,
633+
max_tokens=4096,
634+
temperature=0.5,
635+
system=claude_system,
636+
messages=claude_messages,
637+
) as stream:
638+
for chunk in stream.text_stream:
639+
print(chunk, type='stream', end='', flush=True)
640+
response += chunk
641+
642+
if function_call_message is not None:
643+
response = clean_json_response(response)
644+
assert_json_schema(response, gpt_data["functions"])
645+
646+
return {"text": response}

requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@ tiktoken==0.5.2
2525
urllib3==1.26.7
2626
wcwidth==0.2.8
2727
yaspin==2.5.0
28+
anthropic==0.19.1

0 commit comments

Comments
 (0)