Skip to content

Commit fe73ac4

Browse files
authored
Update: Add Azure OpenAI Support (#117)
1 parent 880e26a commit fe73ac4

File tree

4 files changed

+151
-3
lines changed

4 files changed

+151
-3
lines changed

backend/README.md

+8
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,14 @@ Set your OpenAI key (if you use OpenAI API):
127127
```bash
128128
export OPENAI_API_KEY=<OPENAI_API_KEY>
129129
```
130+
**Note** if you are using [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview) Service, you should choose:
131+
```bash
132+
export OPENAI_API_TYPE=azure
133+
export OPENAI_API_BASE=<AZURE_API_BASE>
134+
export OPENAI_API_VERSION=<AZURE_API_VERSION>
135+
export OPENAI_API_KEY=<AZURE_API_KEY>
136+
```
137+
If you are starting your backend in docker, you should add these environment variables in `docker-compose.yml` as well.
130138

131139
Set your Anthropic key (if you use Anthropic API):
132140
```bash

backend/api/language_model.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22

33
from backend.app import app
4-
from real_agents.adapters.models import ChatOpenAI, ChatAnthropic
4+
from real_agents.adapters.models import ChatOpenAI, ChatAnthropic, AzureChatOpenAI
55
from real_agents.adapters.llm import BaseLanguageModel
66

77
LLAMA_DIR = "PATH_TO_LLAMA_DIR"
@@ -24,8 +24,14 @@ def get_llm_list():
2424
def get_llm(llm_name: str, **kwargs) -> BaseLanguageModel:
2525
"""Gets the llm model by its name."""
2626
if llm_name in ["gpt-3.5-turbo-16k", "gpt-4"]:
27-
return ChatOpenAI(
28-
model_name=llm_name,
27+
openai_api_type = os.getenv("OPENAI_API_TYPE", "open_ai")
28+
if openai_api_type == "open_ai":
29+
chat_openai = ChatOpenAI
30+
kwargs.update({"model_name": llm_name})
31+
elif openai_api_type == "azure":
32+
chat_openai = AzureChatOpenAI
33+
kwargs.update({"deployment_name": llm_name})
34+
return chat_openai(
2935
streaming=True,
3036
verbose=True,
3137
**kwargs

real_agents/adapters/models/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
from real_agents.adapters.models.anthropic import ChatAnthropic
44
from real_agents.adapters.models.openai import ChatOpenAI
5+
from real_agents.adapters.models.azure_openai import AzureChatOpenAI
56

67
__all__ = [
78
"ChatOpenAI",
89
"ChatAnthropic",
910
"ChatGooglePalm",
11+
"AzureChatOpenAI",
1012
]
1113

1214
type_to_cls_dict = {
1315
"chat_anthropic": ChatAnthropic,
1416
"chat_google_palm": ChatGooglePalm,
1517
"chat_openai": ChatOpenAI,
18+
"azure_chat_openai": AzureChatOpenAI,
1619
}
+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Azure OpenAI chat wrapper."""
2+
from __future__ import annotations
3+
4+
import logging
5+
from typing import Any, Dict, Mapping
6+
7+
from pydantic import root_validator
8+
9+
from real_agents.adapters.models.openai import ChatOpenAI
10+
from langchain.schema import ChatResult
11+
from langchain.utils import get_from_dict_or_env
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
class AzureChatOpenAI(ChatOpenAI):
17+
"""Wrapper around Azure OpenAI Chat Completion API. To use this class you
18+
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
19+
constructor to refer to the "Model deployment name" in the Azure portal.
20+
21+
In addition, you should have the ``openai`` python package installed, and the
22+
following environment variables set or passed in constructor in lower case:
23+
- ``OPENAI_API_TYPE`` (default: ``azure``)
24+
- ``OPENAI_API_KEY``
25+
- ``OPENAI_API_BASE``
26+
- ``OPENAI_API_VERSION``
27+
28+
For exmaple, if you have `gpt-35-turbo` deployed, with the deployment name
29+
`35-turbo-dev`, the constructor should look like:
30+
31+
.. code-block:: python
32+
AzureChatOpenAI(
33+
deployment_name="35-turbo-dev",
34+
openai_api_version="2023-03-15-preview",
35+
)
36+
37+
Be aware the API version may change.
38+
39+
Any parameters that are valid to be passed to the openai.create call can be passed
40+
in, even if not explicitly saved on this class.
41+
"""
42+
43+
deployment_name: str = ""
44+
openai_api_type: str = "azure"
45+
openai_api_base: str = ""
46+
openai_api_version: str = ""
47+
openai_api_key: str = ""
48+
openai_organization: str = ""
49+
50+
@root_validator()
51+
def validate_environment(cls, values: Dict) -> Dict:
52+
"""Validate that api key and python package exists in environment."""
53+
openai_api_key = get_from_dict_or_env(
54+
values,
55+
"openai_api_key",
56+
"OPENAI_API_KEY",
57+
)
58+
openai_api_base = get_from_dict_or_env(
59+
values,
60+
"openai_api_base",
61+
"OPENAI_API_BASE",
62+
)
63+
openai_api_version = get_from_dict_or_env(
64+
values,
65+
"openai_api_version",
66+
"OPENAI_API_VERSION",
67+
)
68+
openai_api_type = get_from_dict_or_env(
69+
values,
70+
"openai_api_type",
71+
"OPENAI_API_TYPE",
72+
)
73+
openai_organization = get_from_dict_or_env(
74+
values,
75+
"openai_organization",
76+
"OPENAI_ORGANIZATION",
77+
default="",
78+
)
79+
try:
80+
import openai
81+
82+
openai.api_type = openai_api_type
83+
openai.api_base = openai_api_base
84+
openai.api_version = openai_api_version
85+
openai.api_key = openai_api_key
86+
if openai_organization:
87+
openai.organization = openai_organization
88+
except ImportError:
89+
raise ValueError(
90+
"Could not import openai python package. "
91+
"Please install it with `pip install openai`."
92+
)
93+
try:
94+
values["client"] = openai.ChatCompletion
95+
except AttributeError:
96+
raise ValueError(
97+
"`openai` has no `ChatCompletion` attribute, this is likely "
98+
"due to an old version of the openai package. Try upgrading it "
99+
"with `pip install --upgrade openai`."
100+
)
101+
if values["n"] < 1:
102+
raise ValueError("n must be at least 1.")
103+
if values["n"] > 1 and values["streaming"]:
104+
raise ValueError("n must be 1 when streaming.")
105+
return values
106+
107+
@property
108+
def _default_params(self) -> Dict[str, Any]:
109+
"""Get the default parameters for calling OpenAI API."""
110+
return {
111+
**super()._default_params,
112+
"engine": self.deployment_name,
113+
}
114+
115+
@property
116+
def _identifying_params(self) -> Mapping[str, Any]:
117+
"""Get the identifying parameters."""
118+
return {**self._default_params}
119+
120+
@property
121+
def _llm_type(self) -> str:
122+
return "azure-openai-chat"
123+
124+
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
125+
for res in response["choices"]:
126+
if res.get("finish_reason", None) == "content_filter":
127+
raise ValueError(
128+
"Azure has not provided the response due to a content"
129+
" filter being triggered"
130+
)
131+
return super()._create_chat_result(response)

0 commit comments

Comments
 (0)