diff --git a/src/dstack/_internal/core/backends/base/compute.py b/src/dstack/_internal/core/backends/base/compute.py index 6a1f6af4a..c670b980e 100644 --- a/src/dstack/_internal/core/backends/base/compute.py +++ b/src/dstack/_internal/core/backends/base/compute.py @@ -840,6 +840,7 @@ def get_gateway_user_data(authorized_key: str) -> str: packages=[ "nginx", "python3.10-venv", + "python3-pip", # Add pip for sglang-router installation ], snap={"commands": [["install", "--classic", "certbot"]]}, runcmd=[ @@ -850,6 +851,8 @@ def get_gateway_user_data(authorized_key: str) -> str: "s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/", "/etc/nginx/nginx.conf", ], + # Install sglang-router system-wide. Can be conditionally installed in the future. + ["pip", "install", "sglang-router"], ["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())], ], ssh_authorized_keys=[authorized_key], @@ -979,7 +982,8 @@ def get_dstack_gateway_wheel(build: str) -> str: r.raise_for_status() build = r.text.strip() logger.debug("Found the latest gateway build: %s", build) - return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + # return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl" + return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.0-py3-none-any.whl" def get_dstack_gateway_commands() -> List[str]: diff --git a/src/dstack/_internal/core/models/gateways.py b/src/dstack/_internal/core/models/gateways.py index 6a480b580..159ada65c 100644 --- a/src/dstack/_internal/core/models/gateways.py +++ b/src/dstack/_internal/core/models/gateways.py @@ -50,6 +50,7 @@ class GatewayConfiguration(CoreModel): default: Annotated[bool, Field(description="Make the gateway default")] = False backend: Annotated[BackendType, Field(description="The gateway backend")] region: Annotated[str, Field(description="The gateway region")] + router: Annotated[Optional[str], Field(description="The router type, e.g. `sglang`")] = None domain: Annotated[ Optional[str], Field(description="The gateway domain, e.g. `example.com`") ] = None diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 index b096fa80e..8c50db11e 100644 --- a/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/service.jinja2 @@ -4,9 +4,13 @@ limit_req_zone {{ zone.key }} zone={{ zone.name }}:10m rate={{ zone.rpm }}r/m; {% if replicas %} upstream {{ domain }}.upstream { + {% if router == "sglang" %} + server 127.0.0.1:3000; # SGLang router on the gateway + {% else %} {% for replica in replicas %} server unix:{{ replica.socket }}; # replica {{ replica.id }} {% endfor %} + {% endif %} } {% else %} diff --git a/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 b/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 new file mode 100644 index 000000000..e5ffe12ee --- /dev/null +++ b/src/dstack/_internal/proxy/gateway/resources/nginx/sglang_workers.jinja2 @@ -0,0 +1,23 @@ +{% for replica in replicas %} +# Worker {{ loop.index }} +upstream sglang_worker_{{ ports[loop.index0] }}_upstream { + server unix:{{ replica.socket }}; +} + +server { + listen 127.0.0.1:{{ ports[loop.index0] }}; + access_log off; # disable access logs for this internal endpoint + + proxy_read_timeout 300s; + proxy_send_timeout 300s; + + location / { + proxy_pass http://sglang_worker_{{ ports[loop.index0] }}_upstream; + proxy_http_version 1.1; + proxy_set_header Host $host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header Connection ""; + proxy_set_header Upgrade $http_upgrade; + } +} +{% endfor %} diff --git a/src/dstack/_internal/proxy/gateway/routers/registry.py b/src/dstack/_internal/proxy/gateway/routers/registry.py index e1bfa4ff2..dd4f63f32 100644 --- a/src/dstack/_internal/proxy/gateway/routers/registry.py +++ b/src/dstack/_internal/proxy/gateway/routers/registry.py @@ -36,6 +36,7 @@ async def register_service( model=body.options.openai.model if body.options.openai is not None else None, ssh_private_key=body.ssh_private_key, repo=repo, + router=body.router, nginx=nginx, service_conn_pool=service_conn_pool, ) diff --git a/src/dstack/_internal/proxy/gateway/schemas/registry.py b/src/dstack/_internal/proxy/gateway/schemas/registry.py index 8ab69b6af..117152a95 100644 --- a/src/dstack/_internal/proxy/gateway/schemas/registry.py +++ b/src/dstack/_internal/proxy/gateway/schemas/registry.py @@ -44,6 +44,7 @@ class RegisterServiceRequest(BaseModel): options: Options ssh_private_key: str rate_limits: tuple[RateLimit, ...] = () + router: Optional[str] = None class RegisterReplicaRequest(BaseModel): diff --git a/src/dstack/_internal/proxy/gateway/services/nginx.py b/src/dstack/_internal/proxy/gateway/services/nginx.py index 2d3e755ac..50b8be01d 100644 --- a/src/dstack/_internal/proxy/gateway/services/nginx.py +++ b/src/dstack/_internal/proxy/gateway/services/nginx.py @@ -1,6 +1,8 @@ import importlib.resources +import json import subprocess import tempfile +import urllib.parse from asyncio import Lock from pathlib import Path from typing import Optional @@ -64,6 +66,8 @@ class ServiceConfig(SiteConfig): limit_req_zones: list[LimitReqZoneConfig] locations: list[LocationConfig] replicas: list[ReplicaConfig] + router: Optional[str] = None + model_id: Optional[str] = None class ModelEntrypointConfig(SiteConfig): @@ -77,15 +81,41 @@ class Nginx: def __init__(self, conf_dir: Path = Path("/etc/nginx/sites-enabled")) -> None: self._conf_dir = conf_dir self._lock: Lock = Lock() + self._domain_to_model_id: dict[str, str] = {} + # Track next available port for worker allocation + self._next_worker_port: int = 10001 # Start from 10001 + # Track which ports are used by which domain + self._domain_to_ports: dict[str, list[int]] = {} # domain -> [port1, port2, ...] async def register(self, conf: SiteConfig, acme: ACMESettings) -> None: logger.debug("Registering %s domain %s", conf.type, conf.domain) conf_name = self.get_config_name(conf.domain) - async with self._lock: if conf.https: await run_async(self.run_certbot, conf.domain, acme) await run_async(self.write_conf, conf.render(), conf_name) + if hasattr(conf, "router") and conf.router == "sglang": + replicas = len(conf.replicas) + model_id = conf.model_id + logger.info( + "Registering sglang router with model %s with %d replicas in domain %s", + model_id, + replicas, + conf.domain, + ) + self._domain_to_model_id[conf.domain] = model_id + # Allocate unique ports for this service + allocated_ports = list( + range(self._next_worker_port, self._next_worker_port + replicas) + ) + self._domain_to_ports[conf.domain] = allocated_ports + self._next_worker_port += replicas # Reserve ports for next service + + # Pass allocated ports to worker config generation + await run_async(self.write_sglang_workers_conf, conf, allocated_ports) + await run_async( + self.start_or_update_sglang_router, replicas, model_id, allocated_ports + ) logger.info("Registered %s domain %s", conf.type, conf.domain) @@ -96,6 +126,19 @@ async def unregister(self, domain: str) -> None: return async with self._lock: await run_async(sudo_rm, conf_path) + workers_conf_path = self._conf_dir / f"sglang-workers.{domain}.conf" + if workers_conf_path.exists(): + await run_async(sudo_rm, workers_conf_path) + # Get model_id for this domain before removing workers + model_id = self._domain_to_model_id.get(domain) + # This allows other services with different models to continue running + await run_async(self.remove_sglang_workers_for_model, model_id) + # Clean up the mapping + del self._domain_to_model_id[domain] + # Free up ports + if domain in self._domain_to_ports: + del self._domain_to_ports[domain] + # await run_async(self.stop_sglang_router) await run_async(self.reload) logger.info("Unregistered domain %s", domain) @@ -106,6 +149,226 @@ def reload() -> None: if r.returncode != 0: raise UnexpectedProxyError("Failed to reload nginx") + @staticmethod + def start_or_update_sglang_router( + replicas: int, model_id: str, allocated_ports: list[int] + ) -> None: + if not Nginx.is_sglang_router_running(): + Nginx.start_sglang_router() + # Pass allocated ports to worker update + Nginx.update_sglang_router_workers(replicas, model_id, allocated_ports) + + @staticmethod + def is_sglang_router_running() -> bool: + """Check if sglang router is running and responding to HTTP requests.""" + try: + result = subprocess.run( + ["curl", "-s", "http://localhost:3000/workers"], capture_output=True, timeout=5 + ) + return result.returncode == 0 + except Exception as e: + logger.error(f"Error checking sglang router status: {e}") + return False + + @staticmethod + def start_sglang_router() -> None: + try: + logger.info("Starting sglang-router...") + cmd = [ + "python3", + "-m", + "sglang_router.launch_router", + "--host", + "0.0.0.0", + "--port", + "3000", + "--enable-igw", + "--log-level", + "debug", + "--log-dir", + "./router_logs", + ] + subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + + import time + + time.sleep(2) + + # Verify router is running + if not Nginx.is_sglang_router_running(): + raise Exception("Failed to start sglang router") + + logger.info("Sglang router started successfully") + + except Exception as e: + logger.error(f"Failed to start sglang-router: {e}") + raise + + @staticmethod + def get_sglang_router_workers(model_id: str) -> list[dict]: + try: + result = subprocess.run( + ["curl", "-s", "http://localhost:3000/workers"], capture_output=True, timeout=5 + ) + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + workers = response.get("workers", []) + workers = [w for w in workers if w.get("model_id") == model_id] + return workers + return [] + except Exception as e: + logger.error(f"Error getting sglang router workers: {e}") + return [] + + @staticmethod + def update_sglang_router_workers( + replicas: int, model_id: str, allocated_ports: list[int] + ) -> None: + """Update sglang router workers via HTTP API""" + try: + # Get current workers + current_workers = Nginx.get_sglang_router_workers(model_id) + current_worker_urls = {worker["url"] for worker in current_workers} + + # Calculate target worker URLs + # target_worker_urls = {f"http://127.0.0.1:{10000 + i}" for i in range(1, replicas + 1)} + # Use explicitly assigned ports instead of calculation + target_worker_urls = {f"http://127.0.0.1:{port}" for port in allocated_ports} + + # Workers to add + workers_to_add = target_worker_urls - current_worker_urls + # Workers to remove + workers_to_remove = current_worker_urls - target_worker_urls + + if workers_to_add: + logger.info("Sglang router update: adding %d workers", len(workers_to_add)) + if workers_to_remove: + logger.info("Sglang router update: removing %d workers", len(workers_to_remove)) + + # Add workers + for worker_url in sorted(workers_to_add): + success = Nginx.add_sglang_router_worker(worker_url, model_id) + if not success: + logger.warning("Failed to add worker %s, continuing with others", worker_url) + + # Remove workers + for worker_url in sorted(workers_to_remove): + success = Nginx.remove_sglang_router_worker(worker_url, model_id) + if not success: + logger.warning( + "Failed to remove worker %s, continuing with others", worker_url + ) + + except Exception as e: + logger.error(f"Error updating sglang router workers for model {model_id}: {e}") + raise + + @staticmethod + def add_sglang_router_worker(worker_url: str, model_id: str) -> bool: + try: + payload = {"url": worker_url, "worker_type": "regular", "model_id": model_id} + result = subprocess.run( + [ + "curl", + "-X", + "POST", + "http://localhost:3000/workers", + "-H", + "Content-Type: application/json", + "-d", + json.dumps(payload), + ], + capture_output=True, + timeout=5, + ) + + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + if response.get("status") == "accepted": + logger.info("Added worker %s to sglang router", worker_url) + return True + else: + logger.error("Failed to add worker %s: %s", worker_url, response) + return False + else: + logger.error("Failed to add worker %s: %s", worker_url, result.stderr.decode()) + return False + except Exception as e: + logger.error(f"Error adding worker {worker_url}: {e}") + return False + + @staticmethod + def remove_sglang_router_worker(worker_url: str, model_id: str) -> bool: + """Remove a single worker from sglang router""" + try: + # URL encode the worker URL for the DELETE request + encoded_url = urllib.parse.quote(worker_url, safe="") + + result = subprocess.run( + ["curl", "-X", "DELETE", f"http://localhost:3000/workers/{encoded_url}"], + capture_output=True, + timeout=5, + ) + + if result.returncode == 0: + response = json.loads(result.stdout.decode()) + if response.get("status") == "accepted": + logger.info( + "Removed worker %s from sglang router model %s", worker_url, model_id + ) + return True + else: + logger.error( + "Failed to remove worker %s model %s: %s", worker_url, model_id, response + ) + return False + else: + logger.error( + "Failed to remove worker %s model %s: %s", + worker_url, + model_id, + result.stderr.decode(), + ) + return False + except Exception as e: + logger.error(f"Error removing worker {worker_url} model {model_id}: {e}") + return False + + @staticmethod + def remove_sglang_workers_for_model(model_id: str) -> None: + try: + workers = Nginx.get_sglang_router_workers(model_id) + for worker in workers: + Nginx.remove_sglang_router_worker(worker["url"], model_id) + except Exception as e: + logger.error(f"Error removing sglang router workers for model {model_id}: {e}") + raise + + @staticmethod + def stop_sglang_router() -> None: + try: + result = subprocess.run( + ["pgrep", "-f", "sglang::router"], capture_output=True, timeout=5 + ) + if result.returncode == 0: + logger.info("Stopping sglang-router process...") + subprocess.run(["pkill", "-f", "sglang::router"], timeout=5) + else: + logger.debug("No sglang-router process found to stop") + + log_dir = Path("./router_logs") + if log_dir.exists(): + logger.debug("Cleaning up router logs...") + import shutil + + shutil.rmtree(log_dir, ignore_errors=True) + else: + logger.debug("No router logs directory found to clean up") + + except Exception as e: + logger.error(f"Failed to stop sglang-router: {e}") + raise + def write_conf(self, conf: str, conf_name: str) -> None: """Update config and reload nginx. Rollback changes on error.""" conf_path = self._conf_dir / conf_name @@ -168,6 +431,23 @@ def write_global_conf(self) -> None: conf = read_package_resource("00-log-format.conf") self.write_conf(conf, "00-log-format.conf") + def write_sglang_workers_conf(self, conf: SiteConfig, allocated_ports: list[int]) -> None: + # Pass ports to template + workers_config = generate_sglang_workers_config(conf, allocated_ports) + workers_conf_name = f"sglang-workers.{conf.domain}.conf" + workers_conf_path = self._conf_dir / workers_conf_name + sudo_write(workers_conf_path, workers_config) + self.reload() + + +def generate_sglang_workers_config(conf: SiteConfig, allocated_ports: list[int]) -> str: + template = read_package_resource("sglang_workers.jinja2") + return jinja2.Template(template).render( + replicas=conf.replicas, + ports=allocated_ports, + proxy_port=PROXY_PORT_ON_GATEWAY, + ) + def read_package_resource(file: str) -> str: return ( diff --git a/src/dstack/_internal/proxy/gateway/services/registry.py b/src/dstack/_internal/proxy/gateway/services/registry.py index 3ea412d79..2fcadfedd 100644 --- a/src/dstack/_internal/proxy/gateway/services/registry.py +++ b/src/dstack/_internal/proxy/gateway/services/registry.py @@ -44,6 +44,7 @@ async def register_service( repo: GatewayProxyRepo, nginx: Nginx, service_conn_pool: ServiceConnectionPool, + router: Optional[str] = None, ) -> None: service = models.Service( project_name=project_name, @@ -54,6 +55,8 @@ async def register_service( auth=auth, client_max_body_size=client_max_body_size, replicas=(), + router=router, + model_id=model.name if model is not None else None, ) async with lock: @@ -335,6 +338,8 @@ async def get_nginx_service_config( limit_req_zones=limit_req_zones, locations=locations, replicas=sorted(replicas, key=lambda r: r.id), # sort for reproducible configs + router=service.router, + model_id=service.model_id, ) diff --git a/src/dstack/_internal/proxy/lib/models.py b/src/dstack/_internal/proxy/lib/models.py index 5cb5471d8..c24ce0fe2 100644 --- a/src/dstack/_internal/proxy/lib/models.py +++ b/src/dstack/_internal/proxy/lib/models.py @@ -57,6 +57,8 @@ class Service(ImmutableModel): client_max_body_size: int # only enforced on gateways strip_prefix: bool = True # only used in-server replicas: tuple[Replica, ...] + router: Optional[str] = None + model_id: Optional[str] = None @property def domain_safe(self) -> str: diff --git a/src/dstack/_internal/server/services/gateways/client.py b/src/dstack/_internal/server/services/gateways/client.py index f8c090079..33bfee3f5 100644 --- a/src/dstack/_internal/server/services/gateways/client.py +++ b/src/dstack/_internal/server/services/gateways/client.py @@ -45,6 +45,7 @@ async def register_service( options: dict, rate_limits: list[RateLimit], ssh_private_key: str, + router: Optional[str] = None, ): if "openai" in options: entrypoint = f"gateway.{domain.split('.', maxsplit=1)[1]}" @@ -59,6 +60,7 @@ async def register_service( "options": options, "rate_limits": [limit.dict() for limit in rate_limits], "ssh_private_key": ssh_private_key, + "router": router, } resp = await self._client.post( self._url(f"/api/registry/{project}/services/register"), json=payload diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index a8089a93a..05c1fa909 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -82,6 +82,7 @@ async def _register_service_in_gateway( gateway_configuration = get_gateway_configuration(gateway) service_https = _get_service_https(run_spec, gateway_configuration) + router = gateway_configuration.router service_protocol = "https" if service_https else "http" if service_https and gateway_configuration.certificate is None: @@ -119,6 +120,7 @@ async def _register_service_in_gateway( options=service_spec.options, rate_limits=run_spec.configuration.rate_limits, ssh_private_key=run_model.project.ssh_private_key, + router=router, ) logger.info("%s: service is registered as %s", fmt(run_model), service_spec.url) except SSHError: diff --git a/src/tests/_internal/server/routers/test_gateways.py b/src/tests/_internal/server/routers/test_gateways.py index 996157350..6d06e6969 100644 --- a/src/tests/_internal/server/routers/test_gateways.py +++ b/src/tests/_internal/server/routers/test_gateways.py @@ -70,6 +70,7 @@ async def test_list(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -121,6 +122,7 @@ async def test_get(self, test_db, session: AsyncSession, client: AsyncClient): "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": gateway.wildcard_domain, "default": False, "public_ip": True, @@ -201,6 +203,7 @@ async def test_create_gateway(self, test_db, session: AsyncSession, client: Asyn "name": "test", "backend": backend.type.value, "region": "us", + "router": None, "domain": None, "default": True, "public_ip": True, @@ -253,6 +256,7 @@ async def test_create_gateway_without_name( "name": "random-name", "backend": backend.type.value, "region": "us", + "router": None, "domain": None, "default": True, "public_ip": True, @@ -355,6 +359,7 @@ async def test_set_default_gateway(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": gateway.wildcard_domain, "default": True, "public_ip": True, @@ -477,6 +482,7 @@ def get_backend(project, backend_type): "name": gateway_gcp.name, "backend": backend_gcp.type.value, "region": gateway_gcp.region, + "router": None, "domain": gateway_gcp.wildcard_domain, "default": False, "public_ip": True, @@ -546,6 +552,7 @@ async def test_set_wildcard_domain(self, test_db, session: AsyncSession, client: "name": gateway.name, "backend": backend.type.value, "region": gateway.region, + "router": None, "domain": "test.com", "default": False, "public_ip": True,