diff --git a/common/arg.cpp b/common/arg.cpp index 8266a16c261c5..1ae55b22c32f7 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1879,6 +1879,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.slot_prompt_similarity = std::stof(value); } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(llama_arg( + {"--testing-sampler-delay-millis"}, "N", + format("for tests: delay in milliseconds to add to each sampling (default: %d)", params.testing_sampler_delay_millis), + [](gpt_params & params, int value) { + params.testing_sampler_delay_millis = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(llama_arg( {"--lora-init-without-apply"}, format("load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: %s)", params.lora_init_without_apply ? "enabled" : "disabled"), diff --git a/common/common.h b/common/common.h index 8b84cf9ad45ee..154d59846be62 100644 --- a/common/common.h +++ b/common/common.h @@ -299,6 +299,8 @@ struct gpt_params { float slot_prompt_similarity = 0.5f; + int testing_sampler_delay_millis = 0; + // batched-bench params bool is_pp_shared = false; diff --git a/examples/server/httplib.h b/examples/server/httplib.h index f360bd93ea098..05ee81a088ed7 100644 --- a/examples/server/httplib.h +++ b/examples/server/httplib.h @@ -590,6 +590,7 @@ struct Response { Headers headers; std::string body; std::string location; // Redirect location + std::function is_alive; bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, size_t id = 0) const; @@ -639,6 +640,7 @@ class Stream { virtual bool is_readable() const = 0; virtual bool is_writable() const = 0; + virtual bool is_alive() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -2135,6 +2137,7 @@ class BufferStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2945,6 +2948,7 @@ class SocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -2975,6 +2979,7 @@ class SSLSocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + bool is_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -4279,6 +4284,7 @@ inline bool redirect(T &cli, Request &req, Response &res, } Response new_res; + new_res.is_alive = res.is_alive; auto ret = cli.send(new_req, new_res, error); if (ret) { @@ -5484,6 +5490,10 @@ inline bool SocketStream::is_writable() const { is_socket_alive(sock_); } +inline bool SocketStream::is_alive() const { + return is_socket_alive(sock_); +} + inline ssize_t SocketStream::read(char *ptr, size_t size) { #ifdef _WIN32 size = @@ -5558,6 +5568,8 @@ inline bool BufferStream::is_readable() const { return true; } inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::is_alive() const { return true; } + inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 auto len_read = buffer._Copy_s(ptr, size, size, position); @@ -6634,6 +6646,7 @@ Server::process_request(Stream &strm, bool close_connection, Request req; Response res; + res.is_alive = [&strm]() { return strm.is_alive(); }; res.version = "HTTP/1.1"; res.headers = default_headers_; @@ -8348,6 +8361,10 @@ inline bool SSLSocketStream::is_writable() const { is_socket_alive(sock_); } +inline bool SSLSocketStream::is_alive() const { + return is_socket_alive(sock_); +} + inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index f343cc252f89a..01998eabe61a0 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -104,6 +104,7 @@ struct server_task { json data; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; + std::function is_alive; // utility function static std::unordered_set get_list_id(const std::vector & tasks) { @@ -173,7 +174,7 @@ struct server_slot { std::vector generated_token_probs; server_task_cmpl_type cmpl_type = SERVER_TASK_CMPL_TYPE_NORMAL; - + std::function is_alive; bool has_next_token = true; bool truncated = false; bool stopped_eos = false; @@ -876,6 +877,7 @@ struct server_context { // Sampling parameter defaults are loaded from the global server context (but individual requests can still override them) auto default_sparams = params.sparams; const auto & data = task.data; + slot.is_alive = task.is_alive; if (data.count("__oaicompat") != 0) { slot.oaicompat = true; @@ -1117,6 +1119,13 @@ struct server_context { } bool process_token(completion_token_output & result, server_slot & slot) { + if (slot.is_alive && !slot.is_alive()) { + slot.truncated = false; + slot.has_next_token = false; + + SLT_DBG(slot, "stopped by client disconnection, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.params.n_predict); + return slot.has_next_token; + } // remember which tokens were sampled - used for repetition penalties during sampling const std::string token_str = llama_token_to_piece(ctx, result.tok, params.special); slot.sampled = result.tok; @@ -1461,13 +1470,14 @@ struct server_context { // Functions to create new task(s) and receive result(s) // - std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type) { + std::vector create_tasks_cmpl(json data, server_task_cmpl_type cmpl_type, const std::function & is_alive) { std::vector tasks; auto create_task = [&](json & task_data, bool replace_prompt, json prompt) { server_task task; task.id = queue_tasks.get_new_id(); task.cmpl_type = cmpl_type; task.type = SERVER_TASK_TYPE_COMPLETION; + task.is_alive = is_alive; if (replace_prompt) { task.data = task_data; task.data["prompt"] = std::move(prompt); @@ -1866,6 +1876,13 @@ struct server_context { system_prompt_update(); } + for (auto & slot : slots) { + if (slot.is_processing() && slot.is_alive && !slot.is_alive()) { + SLT_WRN(slot, "%s", "slot connection died\n"); + slot.release(); + } + } + // check if all slots are idle { bool all_idle = true; @@ -2337,6 +2354,10 @@ struct server_context { } completion_token_output result; + if (params.testing_sampler_delay_millis > 0) { + SRV_DBG("sleeping for %dms before sampling (for tests!)\n", params.testing_sampler_delay_millis); + std::this_thread::sleep_for(std::chrono::milliseconds(params.testing_sampler_delay_millis)); + } const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i); gpt_sampler_accept(slot.smpl, id, true); @@ -2893,7 +2914,7 @@ int main(int argc, char ** argv) { return; } - std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type); + std::vector tasks = ctx_server.create_tasks_cmpl(data, cmpl_type, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -2956,7 +2977,7 @@ int main(int argc, char ** argv) { json data = oaicompat_completion_params_parse(ctx_server.model, json::parse(req.body), params.chat_template); - std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL); + std::vector tasks = ctx_server.create_tasks_cmpl(data, SERVER_TASK_CMPL_TYPE_NORMAL, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3099,7 +3120,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_EMBEDDING, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); @@ -3176,7 +3197,7 @@ int main(int argc, char ** argv) { json responses = json::array(); bool error = false; { - std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK); + std::vector tasks = ctx_server.create_tasks_cmpl({{"prompt", prompt}}, SERVER_TASK_CMPL_TYPE_RERANK, res.is_alive); ctx_server.queue_results.add_waiting_tasks(tasks); ctx_server.queue_tasks.post(tasks); diff --git a/examples/server/tests/features/cancel.feature b/examples/server/tests/features/cancel.feature new file mode 100644 index 0000000000000..7112367808451 --- /dev/null +++ b/examples/server/tests/features/cancel.feature @@ -0,0 +1,57 @@ +@llama.cpp +@server +Feature: Cancellation of llama.cpp server requests + + Background: Server startup + Given a server listening on localhost:8080 + And a model file tinyllamas/stories260K.gguf from HF repo ggml-org/models + And a model file test-model.gguf + And a model alias tinyllama-2 + And BOS token is 1 + And 42 as server seed + # KV Cache corresponds to the total amount of tokens + # that can be stored across all independent sequences: #4130 + # see --ctx-size and #5568 + And 256 KV cache size + And 32 as batch size + And 2 slots + And 64 server max tokens to predict + And prometheus compatible metrics exposed + And 300 milliseconds delay in sampler for testing + Then the server is starting + Then the server is healthy + + + Scenario Outline: Cancelling an OAI chat completion request frees up slot (streaming ) + Given a model llama-2 + And a user prompt Once upon a time + And a system prompt You tell lengthy stories + And 256 max tokens to predict + And 256 server max tokens to predict + And streaming is + And disconnect after 100 milliseconds + Given concurrent OAI completions requests + And wait for 700 milliseconds + Then all slots are idle + + Examples: Prompts + | enable_streaming | + | disabled | + | enabled | + + + Scenario Outline: Cancelling a completion request frees up slot (streaming ) + Given a model llama-2 + Given a prompt Once upon a time + And 256 max tokens to predict + And 256 server max tokens to predict + And streaming is + And disconnect after 100 milliseconds + Given a completion request with no api error + And wait for 700 milliseconds + Then all slots are idle + + Examples: Prompts + | enable_streaming | + | disabled | + | enabled | diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index 2611614ba3633..cc3107b2b773f 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -78,7 +78,9 @@ def step_server_config(context, server_fqdn: str, server_port: str): context.response_format = None context.temperature = None context.lora_file = None + context.testing_sampler_delay_millis = None context.disable_ctx_shift = False + context.disconnect_after_millis = None context.tasks_result = [] context.concurrent_tasks = [] @@ -278,6 +280,7 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): n_predict=context.n_predict, cache_prompt=context.cache_prompt, id_slot=context.id_slot, + disconnect_after_millis=context.disconnect_after_millis, expect_api_error=expect_api_error, user_api_key=context.user_api_key, temperature=context.temperature) @@ -290,6 +293,17 @@ async def step_request_completion(context, api_error: Literal['raised'] | str): api_error_code = int(api_error) assert completion == api_error_code, f"completion must be an {api_error_code} status code: {completion}" +@step('wait for {millis:d} milliseconds') +@async_run_until_complete +async def step_request_completion(context, millis: int): + await asyncio.sleep(millis / 1000.0) + + +@step('disconnect after {disconnect_after_millis:d} milliseconds') +@async_run_until_complete +async def step_disconnect_after(context, disconnect_after_millis: int): + context.disconnect_after_millis = disconnect_after_millis + @step('{predicted_n:d} tokens are predicted matching {re_content}') def step_n_tokens_predicted_with_content(context, predicted_n, re_content): @@ -455,6 +469,9 @@ def step_impl(context, n_ga): def step_impl(context, n_ga_w): context.n_ga_w = n_ga_w +@step('{testing_sampler_delay_millis:d} milliseconds delay in sampler for testing') +def step_testing_sampler_delay_millis(context, testing_sampler_delay_millis): + context.testing_sampler_delay_millis = testing_sampler_delay_millis @step('a passkey prompt template') def step_prompt_passkey(context): @@ -495,7 +512,7 @@ async def step_oai_chat_completions(context, api_error): if context.debug: print(f"Submitting OAI compatible completions request...") expect_api_error = api_error == 'raised' - seeds = await completions_seed(context, num_seeds=1), + seeds = await completions_seed(context, num_seeds=1) completion = await oai_chat_completions(context.prompts.pop(), seeds[0] if seeds is not None else seeds, context.system_prompt, @@ -516,6 +533,8 @@ async def step_oai_chat_completions(context, api_error): user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None, + disconnect_after_millis=context.disconnect_after_millis, + expect_api_error=expect_api_error) context.tasks_result.append(completion) if context.debug: @@ -583,6 +602,7 @@ async def step_oai_chat_completions(context): if hasattr(context, 'enable_streaming') else None, response_format=context.response_format if hasattr(context, 'response_format') else None, + disconnect_after_millis=context.disconnect_after_millis, user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -978,9 +998,10 @@ async def request_completion(prompt, id_slot=None, expect_api_error=None, user_api_key=None, + disconnect_after_millis=None, temperature=None) -> int | dict[str, Any]: if debug: - print(f"Sending completion request: {prompt}") + print(f"Sending completion request: {prompt} with n_predict={n_predict}") origin = "my.super.domain" headers = { 'Origin': origin @@ -991,6 +1012,9 @@ async def request_completion(prompt, headers['Authorization'] = f'Bearer {user_api_key}' async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000.0) + return 0 async with session.post(f'{base_url}/completion', json={ "input_prefix": prompt_prefix, @@ -1022,6 +1046,7 @@ async def oai_chat_completions(user_prompt, temperature=None, model=None, n_predict=None, + disconnect_after_millis=None, enable_streaming=None, response_format=None, user_api_key=None, @@ -1062,6 +1087,9 @@ async def oai_chat_completions(user_prompt, origin = 'llama.cpp' headers = {'Authorization': f'Bearer {user_api_key}', 'Origin': origin} async with aiohttp.ClientSession(timeout=DEFAULT_TIMEOUT_SECONDS) as session: + if disconnect_after_millis is not None: + await asyncio.sleep(disconnect_after_millis / 1000.0) + return 0 async with session.post(f'{base_url}{base_path}', json=payload, headers=headers) as response: @@ -1105,6 +1133,7 @@ async def oai_chat_completions(user_prompt, else: return response.status else: + assert disconnect_after_millis is None, "disconnect_after_millis is not supported with sync client" try: openai.api_key = user_api_key openai.base_url = f'{base_url}{base_path.removesuffix("chat")}' @@ -1348,7 +1377,7 @@ async def request_slots_status(context, expected_slots): def assert_slots_status(slots, expected_slots): - assert len(slots) == len(expected_slots) + assert len(slots) == len(expected_slots), f'invalid number of slots: {len(slots)} (actual) != {len(expected_slots)} (expected)' for slot_id, (expected, slot) in enumerate(zip(expected_slots, slots)): for key in expected: assert expected[key] == slot[key], (f"invalid slot {slot_id}" @@ -1436,6 +1465,8 @@ def start_server_background(context): server_args.append('--verbose') if context.lora_file: server_args.extend(['--lora', context.lora_file]) + if context.testing_sampler_delay_millis: + server_args.extend(['--testing-sampler-delay-millis', context.testing_sampler_delay_millis]) if context.disable_ctx_shift: server_args.extend(['--no-context-shift'])