diff --git a/ci/loadWin.groovy b/ci/loadWin.groovy index 29b93e1df0..213f862c18 100644 --- a/ci/loadWin.groovy +++ b/ci/loadWin.groovy @@ -111,7 +111,7 @@ def install_dependencies() { def clean() { def output1 = bat(returnStdout: true, script: 'windows_clean_build.bat ' + get_short_bazel_path() + ' ' + env.OVMS_CLEAN_EXPUNGE) if(fileExists('dist\\windows\\ovms')){ - def status_del = bat(returnStatus: true, script: 'rmdir /s /q ovms') + def status_del = bat(returnStatus: true, script: 'rmdir /s /q dist\\windows\\ovms') if (status_del != 0) { error "Error: Deleting existing ovms directory failed ${status_del}. Check pipeline.log for details." } else { diff --git a/docs/parameters.md b/docs/parameters.md index 4b4663f9ec..95b4bd5885 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -56,6 +56,7 @@ Configuration options for the server are defined only via command-line options a | `allowed_headers` | `string` (default: *) | Comma-separated list of allowed headers in CORS requests. | | `allowed_methods` | `string` (default: *) | Comma-separated list of allowed methods in CORS requests. | | `allowed_origins` | `string` (default: *) | Comma-separated list of allowed origins in CORS requests. | +| `api_key_file` | `string` | Path to the text file with the API key for generative endpoints `/v3/`. The value of first line is used. If not specified, server is using environment variable API_KEY. If not set, requests will not require authorization.| ## Config management mode options diff --git a/docs/security_considerations.md b/docs/security_considerations.md index 6d6525836a..ebad42d54d 100644 --- a/docs/security_considerations.md +++ b/docs/security_considerations.md @@ -5,13 +5,9 @@ By default, the OpenVINO Model Server containers start with the security context of a local account `ovms` with Linux UID 5000. This ensures the Docker container does not have elevated permissions on the host machine. This is in line with best practices to use minimal permissions when running containerized applications. You can change the security context by adding the `--user` parameter to the Docker run command. This may be needed for loading mounted models with restricted access. For additional security hardening, you might also consider preventing write operations on the container root filesystem by adding a `--read-only` flag. This prevents undesired modification of the container files. In case the cloud storage used for the model repository (S3, Google Storage, or Azure storage) is restricting the root filesystem, it should be combined with `--tmpfs /tmp` flag. -```bash -mkdir -p models/resnet/1 -wget -P models/resnet/1 https://storage.openvinotoolkit.org/repositories/open_model_zoo/2022.1/models_bin/2/resnet50-binary-0001/FP32-INT1/resnet50-binary-0001.bin -wget -P models/resnet/1 https://storage.openvinotoolkit.org/repositories/open_model_zoo/2022.1/models_bin/2/resnet50-binary-0001/FP32-INT1/resnet50-binary-0001.xml - -docker run --rm -d --user $(id -u):$(id -g) --read-only --tmpfs /tmp -v ${PWD}/models/:/models -p 9178:9178 openvino/model_server:latest \ ---model_path /models/resnet/ --model_name resnet --port 9178 +``` +docker run --rm -d --user $(id -u):$(id -g) --read-only --tmpfs /tmp -p 9000:9000 openvino/model_server:latest \ +--model_path s3://bucket/model --model_name model --port 9000 ``` --- @@ -21,11 +17,16 @@ See also: - [Securing OVMS with NGINX](../extras/nginx-mtls-auth/README.md) - [Securing models with OVSA](https://docs.openvino.ai/2025/about-openvino/openvino-ecosystem/openvino-project/openvino-security-add-on.html) +--- +Generative endpoints starting with `/v3`, might be restricted with authorization and API key. It can be set during the server initialization with a parameter `api_key_file` or environment variable `API_KEY`. +The `api_key_file` should contain a path to the file containing the value of API key. The content of the file first line is used. If parameter api_key_file and variable API_KEY are not set, the server will not require any authorization. The client should send the API key inside the `Authorization` header as `Bearer `. + --- OpenVINO Model Server has a set of mechanisms preventing denial of service attacks from the client applications. They include the following: - setting the number of inference execution streams which can limit the number of parallel inference calls in progress for each model. It can be tuned with `NUM_STREAMS` or `PERFORMANCE_HINT` plugin config. - setting the maximum number of gRPC threads which is, by default, configured to the number 8 * number_of_cores. It can be changed with the parameter `--grpc_max_threads`. +- setting the maximum number of REST workers which is, be default, configured to the number 4 * number_of_cores. It can be changed with the parameter `--rest_workers`. - maximum size of REST and GRPC message which is 1GB - bigger messages will be rejected - setting max_concurrent_streams which defines how many concurrent threads can be initiated from a single client - the remaining will be queued. The default is equal to the number of CPU cores. It can be changed with the `--grpc_channel_arguments grpc.max_concurrent_streams=8`. - setting the gRPC memory quota for the requests buffer - the default is 2GB. It can be changed with `--grpc_memory_quota=2147483648`. Value `0` invalidates the quota. diff --git a/src/capi_frontend/server_settings.hpp b/src/capi_frontend/server_settings.hpp index aacb394e93..4b0700b1f1 100644 --- a/src/capi_frontend/server_settings.hpp +++ b/src/capi_frontend/server_settings.hpp @@ -178,6 +178,7 @@ struct ServerSettingsImpl { std::string allowedOrigins{"*"}; std::string allowedMethods{"*"}; std::string allowedHeaders{"*"}; + std::string apiKey; #ifdef MTR_ENABLED std::string tracePath; #endif diff --git a/src/cli_parser.cpp b/src/cli_parser.cpp index 9bcf1ba716..b004afb2ca 100644 --- a/src/cli_parser.cpp +++ b/src/cli_parser.cpp @@ -15,6 +15,7 @@ //***************************************************************************** #include "cli_parser.hpp" +#include #include #include #include @@ -35,6 +36,7 @@ namespace ovms { constexpr const char* CONFIG_MANAGEMENT_HELP_GROUP{"config management"}; +constexpr const char* API_KEY_ENV_VAR{"API_KEY"}; std::string getConfigPath(const std::string& configPath) { bool isDir = false; @@ -160,7 +162,11 @@ void CLIParser::parse(int argc, char** argv) { ("allowed_headers", "Comma separated list of headers that are allowed to access the API. Default: *.", cxxopts::value()->default_value("*"), - "ALLOWED_HEADERS"); + "ALLOWED_HEADERS") + ("api_key_file", + "path to the text file containing API key for authentication for generative endpoints. If not set, authentication is disabled.", + cxxopts::value()->default_value(""), + "API_KEY"); options->add_options("multi model") ("config_path", @@ -493,6 +499,31 @@ void CLIParser::prepareServer(ServerSettingsImpl& serverSettings) { serverSettings.allowedOrigins = result->operator[]("allowed_origins").as(); serverSettings.allowedMethods = result->operator[]("allowed_methods").as(); serverSettings.allowedHeaders = result->operator[]("allowed_headers").as(); + std::filesystem::path apiKeyFile = result->operator[]("api_key_file").as(); + serverSettings.apiKey = ""; + if (!apiKeyFile.empty()) { + std::ifstream file(apiKeyFile); + if (file.is_open()) { + std::getline(file, serverSettings.apiKey); + // Use first line and trim whitespace characters from both ends + size_t endpos = serverSettings.apiKey.find_last_not_of(" \n\r\t"); + if (endpos != std::string::npos) { + serverSettings.apiKey = serverSettings.apiKey.substr(0, endpos + 1); + } + file.close(); + } else { + std::cerr << "Error reading API key file: Unable to open file " << apiKeyFile << std::endl; + exit(OVMS_EX_USAGE); + } + } else { + const char* envApiKey = std::getenv(API_KEY_ENV_VAR); + if (envApiKey != nullptr) { + serverSettings.apiKey = envApiKey; + } + if (serverSettings.apiKey.empty()) { + std::cout << "Info: API key not provided via --api_key_file or API_KEY environment variable. Authentication will be disabled." << std::endl; + } + } } void CLIParser::prepareModel(ModelsSettingsImpl& modelsSettings, HFSettingsImpl& hfSettings) { diff --git a/src/config.cpp b/src/config.cpp index 35b906bb9c..59d5498117 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -37,11 +37,7 @@ const uint32_t WIN_MAX_GRPC_WORKERS = 1; const uint32_t MAX_PORT_NUMBER = std::numeric_limits::max(); // For drogon, we need to minimize the number of default workers since this value is set for both: unary and streaming (making it always double) -#if (USE_DROGON == 0) -const uint64_t DEFAULT_REST_WORKERS = AVAILABLE_CORES * 4.0; -#else const uint64_t DEFAULT_REST_WORKERS = AVAILABLE_CORES; -#endif const uint32_t DEFAULT_GRPC_MAX_THREADS = AVAILABLE_CORES * 8.0; const size_t DEFAULT_GRPC_MEMORY_QUOTA = (size_t)2 * 1024 * 1024 * 1024; // 2GB const uint64_t MAX_REST_WORKERS = 10'000; @@ -370,5 +366,6 @@ const std::string& Config::allowedOrigins() const { return this->serverSettings. const std::string& Config::allowedMethods() const { return this->serverSettings.allowedMethods; } const std::string& Config::allowedHeaders() const { return this->serverSettings.allowedHeaders; } const std::string Config::cacheDir() const { return this->serverSettings.cacheDir; } +const std::string& Config::apiKey() const { return this->serverSettings.apiKey; } } // namespace ovms diff --git a/src/config.hpp b/src/config.hpp index 1d98e28163..d05c0fad8e 100644 --- a/src/config.hpp +++ b/src/config.hpp @@ -323,6 +323,7 @@ class Config { const std::string& allowedOrigins() const; const std::string& allowedMethods() const; const std::string& allowedHeaders() const; + const std::string& apiKey() const; /** * @brief Model cache directory diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index b22284c3fb..d16795e9b5 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -15,6 +15,7 @@ //***************************************************************************** #include "http_rest_api_handler.hpp" +#include #include #include #include @@ -123,7 +124,8 @@ const std::string HttpRestApiHandler::v3_RegexExp = const std::string HttpRestApiHandler::metricsRegexExp = R"((.?)\/metrics(\?(.*))?)"; -HttpRestApiHandler::HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_ms) : +HttpRestApiHandler::HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_ms, const std::string& apiKey) : + apiKey(apiKey), predictionRegex(predictionRegexExp), modelstatusRegex(modelstatusRegexExp), configReloadRegex(configReloadRegexExp), @@ -668,6 +670,24 @@ Status HttpRestApiHandler::processListModelsRequest(std::string& response) { return StatusCode::OK; } +bool HttpRestApiHandler::isAuthorized(const std::unordered_map& headers, const std::string& apiKey) { + std::unordered_map lowercaseHeaders; + for (const auto& [key, value] : headers) { + std::string lowercaseKey = key; + std::transform(lowercaseKey.begin(), lowercaseKey.end(), lowercaseKey.begin(), ::tolower); + if (lowercaseKey == "authorization") { + if (value == "Bearer " + apiKey) { + return true; + } else { + SPDLOG_DEBUG("Unauthorized request - invalid API key provided."); + return false; + } + } + } + SPDLOG_DEBUG("Unauthorized request - missing API key"); + return false; +} + Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr serverReaderWriter, std::shared_ptr multiPartParser) { #if (MEDIAPIPE_DISABLE == 0) OVMS_PROFILE_FUNCTION(); @@ -675,7 +695,11 @@ Status HttpRestApiHandler::processV3(const std::string_view uri, const HttpReque HttpPayload request; std::string modelName; bool streamFieldVal = false; - + if (!this->apiKey.empty()) { + if (!isAuthorized(request_components.headers, this->apiKey)) { + return StatusCode::UNAUTHORIZED; + } + } auto status = createV3HttpPayload(uri, request_components, response, request_body, serverReaderWriter, std::move(multiPartParser), request, modelName, streamFieldVal); if (!status.ok()) { SPDLOG_DEBUG("Failed to create V3 payload: {}", status.string()); diff --git a/src/http_rest_api_handler.hpp b/src/http_rest_api_handler.hpp index 6589f059f7..6961e678da 100644 --- a/src/http_rest_api_handler.hpp +++ b/src/http_rest_api_handler.hpp @@ -115,7 +115,7 @@ class HttpRestApiHandler { * * @param timeout_in_ms */ - HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_ms); + HttpRestApiHandler(ovms::Server& ovmsServer, int timeout_in_ms, const std::string& apiKey = ""); Status parseRequestComponents(HttpRequestComponents& components, const std::string_view http_method, @@ -241,6 +241,8 @@ class HttpRestApiHandler { Status processV3(const std::string_view uri, const HttpRequestComponents& request_components, std::string& response, const std::string& request_body, std::shared_ptr serverReaderWriter, std::shared_ptr multiPartParser); Status processListModelsRequest(std::string& response); Status processRetrieveModelRequest(const std::string& name, std::string& response); + bool isAuthorized(const std::unordered_map& headers, const std::string& apiKey); + const std::string apiKey; private: const std::regex predictionRegex; diff --git a/src/http_server.cpp b/src/http_server.cpp index 5d57eefc90..4cb05cc286 100644 --- a/src/http_server.cpp +++ b/src/http_server.cpp @@ -37,26 +37,10 @@ #include "logging.hpp" #include "status.hpp" -#if (USE_DROGON == 0) -#pragma warning(push) -#pragma warning(disable : 4624 6001 6385 6386 6326 6011 4457 6308 6387 6246) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wall" -#pragma GCC diagnostic ignored "-Wunused-but-set-variable" -#include "tensorflow_serving/util/net_http/public/response_code_enum.h" -#include "tensorflow_serving/util/net_http/server/public/httpserver.h" -#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h" -#include "tensorflow_serving/util/threadpool_executor.h" - -#include "net_http_async_writer_impl.hpp" -#pragma GCC diagnostic pop -#pragma warning(pop) -#else #include #include "drogon_http_async_writer_impl.hpp" #include "http_frontend/multi_part_parser_drogon_impl.hpp" // At this point there is no going back to net_http -#endif namespace ovms { @@ -184,10 +168,9 @@ static const ovms::HTTPStatusCode http(const ovms::Status& status) { } } -#if (USE_DROGON == 1) -std::unique_ptr createAndStartDrogonHttpServer(const std::string& address, int port, int num_threads, ovms::Server& ovmsServer, int timeout_in_ms) { +std::unique_ptr createAndStartDrogonHttpServer(const std::string& address, int port, int num_threads, ovms::Server& ovmsServer, const ovms::Config& config, int timeout_in_ms) { auto server = std::make_unique(num_threads, num_threads, port, address); - auto handler = std::make_shared(ovmsServer, timeout_in_ms); + auto handler = std::make_shared(ovmsServer, timeout_in_ms, config.apiKey()); auto& pool = server->getPool(); server->registerRequestDispatcher([handler, &pool](const drogon::HttpRequestPtr& req, std::function drogonResponseInitializeCallback) { SPDLOG_DEBUG("REST request {}", req->getOriginalPath()); @@ -262,128 +245,4 @@ std::unique_ptr createAndStartDrogonHttpServer(const std::stri return server; } -#else - -class RequestExecutor final : public tensorflow::serving::net_http::EventExecutor { -public: - explicit RequestExecutor(int num_threads) : - executor_(tensorflow::Env::Default(), "httprestserver", num_threads) {} - - void Schedule(std::function fn) override { executor_.Schedule(std::move(fn)); } - -private: - tensorflow::serving::ThreadPoolExecutor executor_; -}; - -class RestApiRequestDispatcher { -public: - RestApiRequestDispatcher(ovms::Server& ovmsServer, int timeout_in_ms) { - handler_ = std::make_unique(ovmsServer, timeout_in_ms); - } - - tensorflow::serving::net_http::RequestHandler dispatch(tensorflow::serving::net_http::ServerRequestInterface* req) { - return [this](tensorflow::serving::net_http::ServerRequestInterface* req) { - try { - this->processRequest(req); - } catch (...) { - SPDLOG_DEBUG("Exception caught in REST request handler"); - req->ReplyWithStatus(tensorflow::serving::net_http::HTTPStatusCode::ERROR); - } - }; - } - -private: - void parseHeaders(const tensorflow::serving::net_http::ServerRequestInterface* req, std::vector>* headers) { - if (req->GetRequestHeader("Inference-Header-Content-Length").size() > 0) { - std::pair header{"Inference-Header-Content-Length", req->GetRequestHeader("Inference-Header-Content-Length")}; - headers->emplace_back(header); - } - } - void processRequest(tensorflow::serving::net_http::ServerRequestInterface* req) { - SPDLOG_DEBUG("REST request {}", req->uri_path()); - std::string body; - int64_t num_bytes = 0; - auto request_chunk = req->ReadRequestBytes(&num_bytes); - while (request_chunk != nullptr) { - body.append(std::string_view(request_chunk.get(), num_bytes)); - request_chunk = req->ReadRequestBytes(&num_bytes); - } - - std::vector> headers; - parseHeaders(req, &headers); - std::string output; - SPDLOG_DEBUG("Processing HTTP request: {} {} body: {} bytes", - req->http_method(), - req->uri_path(), - body.size()); - HttpResponseComponents responseComponents; - std::shared_ptr writer = std::make_shared(req); - const auto status = handler_->processRequest(req->http_method(), req->uri_path(), body, &headers, &output, responseComponents, writer); - if (status == StatusCode::PARTIAL_END) { - // No further messaging is required. - // Partial responses were delivered via "req" object. - return; - } - if (!status.ok() && output.empty()) { - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - writer.StartObject(); - writer.String("error"); - writer.String(status.string().c_str()); - writer.EndObject(); - output = buffer.GetString(); - } - const auto http_status = http(status); - if (responseComponents.inferenceHeaderContentLength.has_value()) { - std::pair header{"Inference-Header-Content-Length", std::to_string(responseComponents.inferenceHeaderContentLength.value())}; - headers.emplace_back(header); - } - for (const auto& kv : headers) { - req->OverwriteResponseHeader(kv.first, kv.second); - } - req->WriteResponseString(output); - if (int(http_status) != int(tensorflow::serving::net_http::HTTPStatusCode::OK) && int(http_status) != int(tensorflow::serving::net_http::HTTPStatusCode::CREATED)) { - SPDLOG_DEBUG("Processing HTTP/REST request failed: {} {}. Reason: {}", - req->http_method(), - req->uri_path(), - status.string()); - } - req->ReplyWithStatus(tensorflow::serving::net_http::HTTPStatusCode(http_status)); - } - - std::unique_ptr handler_; -}; - -std::unique_ptr createAndStartNetHttpServer(const std::string& address, int port, int num_threads, ovms::Server& ovmsServer, int timeout_in_ms) { - auto options = std::make_unique(); - options->AddPort(static_cast(port)); - options->SetAddress(address); - options->SetExecutor(std::make_unique(num_threads)); - - auto server = tensorflow::serving::net_http::CreateEvHTTPServer(std::move(options)); - if (server == nullptr) { - SPDLOG_ERROR("Failed to create http server"); - return nullptr; - } - - std::shared_ptr dispatcher = - std::make_shared(ovmsServer, timeout_in_ms); - - tensorflow::serving::net_http::RequestHandlerOptions handler_options; - server->RegisterRequestDispatcher( - [dispatcher](tensorflow::serving::net_http::ServerRequestInterface* req) { - return dispatcher->dispatch(std::move(req)); - }, - handler_options); - - if (server->StartAcceptingRequests()) { - SPDLOG_INFO("REST server listening on port {} with {} threads", port, num_threads); - return server; - } - - return nullptr; -} - -#endif - } // namespace ovms diff --git a/src/http_server.hpp b/src/http_server.hpp index 9368e8f35c..a40d7ff9da 100644 --- a/src/http_server.hpp +++ b/src/http_server.hpp @@ -17,30 +17,12 @@ #include #include - -#if (USE_DROGON == 0) -#pragma warning(push) -#pragma warning(disable : 4624 6001 6385 6386 6326 6011 4457 6308 6387 6246) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wall" -#pragma GCC diagnostic ignored "-Wunused-but-set-variable" -#include "tensorflow_serving/util/net_http/public/response_code_enum.h" -#include "tensorflow_serving/util/net_http/server/public/httpserver.h" -#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h" -#include "tensorflow_serving/util/threadpool_executor.h" -#pragma GCC diagnostic pop -#pragma warning(pop) -#else +#include "config.hpp" #include "drogon_http_server.hpp" -#endif namespace ovms { class Server; -#if (USE_DROGON == 0) -std::unique_ptr createAndStartNetHttpServer(const std::string& address, int port, int num_threads, ovms::Server& ovmsServer, int timeout_in_ms = -1); -#else -std::unique_ptr createAndStartDrogonHttpServer(const std::string& address, int port, int num_threads, ovms::Server& ovmsServer, int timeout_in_ms = -1); -#endif +std::unique_ptr createAndStartDrogonHttpServer(const std::string& address, int port, int num_threads, ovms::Server& ovmsServer, const ovms::Config& config, int timeout_in_ms = -1); } // namespace ovms diff --git a/src/httpservermodule.cpp b/src/httpservermodule.cpp index 8e3ffa05ba..663e701a0c 100644 --- a/src/httpservermodule.cpp +++ b/src/httpservermodule.cpp @@ -37,17 +37,8 @@ Status HTTPServerModule::start(const ovms::Config& config) { int workers = config.restWorkers() ? config.restWorkers() : 10; SPDLOG_INFO("Will start {} REST workers", workers); -#if (USE_DROGON == 0) - netHttpServer = ovms::createAndStartNetHttpServer(config.restBindAddress(), config.restPort(), workers, this->ovmsServer); - if (netHttpServer == nullptr) { - std::stringstream ss; - ss << "at " << server_address; - auto status = Status(StatusCode::FAILED_TO_START_REST_SERVER, ss.str()); - SPDLOG_ERROR(status.string()); - return status; - } -#else - drogonServer = ovms::createAndStartDrogonHttpServer(config.restBindAddress(), config.restPort(), workers, this->ovmsServer); + + drogonServer = ovms::createAndStartDrogonHttpServer(config.restBindAddress(), config.restPort(), workers, this->ovmsServer, config); if (drogonServer == nullptr) { std::stringstream ss; ss << "at " << server_address; @@ -55,7 +46,7 @@ Status HTTPServerModule::start(const ovms::Config& config) { SPDLOG_ERROR(status.string()); return status; } -#endif + curl_global_init(CURL_GLOBAL_ALL); state = ModuleState::INITIALIZED; SPDLOG_INFO("{} started", HTTP_SERVER_MODULE_NAME); @@ -63,23 +54,12 @@ Status HTTPServerModule::start(const ovms::Config& config) { return StatusCode::OK; } void HTTPServerModule::shutdown() { -#if (USE_DROGON == 0) - if (netHttpServer == nullptr) - return; -#else if (drogonServer == nullptr) return; -#endif SPDLOG_INFO("{} shutting down", HTTP_SERVER_MODULE_NAME); state = ModuleState::STARTED_SHUTDOWN; -#if (USE_DROGON == 0) - netHttpServer->Terminate(); - netHttpServer->WaitForTermination(); - netHttpServer.reset(); -#else drogonServer->terminate(); drogonServer.reset(); -#endif curl_global_cleanup(); SPDLOG_INFO("Shutdown HTTP server"); state = ModuleState::SHUTDOWN; diff --git a/src/httpservermodule.hpp b/src/httpservermodule.hpp index c271a5f5f6..3b6bb89d69 100644 --- a/src/httpservermodule.hpp +++ b/src/httpservermodule.hpp @@ -18,21 +18,7 @@ #include #include -#if (USE_DROGON == 0) -#pragma warning(push) -#pragma warning(disable : 4624 6001 6385 6386 6326 6011 4457 6308 6387 6246) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wall" -#pragma GCC diagnostic ignored "-Wunused-but-set-variable" -#include "tensorflow_serving/util/net_http/public/response_code_enum.h" -#include "tensorflow_serving/util/net_http/server/public/httpserver.h" -#include "tensorflow_serving/util/net_http/server/public/server_request_interface.h" -#include "tensorflow_serving/util/threadpool_executor.h" -#pragma GCC diagnostic pop -#pragma warning(pop) -#else #include "drogon_http_server.hpp" -#endif #include "http_server.hpp" #include "module.hpp" @@ -40,11 +26,7 @@ namespace ovms { class Config; class Server; class HTTPServerModule : public Module { -#if (USE_DROGON == 0) - std::unique_ptr netHttpServer; -#else std::unique_ptr drogonServer; -#endif Server& ovmsServer; public: diff --git a/src/server.cpp b/src/server.cpp index e21f12e5fa..6343e658e6 100644 --- a/src/server.cpp +++ b/src/server.cpp @@ -123,6 +123,7 @@ static void logConfig(const Config& config) { SPDLOG_DEBUG("gRPC channel arguments: {}", config.grpcChannelArguments()); SPDLOG_DEBUG("log level: {}", config.logLevel()); SPDLOG_DEBUG("log path: {}", config.logPath()); + SPDLOG_TRACE("API key: {}", config.getServerSettings().apiKey); SPDLOG_DEBUG("file system poll wait milliseconds: {}", config.filesystemPollWaitMilliseconds()); SPDLOG_DEBUG("sequence cleaner poll wait minutes: {}", config.sequenceCleanerPollWaitMinutes()); SPDLOG_DEBUG("model_repository_path: {}", config.getServerSettings().hfSettings.downloadPath); diff --git a/src/status.cpp b/src/status.cpp index bb32a1ccec..6dcfae8e34 100644 --- a/src/status.cpp +++ b/src/status.cpp @@ -133,6 +133,7 @@ const std::unordered_map Status::statusMessageMap = { {StatusCode::UNKNOWN_REQUEST_COMPONENTS_TYPE, "Request components type not recognized"}, {StatusCode::FAILED_TO_PARSE_MULTIPART_CONTENT_TYPE, "Request of multipart type but failed to parse"}, {StatusCode::FAILED_TO_DEDUCE_MODEL_NAME_FROM_URI, "Failed to deduce model name from all possible ways"}, + {StatusCode::UNAUTHORIZED, "Unauthorized request due to invalid or missing api-key"}, // Rest parser failure {StatusCode::REST_BODY_IS_NOT_AN_OBJECT, "Request body should be JSON object"}, diff --git a/src/status.hpp b/src/status.hpp index 270a230bb5..39c31c6bfe 100644 --- a/src/status.hpp +++ b/src/status.hpp @@ -176,6 +176,7 @@ enum class StatusCode { UNKNOWN_REQUEST_COMPONENTS_TYPE, /*!< Components type not recognized */ FAILED_TO_PARSE_MULTIPART_CONTENT_TYPE, /*!< Request of multipart type but failed to parse */ FAILED_TO_DEDUCE_MODEL_NAME_FROM_URI, /*!< Failed to deduce model name from all possible ways */ + UNAUTHORIZED, /*!< Unauthorized request due to invalid or missing api-key*/ // REST Parse REST_BODY_IS_NOT_AN_OBJECT, /*!< REST body should be JSON object */ diff --git a/src/test/c_api_stress_tests.cpp b/src/test/c_api_stress_tests.cpp index c7acaf0fe9..1ce77097c7 100644 --- a/src/test/c_api_stress_tests.cpp +++ b/src/test/c_api_stress_tests.cpp @@ -160,11 +160,8 @@ TEST_F(StressCapiConfigChanges, KFSAddNewVersionDuringPredictLoad) { requiredLoadResults, allowedLoadResults); } -#if (USE_DROGON == 0) -TEST_F(StressCapiConfigChanges, GetMetricsDuringLoad) { -#else + TEST_F(StressCapiConfigChanges, DISABLED_GetMetricsDuringLoad) { -#endif bool performWholeConfigReload = false; // we just need to have all model versions rechecked std::set requiredLoadResults = {StatusCode::OK}; // we expect full continuity of operation std::set allowedLoadResults = {}; diff --git a/src/test/ensemble_config_change_stress.cpp b/src/test/ensemble_config_change_stress.cpp index b4031a9fb4..d79125ca43 100644 --- a/src/test/ensemble_config_change_stress.cpp +++ b/src/test/ensemble_config_change_stress.cpp @@ -183,11 +183,8 @@ TEST_F(StressPipelineConfigChanges, KFSAddNewVersionDuringPredictLoad) { allowedLoadResults); } // Disabled because we cannot start http server multiple times https://github.com/drogonframework/drogon/issues/2210 -#if (USE_DROGON == 0) -TEST_F(ConfigChangeStressTest, GetMetricsDuringLoad) { -#else + TEST_F(ConfigChangeStressTest, DISABLED_GetMetricsDuringLoad) { -#endif bool performWholeConfigReload = false; // we just need to have all model versions rechecked std::set requiredLoadResults = {StatusCode::OK}; // we expect full continuity of operation std::set allowedLoadResults = {}; diff --git a/src/test/http_openai_handler_test.cpp b/src/test/http_openai_handler_test.cpp index f5fddf052a..53971a7016 100644 --- a/src/test/http_openai_handler_test.cpp +++ b/src/test/http_openai_handler_test.cpp @@ -67,6 +67,125 @@ class HttpOpenAIHandlerTest : public ::testing::Test { } }; +class HttpOpenAIHandlerAuthorizationTest : public ::testing::Test { +protected: + ovms::Server& server = ovms::Server::instance(); + std::unique_ptr handler; + + std::unique_ptr t; + std::string port = "9173"; + + std::unordered_map headers{{"content-type", "application/json"}}; + ovms::HttpRequestComponents comp; + const std::string endpoint = "/v3/chat/completions"; + std::shared_ptr writer; + std::shared_ptr multiPartParser; + std::string response; + ovms::HttpResponseComponents responseComponents; + + void SetUpServer(const char* configPath) { + // create temp file with api key + std::string apiKeyFile = getGenericFullPathForSrcTest("test_api_key.txt"); + std::ofstream ofs(apiKeyFile); + std::string absoluteApiKeyPath = std::filesystem::absolute(apiKeyFile).string(); + ofs << "1234"; + ofs.close(); + randomizeAndEnsureFree(this->port); + ::SetUpServer(this->t, this->server, this->port, configPath, 10, absoluteApiKeyPath); + EnsureServerStartedWithTimeout(this->server, 20); + handler = std::make_unique(server, 5, "1234"); + // remove temp file with api key + std::filesystem::remove(absoluteApiKeyPath); + } + + void SetUp() { + writer = std::make_shared(); + multiPartParser = std::make_shared(); + SetUpServer(getGenericFullPathForSrcTest("/ovms/src/test/mediapipe/config_mediapipe_openai_chat_completions_mock.json").c_str()); + ASSERT_EQ(handler->parseRequestComponents(comp, "POST", endpoint, headers), ovms::StatusCode::OK); + } + + void TearDown() { + handler.reset(); + server.setShutdownRequest(1); + t->join(); + server.setShutdownRequest(0); + } +}; + +TEST_F(HttpOpenAIHandlerAuthorizationTest, CorrectApiKey) { + std::string requestBody = R"( + { + "model": "gpt", + "messages": [] + } + )"; + const std::string URI = "/v3/chat/completions"; + comp.headers["authorization"] = "Bearer 1234"; + std::cout << "URI" << URI << std::endl; + std::cout << "BODY" << requestBody << std::endl; + std::cout << "KEY" << comp.headers["authorization"] << std::endl; + std::shared_ptr stream = std::make_shared(); + std::shared_ptr multiPartParser = std::make_shared(); + auto streamPtr = std::static_pointer_cast(stream); + std::string response; + auto status = handler->processV3("/v3/completions", comp, response, requestBody, streamPtr, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::OK) << status.string(); +} + +TEST_F(HttpOpenAIHandlerAuthorizationTest, CorrectApiKeyMissingModel) { + std::string requestBody = R"( + { + "model": "gpt-missing", + "messages": [] + } + )"; + const std::string URI = "/v3/chat/completions"; + comp.headers["authorization"] = "Bearer 1234"; + std::cout << "URI" << URI << std::endl; + std::cout << "BODY" << requestBody << std::endl; + std::cout << "KEY" << comp.headers["authorization"] << std::endl; + std::shared_ptr stream = std::make_shared(); + std::shared_ptr multiPartParser = std::make_shared(); + auto streamPtr = std::static_pointer_cast(stream); + std::string response; + auto status = handler->processV3("/v3/completions", comp, response, requestBody, streamPtr, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::MEDIAPIPE_DEFINITION_NAME_MISSING) << status.string(); +} + +TEST_F(HttpOpenAIHandlerAuthorizationTest, IncorrectApiKey) { + std::string requestBody = R"( + { + "model": "gpt", + "messages": [] + } + )"; + const std::string URI = "/v3/chat/completions"; + comp.headers["authorization"] = "Bearer ABCD"; + std::shared_ptr stream = std::make_shared(); + std::shared_ptr multiPartParser = std::make_shared(); + auto streamPtr = std::static_pointer_cast(stream); + std::string response; + auto status = handler->processV3("/v3/completions", comp, response, requestBody, streamPtr, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::UNAUTHORIZED) << status.string(); +} + +TEST_F(HttpOpenAIHandlerAuthorizationTest, MissingApiKey) { + std::string requestBody = R"( + { + "model": "gpt", + "messages": [] + } + )"; + const std::string URI = "/v3/chat/completions"; + std::shared_ptr stream = std::make_shared(); + std::shared_ptr multiPartParser = std::make_shared(); + auto streamPtr = std::static_pointer_cast(stream); + std::string response; + auto status = handler->processV3("/v3/completions", comp, response, requestBody, streamPtr, multiPartParser); + ASSERT_EQ(status, ovms::StatusCode::UNAUTHORIZED) << status.string(); +} + TEST_F(HttpOpenAIHandlerTest, Unary) { std::string requestBody = R"( { diff --git a/src/test/http_rest_api_handler_test.cpp b/src/test/http_rest_api_handler_test.cpp index 755857daf2..19f128ed2b 100644 --- a/src/test/http_rest_api_handler_test.cpp +++ b/src/test/http_rest_api_handler_test.cpp @@ -1310,3 +1310,30 @@ TEST_F(ConfigStatus, url_decode) { EXPECT_EQ("model%", ovms::urlDecode("model%")); EXPECT_EQ("model%2", ovms::urlDecode("model%2")); } + +TEST_F(ConfigStatus, isAuthorized) { + ovms::Server& ovmsServer = ovms::Server::instance(); + std::string contents; + auto fs = std::make_shared(); + fs->readTextFile(getGenericFullPathForSrcTest("/ovms/src/test/mediapipe/config_mediapipe_add_adapter_full.json"), &contents); + TestHelper1 t(*this, contents.c_str()); + auto handler = ovms::HttpRestApiHandler(ovmsServer, 10); + std::unordered_map headers = {{"X-Api-Key", "12345"}, + {"Content-Type", "application/json"}, + {"Authorization", "ABC"}}; + EXPECT_FALSE(handler.isAuthorized(headers, "wrong_key")); + headers = {{"X-Api-Key", "12345"}, + {"Content-Type", "application/json"}, + {"Authorization", "Bearer ABC"}}; + EXPECT_TRUE(handler.isAuthorized(headers, "ABC")); + + headers = {{"x-api-key", "12345"}, + {"content-type", "application/json"}, + {"authoriZation", "Bearer ABC123"}}; + EXPECT_TRUE(handler.isAuthorized(headers, "ABC123")); + + headers = {}; + EXPECT_FALSE(handler.isAuthorized(headers, "any_key")); + headers = {{"X-CustomHeader", "12345"}}; + EXPECT_FALSE(handler.isAuthorized(headers, "any_key")); +} diff --git a/src/test/metrics_flow_test.cpp b/src/test/metrics_flow_test.cpp index 70b3ebff65..4ea5dff857 100644 --- a/src/test/metrics_flow_test.cpp +++ b/src/test/metrics_flow_test.cpp @@ -828,7 +828,7 @@ TEST_F(MetricFlowTest, ModelReady) { #if (MEDIAPIPE_DISABLE == 0) TEST_F(MetricFlowTest, RestV3Unary) { - HttpRestApiHandler handler(server, 0); + HttpRestApiHandler handler(server, 0, ""); std::shared_ptr stream = std::make_shared(); std::shared_ptr multiPartParser = std::make_shared(); diff --git a/src/test/ovmsconfig_test.cpp b/src/test/ovmsconfig_test.cpp index ebe1c05f7a..de98c1815c 100644 --- a/src/test/ovmsconfig_test.cpp +++ b/src/test/ovmsconfig_test.cpp @@ -24,6 +24,7 @@ #include "spdlog/spdlog.h" #include "../capi_frontend/server_settings.hpp" +#include "./env_guard.hpp" #include "../config.hpp" #include "../filesystem.hpp" #include "../ovms_exit_codes.hpp" @@ -325,6 +326,12 @@ TEST_F(OvmsConfigDeathTest, NegativeListModelsWithoutModelRepositoryPath) { EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "Use --list_models with --model_repository_path"); } +TEST_F(OvmsConfigDeathTest, NegativeInvalidAPIKeyFile) { + char* n_argv[] = {"ovms", "--config_path", "/path1", "--api_key_file", "/wrong/dir", "--port", "44"}; + int arg_count = 7; + EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "Error reading API key file: Unable to open file \"/wrong/dir\""); +} + TEST_F(OvmsConfigDeathTest, negativeMissingDashes) { char* n_argv[] = { "ovms", @@ -1902,6 +1909,60 @@ TEST(OvmsGraphConfigTest, negativeEmbeddingsInvalidNormalize) { EXPECT_EXIT(ovms::Config::instance().parse(arg_count, n_argv), ::testing::ExitedWithCode(OVMS_EX_USAGE), "normalize: INVALID is not allowed. Supported values: true, false"); } +TEST(OvmsAPIKeyConfig, positiveAPIKeyFile) { + // Create a temporary API key file + std::ofstream apiKeyFileTmp("api_key.txt"); + apiKeyFileTmp << "1234"; + apiKeyFileTmp.close(); + std::string modelName = "test_name"; + std::string modelPath = "model_path"; + std::string apiKeyFile = "api_key.txt"; + std::string rest_port = "8080"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--model_path", + (char*)modelPath.c_str(), + (char*)"--model_name", + (char*)modelName.c_str(), + (char*)"--api_key_file", + (char*)apiKeyFile.c_str(), + (char*)"--rest_port", + (char*)rest_port.c_str(), + }; + + int arg_count = 9; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + + ASSERT_EQ(config.getServerSettings().apiKey, "1234"); + // Clean up the temporary file + std::remove("api_key.txt"); +} + +TEST(OvmsAPIKeyConfig, positiveAPIKeyEnv) { + EnvGuard envGuard; + envGuard.set("API_KEY", "ABCD"); + std::string modelName = "test_name"; + std::string modelPath = "model_path"; + std::string apiKeyFile = "api_key.txt"; + std::string rest_port = "8080"; + char* n_argv[] = { + (char*)"ovms", + (char*)"--model_path", + (char*)modelPath.c_str(), + (char*)"--model_name", + (char*)modelName.c_str(), + (char*)"--rest_port", + (char*)rest_port.c_str(), + }; + + int arg_count = 7; + ConstructorEnabledConfig config; + config.parse(arg_count, n_argv); + + ASSERT_EQ(config.getServerSettings().apiKey, "ABCD"); +} + class OvmsParamsTest : public ::testing::Test { }; diff --git a/src/test/standalone_http_server_test.cpp b/src/test/standalone_http_server_test.cpp index 46f7062f85..43cfbc3167 100644 --- a/src/test/standalone_http_server_test.cpp +++ b/src/test/standalone_http_server_test.cpp @@ -17,15 +17,10 @@ #include #include -#if (USE_DROGON == 1) #include -#endif - #include #include -#if (USE_DROGON == 1) - // Disabled due to drogon issue https://github.com/drogonframework/drogon/issues/2210 TEST(Drogon, DISABLED_basic) { for (int i = 0; i < 2; i++) { @@ -41,14 +36,3 @@ TEST(Drogon, DISABLED_basic) { k.join(); } } - -#endif - -// Make sure we have drogon enabled as default in production -TEST(Drogon, EnabledInProduction) { -#if (USE_DROGON == 1) - ASSERT_EQ(1, 1); -#else - ASSERT_EQ(1, 0); -#endif -} diff --git a/src/test/stress_test_utils.hpp b/src/test/stress_test_utils.hpp index fca269b0c5..9c1ec90252 100644 --- a/src/test/stress_test_utils.hpp +++ b/src/test/stress_test_utils.hpp @@ -1157,16 +1157,8 @@ class ConfigChangeStressTest : public TestWithTempDir { ASSERT_CAPI_STATUS_NULL(OVMS_ModelsSettingsNew(&modelsSettings)); randomizeAndEnsureFrees(port, restPort); ASSERT_CAPI_STATUS_NULL(OVMS_ServerSettingsSetGrpcPort(serverSettings, std::stoi(port))); -#if (USE_DROGON == 0) // when jusing drogon we cannot start rest server multiple times within the same process - ASSERT_CAPI_STATUS_NULL(OVMS_ServerSettingsSetRestPort(serverSettings, std::stoi(restPort))); // required for metrics - but disabled because drogon http server cannot be restarted -#endif - // ideally we would want to have emptyConfigWithMetrics -#if (USE_DROGON == 0) - ASSERT_CAPI_STATUS_NULL(OVMS_ModelsSettingsSetConfigPath(modelsSettings, getGenericFullPathForSrcTest("/ovms/src/test/configs/emptyConfigWithMetrics.json").c_str())); // the content of config json is irrelevant - we just need server to be ready for C-API use in mediapipe -#else ASSERT_CAPI_STATUS_NULL(OVMS_ModelsSettingsSetConfigPath(modelsSettings, getGenericFullPathForSrcTest("/ovms/src/test/configs/emptyConfig.json").c_str())); // the content of config json is irrelevant - we just need server to be ready for C-API use in mediapipe -#endif - ASSERT_CAPI_STATUS_NULL(OVMS_ServerSettingsSetFileSystemPollWaitSeconds(serverSettings, 0)); // set to 0 to reload only through test and avoid races + ASSERT_CAPI_STATUS_NULL(OVMS_ServerSettingsSetFileSystemPollWaitSeconds(serverSettings, 0)); // set to 0 to reload only through test and avoid races ASSERT_CAPI_STATUS_NULL(OVMS_ServerNew(&cserver)); ASSERT_CAPI_STATUS_NULL(OVMS_ServerStartFromConfigurationFile(cserver, serverSettings, modelsSettings)); OVMS_ModelsSettingsDelete(modelsSettings); diff --git a/src/test/test_utils.cpp b/src/test/test_utils.cpp index 6dc8c33d63..45b2ea563b 100644 --- a/src/test/test_utils.cpp +++ b/src/test/test_utils.cpp @@ -794,20 +794,36 @@ void SetUpServerForDownloadAndStartGGUF(std::unique_ptr& t, ovms::S EnsureServerStartedWithTimeout(server, timeoutSeconds); } -void SetUpServer(std::unique_ptr& t, ovms::Server& server, std::string& port, const char* configPath, int timeoutSeconds) { +void SetUpServer(std::unique_ptr& t, ovms::Server& server, std::string& port, const char* configPath, int timeoutSeconds, std::string api_key) { server.setShutdownRequest(0); randomizeAndEnsureFree(port); - char* argv[] = {(char*)"ovms", - (char*)"--config_path", - (char*)configPath, - (char*)"--port", - (char*)port.c_str()}; - int argc = 5; - t.reset(new std::thread([&argc, &argv, &server]() { - EXPECT_EQ(EXIT_SUCCESS, server.start(argc, argv)); - })); - EnsureServerStartedWithTimeout(server, timeoutSeconds); + if (!api_key.empty()) { + char* argv[] = {(char*)"ovms", + (char*)"--config_path", + (char*)configPath, + (char*)"--port", + (char*)port.c_str(), + (char*)"--api_key_file", + (char*)api_key.c_str()}; + int argc = 7; + t.reset(new std::thread([&argc, &argv, &server]() { + EXPECT_EQ(EXIT_SUCCESS, server.start(argc, argv)); + })); + EnsureServerStartedWithTimeout(server, timeoutSeconds); + } else { + char* argv[] = {(char*)"ovms", + (char*)"--config_path", + (char*)configPath, + (char*)"--port", + (char*)port.c_str()}; + int argc = 5; + t.reset(new std::thread([&argc, &argv, &server]() { + EXPECT_EQ(EXIT_SUCCESS, server.start(argc, argv)); + })); + EnsureServerStartedWithTimeout(server, timeoutSeconds); + } } + void SetUpServer(std::unique_ptr& t, ovms::Server& server, std::string& port, const char* modelPath, const char* modelName, int timeoutSeconds) { server.setShutdownRequest(0); randomizeAndEnsureFree(port); diff --git a/src/test/test_utils.hpp b/src/test/test_utils.hpp index 615a82cc14..c307fc255e 100644 --- a/src/test/test_utils.hpp +++ b/src/test/test_utils.hpp @@ -792,7 +792,7 @@ void SetUpServerForDownloadAndStart(std::unique_ptr& t, ovms::Serve /* * starts loading OVMS on separate thread but waits until it is ready */ -void SetUpServer(std::unique_ptr& t, ovms::Server& server, std::string& port, const char* configPath, int timeoutSeconds = SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); +void SetUpServer(std::unique_ptr& t, ovms::Server& server, std::string& port, const char* configPath, int timeoutSeconds = SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS, std::string apiKeyFile = ""); void SetUpServer(std::unique_ptr& t, ovms::Server& server, std::string& port, const char* modelPath, const char* modelName, int timeoutSeconds = SERVER_START_FROM_CONFIG_TIMEOUT_SECONDS); class ConstructorEnabledConfig : public ovms::Config {