diff --git a/common/arg.cpp b/common/arg.cpp index 0d0daa3610105..810faa8262a70 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2776,6 +2776,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex key_file.close(); } ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--alias-presets-file"}, "FNAME", + "path to file containing alias preset configurations (default: none)", + [](common_params & params, const std::string & value) { + params.alias_presets_file = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--ssl-key-file"}, "FNAME", "path to file a PEM-encoded SSL private key", diff --git a/common/common.h b/common/common.h index f26724b6e1495..c35ba2fe9be73 100644 --- a/common/common.h +++ b/common/common.h @@ -377,6 +377,7 @@ struct common_params { std::string ssl_file_key = ""; // NOLINT std::string ssl_file_cert = ""; // NOLINT + std::string alias_presets_file = ""; // NOLINT // "advanced" endpoints are disabled by default for better security bool webui = true; diff --git a/tools/server/README.md b/tools/server/README.md index 06533c172e530..181510aaaaf6f 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -490,6 +490,8 @@ These words will not be included in the completion, so make sure to add them to `lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation. +`alias-presets-file`: A JSON file of model-alias and it's parameter presets. E.g. `{ "llama-low": {"temperature": 0.1}, "llama-high": {"temperature": 1.0}" }`. If a `model` is specified in the request and has a preset, it will be applied before handling a completion. In case there is a conflict in the request's parameters vs presets, the request's parameters take precedence. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 2e78dcd7bf1da..1812c5b75205e 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -1886,6 +1887,8 @@ struct server_context { common_chat_templates_ptr chat_templates; oaicompat_parser_options oai_parser_opt; + std::unordered_map model_alias_presets; + ~server_context() { mtmd_free(mctx); @@ -1906,6 +1909,33 @@ struct server_context { llama_batch_free(batch); } + void load_model_alias_presets(const std::string & alias_presets_file) { + try { + std::ifstream file(alias_presets_file); + if (!file) { + SRV_ERR("failed to open alias presets file '%s'\n", alias_presets_file.c_str()); + return; + } + + json presets_json; + file >> presets_json; + file.close(); + + for (const auto & [model_alias_name, preset] : presets_json.items()) { + if (preset.is_object()) { + model_alias_presets[model_alias_name] = preset; + SRV_INF("loaded preset for model alias '%s'\n", model_alias_name.c_str()); + } else { + SRV_WRN("skipping invalid preset for model alias '%s' (not an object)\n", model_alias_name.c_str()); + } + } + + SRV_INF("loaded %zu model alias presets from '%s'\n", model_alias_presets.size(), alias_presets_file.c_str()); + } catch (const std::exception & e) { + SRV_ERR("failed to parse alias presets file '%s': %s\n", alias_presets_file.c_str(), e.what()); + } + } + bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -2023,6 +2053,10 @@ struct server_context { } } + if (!params_base.alias_presets_file.empty()) { + load_model_alias_presets(params_base.alias_presets_file); + } + return true; } @@ -4181,6 +4215,17 @@ int main(int argc, char ** argv) { return; } + // apply presets if available + const std::string model_alias = json_value(data, "model", std::string()); + if (!model_alias.empty() && ctx_server.model_alias_presets.find(model_alias) != ctx_server.model_alias_presets.end()) { + const auto & preset = ctx_server.model_alias_presets.at(model_alias); + for (const auto & [key, value] : preset.items()) { + if (!data.contains(key)) { + data[key] = value; + } + } + } + auto completion_id = gen_chatcmplid(); std::unordered_set task_ids; try { @@ -4245,6 +4290,8 @@ int main(int argc, char ** argv) { } } + + tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); diff --git a/tools/server/tests/unit/test_alias_presets.py b/tools/server/tests/unit/test_alias_presets.py new file mode 100644 index 0000000000000..1ca9d5f7bb0e7 --- /dev/null +++ b/tools/server/tests/unit/test_alias_presets.py @@ -0,0 +1,139 @@ +import json +import os +import tempfile +from pathlib import Path +import sys + +import pytest + +# ensure grandparent path is in sys.path +path = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(path)) + +from utils import * + +server = ServerPreset.stories15m_moe() + +LORA_FILE_URL = "https://huggingface.co/ggml-org/stories15M_MOE/resolve/main/moe_shakespeare15M.gguf" + +@pytest.fixture(scope="module", autouse=True) +def create_server(): + global server + server = ServerPreset.stories15m_moe() + server.lora_files = [download_file(LORA_FILE_URL)] + + +def test_alias_presets_per_request(): + global server + server.n_slots = 4 + + preset_data = { + "bedtime-stories": { + "lora": [{"id": 0, "scale": 0.0}] + }, + "shakespeare-light": { + "lora": [{"id": 0, "scale": 0.3}] + }, + "shakespeare-medium": { + "lora": [{"id": 0, "scale": 0.7}] + }, + "shakespeare-full": { + "lora": [{"id": 0, "scale": 1.0}] + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(preset_data, f) + preset_file_path = f.name + + try: + server.alias_presets_file = preset_file_path + server.start() + + # running the same prompt with different model aliases, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + alias_config = [ + ("bedtime-stories", "(bright|day|many|happy)+"), + ("bedtime-stories", "(bright|day|many|happy)+"), + ("shakespeare-light", "(special|thing|gifted)+"), + ("shakespeare-medium", "(far|from|home|away)+"), + ("shakespeare-full", "(eye|love|glass|sun)+"), + ("shakespeare-full", "(eye|love|glass|sun)+"), + ] + + tasks = [( + server.make_request, + ("POST", "/completions", { + "model": model_alias, + "prompt": prompt, + "seed": 42, + "temperature": 0.0, + "cache_prompt": False, + }) + ) for model_alias, _ in alias_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, alias_config): + assert match_regex(re_test, res.body["content"]) + + finally: + server.stop() + os.unlink(preset_file_path) + +def test_alias_override(): + # test whether we honor the user's override even in case a preset is set + global server + server.n_slots = 2 + + # Use the same preset data as test_alias_presets_per_request + preset_data = { + "bedtime-stories": { + "lora": [{"id": 0, "scale": 0.0}] + }, + "shakespeare-light": { + "lora": [{"id": 0, "scale": 0.3}] + }, + "shakespeare-medium": { + "lora": [{"id": 0, "scale": 0.7}] + }, + "shakespeare-full": { + "lora": [{"id": 0, "scale": 1.0}] + } + } + + with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f: + json.dump(preset_data, f) + preset_file_path = f.name + + try: + server.alias_presets_file = preset_file_path + server.start() + + prompt = "Look in thy glass" + + res1 = server.make_request("POST", "/completions", { + "model": "bedtime-stories", + "prompt": prompt, + "cache_prompt": False, + }) + + # override to shakespeare + res2 = server.make_request("POST", "/completions", { + "model": "bedtime-stories", + "prompt": prompt, + "cache_prompt": False, + "lora": [{"id": 0, "scale": 1.0}], + }) + + assert res1.status_code == 200 + assert res2.status_code == 200 + + assert match_regex("(bright|day|many|happy)+", res1.body["content"]) + assert match_regex("(eye|love|glass|sun)+", res2.body["content"]) + assert res1.body["content"] != res2.body["content"] + + finally: + server.stop() + os.unlink(preset_file_path) diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index bc547ca03bf1b..7d6564d6f2578 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -88,6 +88,7 @@ class ServerProcess: reasoning_budget: int | None = None chat_template: str | None = None chat_template_file: str | None = None + alias_presets_file: str | None = None server_path: str | None = None mmproj_url: str | None = None @@ -198,6 +199,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: server_args.extend(["--chat-template", self.chat_template]) if self.chat_template_file: server_args.extend(["--chat-template-file", self.chat_template_file]) + if self.alias_presets_file: + server_args.extend(["--alias-presets-file", self.alias_presets_file]) if self.mmproj_url: server_args.extend(["--mmproj-url", self.mmproj_url])