Skip to content

Commit ddef995

Browse files
authored
server : fix assistant prefilling when content is an array (#14360)
1 parent 6681688 commit ddef995

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

tools/server/tests/unit/test_chat_completion.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,28 @@ def test_chat_template():
132132
assert res.body["__verbose"]["prompt"] == "<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
133133

134134

135+
@pytest.mark.parametrize("prefill,re_prefill", [
136+
("Whill", "Whill"),
137+
([{"type": "text", "text": "Wh"}, {"type": "text", "text": "ill"}], "Whill"),
138+
])
139+
def test_chat_template_assistant_prefill(prefill, re_prefill):
140+
global server
141+
server.chat_template = "llama3"
142+
server.debug = True # to get the "__verbose" object in the response
143+
server.start()
144+
res = server.make_request("POST", "/chat/completions", data={
145+
"max_tokens": 8,
146+
"messages": [
147+
{"role": "system", "content": "Book"},
148+
{"role": "user", "content": "What is the best book"},
149+
{"role": "assistant", "content": prefill},
150+
]
151+
})
152+
assert res.status_code == 200
153+
assert "__verbose" in res.body
154+
assert res.body["__verbose"]["prompt"] == f"<s> <|start_header_id|>system<|end_header_id|>\n\nBook<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nWhat is the best book<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n{re_prefill}"
155+
156+
135157
def test_apply_chat_template():
136158
global server
137159
server.chat_template = "command-r"
@@ -228,6 +250,7 @@ def test_completion_with_grammar(jinja: bool, grammar: str, n_predicted: int, re
228250
[{"role": "system", "content": 123}],
229251
# [{"content": "hello"}], # TODO: should not be a valid case
230252
[{"role": "system", "content": "test"}, {}],
253+
[{"role": "user", "content": "test"}, {"role": "assistant", "content": "test"}, {"role": "assistant", "content": "test"}],
231254
])
232255
def test_invalid_chat_completion_req(messages):
233256
global server

tools/server/utils.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,13 @@ static json oaicompat_chat_params_parse(
792792

793793
/* Append assistant prefilled message */
794794
if (prefill_assistant_message) {
795-
chat_params.prompt += last_message.content;
795+
if (!last_message.content_parts.empty()) {
796+
for (auto & p : last_message.content_parts) {
797+
chat_params.prompt += p.text;
798+
}
799+
} else {
800+
chat_params.prompt += last_message.content;
801+
}
796802
}
797803

798804
llama_params["chat_format"] = static_cast<int>(chat_params.format);

0 commit comments

Comments
 (0)