forked from h2oai/h2ogpt
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_openai_server.py
164 lines (133 loc) · 6.65 KB
/
test_openai_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
import time
import pytest
import os
# to avoid copy-paste, only other external reference besides main() (for local_server=True)
from tests.utils import wrap_test_forked
def launch_openai_server():
from openai_server.server import run
run()
def test_openai_server():
# for manual separate OpenAI server on existing h2oGPT, run:
# Shell 1: CUDA_VISIBLE_DEVICES=0 python generate.py --verbose=True --score_model=None --pre_load_embedding_model=False --gradio_offline_level=2 --base_model=openchat/openchat-3.5-1210 --inference_server=vllm:ip:port --max_seq_len=4096 --save_dir=duder1 --verbose --openai_server=True --concurrency_count=64 --openai_server=False
# Shell 2: pytest -s -v openai_server/test_openai_server.py::test_openai_server # once client done, hit CTRL-C, should pass
# Shell 3: pytest -s -v openai_server/test_openai_server.py::test_openai_client_test2 # should pass
launch_openai_server()
# repeat0 = 100 # e.g. to test concurrency
repeat0 = 1
@pytest.mark.parametrize("stream_output", [False, True])
@pytest.mark.parametrize("chat", [False, True])
@pytest.mark.parametrize("local_server", [False])
@wrap_test_forked
def test_openai_client_test2(stream_output, chat, local_server):
prompt = "Who are you?"
api_key = 'EMPTY'
enforce_h2ogpt_api_key = False
repeat = 1
run_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat)
@pytest.mark.parametrize("stream_output", [False, True])
@pytest.mark.parametrize("chat", [False, True])
@pytest.mark.parametrize("local_server", [True])
@pytest.mark.parametrize("prompt", ["Who are you?", "Tell a very long kid's story about birds."])
@pytest.mark.parametrize("api_key", [None, "EMPTY", os.environ.get('H2OGPT_H2OGPT_KEY', 'EMPTY')])
@pytest.mark.parametrize("enforce_h2ogpt_api_key", [False, True])
@pytest.mark.parametrize("repeat", list(range(0, repeat0)))
@wrap_test_forked
def test_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat):
run_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat)
def run_openai_client(stream_output, chat, local_server, prompt, api_key, enforce_h2ogpt_api_key, repeat):
base_model = 'openchat/openchat-3.5-1210'
if local_server:
from src.gen import main
main(base_model=base_model, chat=False,
stream_output=stream_output, gradio=True,
num_beams=1, block_gradio_exit=False,
add_disk_models_to_ui=False,
enable_tts=False,
enable_stt=False,
enforce_h2ogpt_api_key=enforce_h2ogpt_api_key,
# or use file with h2ogpt_api_keys=h2ogpt_api_keys.json
h2ogpt_api_keys=[api_key] if api_key else None,
)
time.sleep(10)
else:
# RUN something
# e.g. CUDA_VISIBLE_DEVICES=0 python generate.py --verbose=True --score_model=None --gradio_offline_level=2 --base_model=openchat/openchat-3.5-1210 --inference_server=vllm:IP:port --max_seq_len=4096 --save_dir=duder1 --verbose --openai_server=True --concurency_count=64
pass
# api_key = "EMPTY" # if gradio/openai server not keyed. Can't pass '' itself, leads to httpcore.LocalProtocolError: Illegal header value b'Bearer '
# Setting H2OGPT_H2OGPT_KEY does not key h2oGPT, just passes along key to gradio inference server, so empty key is valid test regardless of the H2OGPT_H2OGPT_KEY value
# api_key = os.environ.get('H2OGPT_H2OGPT_KEY', 'EMPTY') # if keyed and have this in env with same key
print('api_key: %s' % api_key)
# below should be consistent with server prefix, host, and port
base_url = 'http://localhost:5000/v1'
verbose = True
system_prompt = "You are a helpful assistant."
chat_conversation = []
add_chat_history_to_context = True
client_kwargs = dict(model=base_model,
max_tokens=200,
stream=stream_output)
from openai import OpenAI, AsyncOpenAI
client_args = dict(base_url=base_url, api_key=api_key)
openai_client = OpenAI(**client_args)
async_client = AsyncOpenAI(**client_args)
try:
test_chat(chat, openai_client, async_client, system_prompt, chat_conversation, add_chat_history_to_context,
prompt, client_kwargs, stream_output, verbose)
except AssertionError:
if enforce_h2ogpt_api_key and api_key is None:
print("Expected to fail since no key but enforcing.")
else:
raise
# MODELS
model_info = openai_client.models.retrieve(base_model)
assert model_info.base_model == base_model
model_list = openai_client.models.list()
assert model_list.data[0] == base_model
def test_chat(chat, openai_client, async_client, system_prompt, chat_conversation, add_chat_history_to_context,
prompt, client_kwargs, stream_output, verbose):
# COMPLETION
if chat:
client = openai_client.chat.completions
async_client = async_client.chat.completions
messages0 = []
if system_prompt:
messages0.append({"role": "system", "content": system_prompt})
if chat_conversation and add_chat_history_to_context:
for message1 in chat_conversation:
if len(message1) == 2:
messages0.append(
{'role': 'user', 'content': message1[0] if message1[0] is not None else ''})
messages0.append(
{'role': 'assistant', 'content': message1[1] if message1[1] is not None else ''})
messages0.append({'role': 'user', 'content': prompt if prompt is not None else ''})
client_kwargs.update(dict(messages=messages0))
else:
client = openai_client.completions
async_client = async_client.completions
client_kwargs.update(dict(prompt=prompt))
responses = client.create(**client_kwargs)
if not stream_output:
if chat:
text = responses.choices[0].message.content
else:
text = responses.choices[0].text
print(text)
else:
collected_events = []
text = ''
for event in responses:
collected_events.append(event) # save the event response
if chat:
delta = event.choices[0].delta.content
else:
delta = event.choices[0].text # extract the text
text += delta # append the text
if verbose:
print('delta: %s' % delta)
print(text)
if "Who" in prompt:
assert 'OpenAI' in text or 'chatbot' in text
else:
assert 'birds' in text
if __name__ == '__main__':
launch_openai_server()