From cf731c1def04e15ff9735e906585f0a6d1d1c8ae Mon Sep 17 00:00:00 2001 From: Ben Barclay Date: Fri, 27 Feb 2026 10:43:58 +1100 Subject: [PATCH 1/2] Add bearer token auth --- .../inference-node/src/bin/gateway-node.rs | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 4a06e6683..36a380938 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -12,7 +12,7 @@ use anyhow::{Context, Result}; use axum::{ Json, Router, extract::State, - http::StatusCode, + http::{HeaderMap, StatusCode}, response::{IntoResponse, Response}, routing::post, }; @@ -50,6 +50,10 @@ struct Args { #[arg(long)] write_endpoint_file: Option, + + /// Bearer token secret required for API authentication (optional) + #[arg(long, env = "GATEWAY_API_SECRET")] + api_secret: Option, } #[derive(Clone, Debug)] @@ -66,6 +70,7 @@ struct GatewayState { available_nodes: RwLock>, pending_requests: RwLock>>, network_tx: mpsc::Sender<(EndpointId, InferenceMessage)>, + api_secret: Option, } #[derive(serde::Deserialize, serde::Serialize, Clone, Debug)] @@ -118,8 +123,20 @@ struct ChatCompletionResponse { #[axum::debug_handler] async fn handle_inference( State(state): State>, + headers: HeaderMap, Json(req): Json, ) -> Result, AppError> { + if let Some(ref secret) = state.api_secret { + let authorized = headers + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")) + .is_some_and(|token| token == secret); + if !authorized { + return Err(AppError::Unauthorized); + } + } + let nodes = state.available_nodes.read().await; let node = nodes.values().next().ok_or(AppError::NoNodesAvailable)?; @@ -201,6 +218,7 @@ enum AppError { NoNodesAvailable, Timeout, InternalError, + Unauthorized, } impl IntoResponse for AppError { @@ -212,6 +230,7 @@ impl IntoResponse for AppError { ), AppError::Timeout => (StatusCode::GATEWAY_TIMEOUT, "Inference request timed out"), AppError::InternalError => (StatusCode::INTERNAL_SERVER_ERROR, "Internal server error"), + AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized"), }; (status, message).into_response() } @@ -357,6 +376,7 @@ async fn run_gateway() -> Result<()> { available_nodes: RwLock::new(HashMap::new()), pending_requests: RwLock::new(HashMap::new()), network_tx, + api_secret: args.api_secret.clone(), }); info!("Gateway ready! Listening on http://{}", args.listen_addr); From ac40c9437c56d4b008e4726bd2445f8f9e56946c Mon Sep 17 00:00:00 2001 From: Ben Barclay Date: Fri, 27 Feb 2026 11:01:28 +1100 Subject: [PATCH 2/2] Add unit tests for Gateway auth --- .../inference-only/inference-node/Cargo.toml | 2 +- .../inference-node/src/bin/gateway-node.rs | 123 ++++++++++++++++++ 2 files changed, 124 insertions(+), 1 deletion(-) diff --git a/architectures/inference-only/inference-node/Cargo.toml b/architectures/inference-only/inference-node/Cargo.toml index 0e99cbc20..844ce7b31 100644 --- a/architectures/inference-only/inference-node/Cargo.toml +++ b/architectures/inference-only/inference-node/Cargo.toml @@ -42,6 +42,6 @@ uuid = { version = "1", features = ["v4"] } pyo3.workspace = true postcard.workspace = true axum = { version = "0.7", features = ["macros"] } -tower = { version = "0.4" } +tower = { version = "0.4", features = ["util"] } tower-http = { version = "0.5", features = ["cors"] } tikv-jemallocator.workspace = true diff --git a/architectures/inference-only/inference-node/src/bin/gateway-node.rs b/architectures/inference-only/inference-node/src/bin/gateway-node.rs index 36a380938..595498744 100644 --- a/architectures/inference-only/inference-node/src/bin/gateway-node.rs +++ b/architectures/inference-only/inference-node/src/bin/gateway-node.rs @@ -514,3 +514,126 @@ async fn run_gateway() -> Result<()> { info!("Shutdown complete"); Ok(()) } + +#[cfg(test)] +mod tests { + use axum::{ + body::Body, + http::{Request, StatusCode}, + }; + use std::sync::Arc; + use tokio::sync::{RwLock, mpsc}; + use tower::ServiceExt; + + use super::*; + + fn make_app(secret: Option<&str>) -> axum::Router { + let (network_tx, _network_rx) = mpsc::channel(1); + let state = Arc::new(GatewayState { + available_nodes: RwLock::new(Default::default()), + pending_requests: RwLock::new(Default::default()), + network_tx, + api_secret: secret.map(|s| s.to_string()), + }); + Router::new() + .route( + "/v1/chat/completions", + axum::routing::post(handle_inference), + ) + .with_state(state) + } + + fn chat_request_body() -> &'static str { + r#"{"messages":[{"role":"user","content":"hello"}]}"# + } + + #[tokio::test] + async fn no_secret_configured_allows_unauthenticated_requests() { + let app = make_app(None); + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(chat_request_body())) + .unwrap(), + ) + .await + .unwrap(); + // No nodes available, but auth passed — expect 503 not 401 + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + } + + #[tokio::test] + async fn correct_bearer_token_is_accepted() { + let app = make_app(Some("supersecret")); + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("authorization", "Bearer supersecret") + .body(Body::from(chat_request_body())) + .unwrap(), + ) + .await + .unwrap(); + // Auth passed, no nodes available — expect 503 not 401 + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + } + + #[tokio::test] + async fn missing_auth_header_is_rejected() { + let app = make_app(Some("supersecret")); + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .body(Body::from(chat_request_body())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn wrong_token_is_rejected() { + let app = make_app(Some("supersecret")); + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("authorization", "Bearer wrongtoken") + .body(Body::from(chat_request_body())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } + + #[tokio::test] + async fn non_bearer_scheme_is_rejected() { + let app = make_app(Some("supersecret")); + let response = app + .oneshot( + Request::builder() + .method("POST") + .uri("/v1/chat/completions") + .header("content-type", "application/json") + .header("authorization", "Basic supersecret") + .body(Body::from(chat_request_body())) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + } +}