Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions areal/experimental/openai/proxy/proxy_rollout_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,37 @@ def _setup_openai_client():
engine_max_tokens=agent_cfg.engine_max_tokens,
chat_template_type=agent_cfg.chat_template_type,
)
# Set session timeout from config
_session_timeout_seconds = agent_cfg.session_timeout_seconds
with _lock:
_admin_api_key = agent_cfg.admin_api_key
if _admin_api_key == DEFAULT_ADMIN_API_KEY:
# Validate admin API key BEFORE assigning it to the global, so a
# failed validation cannot leave the default key live on the server.
requested_admin_key = agent_cfg.admin_api_key
if requested_admin_key == DEFAULT_ADMIN_API_KEY:
# The default admin key is publicly known. Refuse to use it when
# the server is reachable from outside the local host, otherwise
# any attacker who can reach this port can call admin endpoints
# (grant_capacity, start_session, export_trajectories, ...).
loopback_hosts = {"127.0.0.1", "::1", "localhost"}
allow_override = (
os.environ.get("AREAL_ALLOW_DEFAULT_ADMIN_KEY", "0") == "1"
)
if _server_host in loopback_hosts or allow_override:
logger.warning(
"Using default admin API key. Change 'admin_api_key' in "
"AgentConfig for non-local deployments."
"AgentConfig before exposing this server on a network."
)
else:
raise RuntimeError(
"Refusing to start proxy rollout server on non-loopback "
f"host {_server_host!r} with the default admin API key "
f"({DEFAULT_ADMIN_API_KEY!r}). Set 'admin_api_key' in "
"OpenAIProxyConfig to a unique secret, or set "
"AREAL_ALLOW_DEFAULT_ADMIN_KEY=1 to acknowledge the risk "
"in a trusted environment."
)
# Only commit the key to the global after validation has passed.
with _lock:
_admin_api_key = requested_admin_key


@app.post("/configure")
Expand Down Expand Up @@ -1017,7 +1040,7 @@ def main():
# Run uvicorn directly (blocking)
uvicorn.run(
app,
host="0.0.0.0",
host=_server_host,
port=_server_port,
log_level="warning",
timeout_keep_alive=300,
Expand Down
Loading