Skip to content

Commit 9bb8f46

Browse files
committed
feat: Suport JSON mode response format
1 parent f2b698f commit 9bb8f46

File tree

4 files changed

+117
-2
lines changed

4 files changed

+117
-2
lines changed

src/api/routers/chat.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from typing import Annotated
1+
from typing import Annotated, Final
22

33
from fastapi import APIRouter, Depends, Body
44
from fastapi.responses import StreamingResponse
55

66
from api.auth import api_key_auth
77
from api.models.bedrock import BedrockModel
8-
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse
8+
from api.schema import ChatRequest, ChatResponse, ChatStreamResponse, SystemMessage
99
from api.setting import DEFAULT_MODEL
1010
import logging
1111

@@ -15,6 +15,12 @@
1515
# responses={404: {"description": "Not found"}},
1616
)
1717

18+
JSON_MODE_SYSTEM_PROMPT: Final = SystemMessage(
19+
name="json_mode",
20+
role="system",
21+
content="You are a helpful assistant designed to output JSON without extra text",
22+
)
23+
1824

1925
@router.post(
2026
"/completions",
@@ -43,6 +49,16 @@ async def chat_completions(
4349
else:
4450
chat_request.model = deployment
4551

52+
# JSON mode handling
53+
# Ref: https://platform.openai.com/docs/guides/structured-outputs/json-mode
54+
response_format = chat_request.response_format
55+
if (
56+
chat_request.stream
57+
or response_format
58+
and response_format["type"] == "json_object"
59+
):
60+
chat_request.messages.append(JSON_MODE_SYSTEM_PROMPT)
61+
4662
logging.debug(f"chat_request: {chat_request}")
4763

4864
# Exception will be raised if model not supported.

src/api/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class ChatRequest(BaseModel):
9898
n: int | None = 1 # Not used
9999
tools: list[Tool] | None = None
100100
tool_choice: str | object = "auto"
101+
response_format: dict | None = None
101102

102103

103104
class Usage(BaseModel):
File renamed without changes.

tests/test_json.ipynb

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {
6+
"isInteractiveWindowMessageCell": true
7+
},
8+
"source": [
9+
"Connected to openai2 (Python 3.11.4)"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 1,
15+
"metadata": {},
16+
"outputs": [],
17+
"source": [
18+
"import os\n",
19+
"from openai import AzureOpenAI\n",
20+
"\n",
21+
"client = AzureOpenAI(\n",
22+
" azure_endpoint=os.getenv(\"AZURE_OPENAI_ENDPOINT\", \"http://localhost:8000\"),\n",
23+
" api_key=os.getenv(\"AZURE_OPENAI_API_KEY\", \"bedrock\"),\n",
24+
" azure_deployment=\"gpt-4o\",\n",
25+
" api_version=\"2024-06-01\",\n",
26+
")\n",
27+
"\n",
28+
"response = client.chat.completions.create(\n",
29+
" model=\"gpt-4o\",\n",
30+
" response_format={\"type\": \"json_object\"},\n",
31+
" messages=[\n",
32+
" # {\n",
33+
" # \"role\": \"system\",\n",
34+
" # \"content\": \"You are a helpful assistant designed to output JSON without extra text\",\n",
35+
" # },\n",
36+
" {\"role\": \"user\", \"content\": \"Who won the world series in 2020?\"},\n",
37+
" ],\n",
38+
")\n",
39+
"\n",
40+
"content = response.choices[0].message.content"
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 2,
46+
"metadata": {},
47+
"outputs": [
48+
{
49+
"name": "stdout",
50+
"output_type": "stream",
51+
"text": [
52+
"<class 'str'>\n",
53+
"{\n",
54+
" \"winner\": \"Los Angeles Dodgers\"\n",
55+
"}\n"
56+
]
57+
}
58+
],
59+
"source": [
60+
"print(type(content)) # The type of content is a JSON String\n",
61+
"print(content)"
62+
]
63+
},
64+
{
65+
"cell_type": "code",
66+
"execution_count": 3,
67+
"metadata": {},
68+
"outputs": [
69+
{
70+
"name": "stdout",
71+
"output_type": "stream",
72+
"text": [
73+
"<class 'dict'>\n"
74+
]
75+
}
76+
],
77+
"source": [
78+
"import json\n",
79+
"\n",
80+
"json_data = json.loads(content)\n",
81+
"print(type(json_data)) # Customer can later parse the string to a JSON format."
82+
]
83+
}
84+
],
85+
"metadata": {
86+
"kernelspec": {
87+
"display_name": "openai2",
88+
"language": "python",
89+
"name": "python3"
90+
},
91+
"language_info": {
92+
"name": "python",
93+
"version": "3.11.4"
94+
}
95+
},
96+
"nbformat": 4,
97+
"nbformat_minor": 2
98+
}

0 commit comments

Comments
 (0)