Skip to content

Commit

Permalink
Make plugin work with AzureOAI (openai#221)
Browse files Browse the repository at this point in the history
* Make this work better with Azure Open AI

* Fix typo

* Update README.md

Co-authored-by: Derek Legenzoff <[email protected]>

---------

Co-authored-by: isafulf <[email protected]>
Co-authored-by: Derek Legenzoff <[email protected]>
  • Loading branch information
3 people authored May 2, 2023
1 parent d619172 commit 0ebb015
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 13 deletions.
18 changes: 18 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ Follow these steps to quickly set up and run the ChatGPT Retrieval Plugin:
export BEARER_TOKEN=<your_bearer_token>
export OPENAI_API_KEY=<your_openai_api_key>
# Optional environment variables used when running Azure OpenAI
export OPENAI_API_BASE=https://<AzureOpenAIName>.openai.azure.com/
export OPENAI_API_TYPE=azure
export OPENAI_EMBEDDINGMODEL_DEPLOYMENTID=<Name of text-embedding-ada-002 model deployment>
export OPENAI_METADATA_EXTRACTIONMODEL_DEPLOYMENTID=<Name of deployment of model for metatdata>
export OPENAI_COMPLETIONMODEL_DEPLOYMENTID=<Name of general model deployment used for completion>
# Add the environment variables for your chosen vector DB.
# Some of these are optional; read the provider's setup docs in /docs/providers for more information.
Expand Down Expand Up @@ -237,6 +244,17 @@ The API requires the following environment variables to work:
| `BEARER_TOKEN` | Yes | This is a secret token that you need to authenticate your requests to the API. You can generate one using any tool or method you prefer, such as [jwt.io](https://jwt.io/). |
| `OPENAI_API_KEY` | Yes | This is your OpenAI API key that you need to generate embeddings using the `text-embedding-ada-002` model. You can get an API key by creating an account on [OpenAI](https://openai.com/). |


### Using the plugin with Azure OpenAI

The Azure Open AI uses URLs that are specific to your resource and references models not by model name but by the deployment id. As a result, you need to set additional environment variables for this case.

In addition to the OPENAI_API_BASE (your specific URL) and OPENAI_API_TYPE (azure), you should also set OPENAI_EMBEDDINGMODEL_DEPLOYMENTID which specifies the model to use for getting embeddings on upsert and query. For this, we recommend deploying text-embedding-ada-002 model and using the deployment name here.

If you wish to use the data preparation scripts, you will also need to set OPENAI_METADATA_EXTRACTIONMODEL_DEPLOYMENTID, used for metadata extraction and
OPENAI_COMPLETIONMODEL_DEPLOYMENTID, used for PII handling.


### Choosing a Vector Database

The plugin supports several vector database providers, each with different features, performance, and pricing. Depending on which one you choose, you will need to use a different Dockerfile and set different environment variables. The following sections provide brief introductions to each vector database provider.
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ packages = [{include = "server"}]
python = "^3.10"
fastapi = "^0.92.0"
uvicorn = "^0.20.0"
openai = "^0.27.2"
openai = "^0.27.5"
python-dotenv = "^0.21.1"
pydantic = "^1.10.5"
tenacity = "^8.2.1"
Expand Down
8 changes: 6 additions & 2 deletions services/extract_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from services.openai import get_chat_completion
import json
from typing import Dict

import os

def extract_metadata_from_document(text: str) -> Dict[str, str]:
sources = Source.__members__.keys()
Expand All @@ -24,8 +24,12 @@ def extract_metadata_from_document(text: str) -> Dict[str, str]:
{"role": "user", "content": text},
]

# NOTE: Azure Open AI requires deployment id
# Read environment variable - if not set - not used
completion = get_chat_completion(
messages, "gpt-4"
messages,
"gpt-4",
os.environ.get("OPENAI_METADATA_EXTRACTIONMODEL_DEPLOYMENTID")
) # TODO: change to your preferred model name

print(f"completion: {completion}")
Expand Down
29 changes: 23 additions & 6 deletions services/openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List
import openai

import os

from tenacity import retry, wait_random_exponential, stop_after_attempt

Expand All @@ -20,8 +20,15 @@ def get_embeddings(texts: List[str]) -> List[List[float]]:
Exception: If the OpenAI API call fails.
"""
# Call the OpenAI API to get the embeddings
response = openai.Embedding.create(input=texts, model="text-embedding-ada-002")
# NOTE: Azure Open AI requires deployment id
deployment = os.environ.get("OPENAI_EMBEDDINGMODEL_DEPLOYMENTID")

response = {}
if deployment == None:
response = openai.Embedding.create(input=texts, model="text-embedding-ada-002")
else:
response = openai.Embedding.create(input=texts, deployment_id=deployment)

# Extract the embedding data from the response
data = response["data"] # type: ignore

Expand All @@ -33,6 +40,7 @@ def get_embeddings(texts: List[str]) -> List[List[float]]:
def get_chat_completion(
messages,
model="gpt-3.5-turbo", # use "gpt-4" for better results
deployment_id = None
):
"""
Generate a chat completion using OpenAI's chat completion API.
Expand All @@ -48,10 +56,19 @@ def get_chat_completion(
Exception: If the OpenAI API call fails.
"""
# call the OpenAI chat completion API with the given messages
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
# Note: Azure Open AI requires deployment id
response = {}
if deployment_id == None:
response = openai.ChatCompletion.create(
model=model,
messages=messages,
)
else:
response = openai.ChatCompletion.create(
deployment_id = deployment_id,
messages=messages,
)


choices = response["choices"] # type: ignore
completion = choices[0].message.content.strip()
Expand Down
2 changes: 2 additions & 0 deletions services/pii_detection.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from services.openai import get_chat_completion


Expand All @@ -22,6 +23,7 @@ def screen_text_for_pii(text: str) -> bool:

completion = get_chat_completion(
messages,
deployment_id=os.environ.get("OPENAI_COMPLETIONMODEL_DEPLOYMENTID")
)

if completion.startswith("True"):
Expand Down

0 comments on commit 0ebb015

Please sign in to comment.