This repository has been archived by the owner on Aug 12, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathe2e_function_calling.py
73 lines (68 loc) · 2.89 KB
/
e2e_function_calling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import asyncio
import os
from dotenv import load_dotenv
from prem_utils.connectors.anthropic import AnthropicConnector
from prem_utils.connectors.anyscale import AnyscaleEndpointsConnector
from prem_utils.connectors.azure import AzureOpenAIConnector
from prem_utils.connectors.cohere import CohereConnector
from prem_utils.connectors.groq import GroqConnector
from prem_utils.connectors.mistral import MistralAzureConnector, MistralConnector
from prem_utils.connectors.openai import OpenAIConnector
load_dotenv()
messages = [{"role": "user", "content": "What is the weather like in San Francisco?"}]
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}
},
"required": ["location"],
},
},
}
]
connectors = [
{"connector": AnthropicConnector(api_key=os.environ["ANTHROPIC_API_KEY"]), "model": "claude-3-haiku-20240307"},
{"connector": OpenAIConnector(api_key=os.environ["OPENAI_API_KEY"]), "model": "gpt-4o"},
{
"connector": AzureOpenAIConnector(
api_key=os.environ["AZURE_OPENAI_API_KEY"], base_url=os.environ["AZURE_OPENAI_BASE_URL"]
),
"model": "gpt-4-32k-azure",
},
{"connector": MistralConnector(api_key=os.environ["MISTRAL_AI_API_KEY"]), "model": "mistral-small-latest"},
{
"connector": MistralAzureConnector(
api_key=os.environ["MISTRAL_AZURE_API_KEY"], endpoint=os.environ["MISTRAL_AZURE_ENDPOINT"]
),
"model": "mistral-large",
},
{"connector": GroqConnector(api_key=os.environ["GROQ_API_KEY"]), "model": "groq/gemma-7b-it"},
{"connector": CohereConnector(api_key=os.environ["COHERE_API_KEY"]), "model": "command-r-plus"},
{
"connector": AnyscaleEndpointsConnector(api_key=os.environ["ANYSCALE_API_KEY"]),
"model": "anyscale/mistralai/Mixtral-8x7B-Instruct-v0.1",
},
]
for connector_dict in connectors:
connector = connector_dict["connector"]
model = connector_dict["model"]
print(f"Connector: {connector} and model: {model}")
print("With tools")
response = asyncio.run(connector.chat_completion(model=model, messages=messages, tools=tools))
print(response)
print("NO tools")
response = asyncio.run(connector.chat_completion(model=model, messages=messages))
print(response)
print("Stream")
if not isinstance(connector, CohereConnector):
response = asyncio.run(connector.chat_completion(model=model, messages=messages, stream=True))
else:
response = connector.chat_completion(model=model, messages=messages, stream=True)
print(response)
print("\n", "-" * 50, "\n")