Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion gateway/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,12 @@ requires-python = ">=3.10"
dynamic = ["version"]
dependencies = [
# release builds of dstack-gateway depend on a PyPI version of dstack instead
"dstack[gateway] @ git+https://github.com/dstackai/dstack.git@master",
"dstack[gateway] @ git+https://github.com/Bihan/dstack.git@add_sglang_router_support",
]

[project.optional-dependencies]
sglang = ["sglang-router==0.2.2"]

[tool.setuptools.package-data]
"dstack.gateway" = [
"resources/systemd/*",
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/aws/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,9 @@ def create_gateway(
image_id=aws_resources.get_gateway_image_id(ec2_client),
instance_type="t3.micro",
iam_instance_profile=None,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
tags=tags,
security_group_id=security_group_id,
spot=False,
Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,9 @@ def create_gateway(
image_reference=_get_gateway_image_ref(),
vm_size="Standard_B1ms",
instance_name=instance_name,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
ssh_pub_keys=[configuration.ssh_key_pub],
spot=False,
disk_size=30,
Expand Down
19 changes: 14 additions & 5 deletions src/dstack/_internal/core/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
SSHKey,
)
from dstack._internal.core.models.placement import PlacementGroup, PlacementGroupProvisioningData
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run
from dstack._internal.core.models.volumes import (
Volume,
Expand Down Expand Up @@ -876,7 +877,7 @@ def get_run_shim_script(
]


def get_gateway_user_data(authorized_key: str) -> str:
def get_gateway_user_data(authorized_key: str, router: Optional[AnyRouterConfig] = None) -> str:
return get_cloud_config(
package_update=True,
packages=[
Expand All @@ -892,7 +893,7 @@ def get_gateway_user_data(authorized_key: str) -> str:
"s/# server_names_hash_bucket_size 64;/server_names_hash_bucket_size 128;/",
"/etc/nginx/nginx.conf",
],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands())],
["su", "ubuntu", "-c", " && ".join(get_dstack_gateway_commands(router))],
],
ssh_authorized_keys=[authorized_key],
)
Expand Down Expand Up @@ -1021,16 +1022,24 @@ def get_dstack_gateway_wheel(build: str) -> str:
r.raise_for_status()
build = r.text.strip()
logger.debug("Found the latest gateway build: %s", build)
return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
# return f"{base_url}/dstack_gateway-{build}-py3-none-any.whl"
return "https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl"


def get_dstack_gateway_commands() -> List[str]:
def get_dstack_gateway_commands(router: Optional[AnyRouterConfig] = None) -> List[str]:
build = get_dstack_runner_version()
wheel = get_dstack_gateway_wheel(build)
# Use router type directly as pip extra
if router:
gateway_package = f"dstack-gateway[{router.type}]"
else:
gateway_package = "dstack-gateway"
return [
"mkdir -p /home/ubuntu/dstack",
"python3 -m venv /home/ubuntu/dstack/blue",
"python3 -m venv /home/ubuntu/dstack/green",
f"/home/ubuntu/dstack/blue/bin/pip install {get_dstack_gateway_wheel(build)}",
f"/home/ubuntu/dstack/blue/bin/pip install {wheel}",
f"/home/ubuntu/dstack/blue/bin/pip install --upgrade '{gateway_package}'",
Comment on lines +1041 to +1042
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) It should be possible to install the correct extra with one pip command:

$ pip install "dstack-gateway[sglang] @ https://bihan-test-bucket.s3.eu-west-1.amazonaws.com/dstack_gateway-0.0.1-py3-none-any.whl"

This should be slightly faster than calling pip twice.

"sudo /home/ubuntu/dstack/blue/bin/python -m dstack.gateway.systemd install --run",
]

Expand Down
4 changes: 3 additions & 1 deletion src/dstack/_internal/core/backends/gcp/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,7 +599,9 @@ def create_gateway(
machine_type="e2-medium",
accelerators=[],
spot=False,
user_data=get_gateway_user_data(configuration.ssh_key_pub),
user_data=get_gateway_user_data(
configuration.ssh_key_pub, router=configuration.router
),
authorized_keys=[configuration.ssh_key_pub],
labels=labels,
tags=[gcp_resources.DSTACK_GATEWAY_TAG],
Expand Down
6 changes: 6 additions & 0 deletions src/dstack/_internal/core/models/gateways.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import CoreModel
from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.utils.tags import tags_validator


Expand Down Expand Up @@ -50,6 +51,10 @@ class GatewayConfiguration(CoreModel):
default: Annotated[bool, Field(description="Make the gateway default")] = False
backend: Annotated[BackendType, Field(description="The gateway backend")]
region: Annotated[str, Field(description="The gateway region")]
router: Annotated[
Optional[AnyRouterConfig],
Field(description="The router configuration"),
] = None
domain: Annotated[
Optional[str], Field(description="The gateway domain, e.g. `example.com`")
] = None
Expand Down Expand Up @@ -113,6 +118,7 @@ class GatewayComputeConfiguration(CoreModel):
ssh_key_pub: str
certificate: Optional[AnyGatewayCertificate] = None
tags: Optional[Dict[str, str]] = None
router: Optional[AnyRouterConfig] = None


class GatewayProvisioningData(CoreModel):
Expand Down
34 changes: 34 additions & 0 deletions src/dstack/_internal/core/models/routers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum
from typing import Union

from pydantic import Field
from typing_extensions import Annotated, Literal

from dstack._internal.core.models.common import CoreModel


class RouterType(str, Enum):
SGLANG = "sglang"
SGLANG_DEPRECATED = "sglang_deprecated"
SGLANG_NEW = "sglang_new"
VLLM = "vllm"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM routers are not yet supported, yet vLLM settings will be visible to users in things like the automatically generated reference and IDE hints.

I'd suggest to remove anything vLLM-related to avoid confusing users.



class SGLangRouterConfig(CoreModel):
type: Literal["sglang_deprecated"] = "sglang_deprecated"
policy: str = "cache_aware"


class SGLangNewRouterConfig(CoreModel):
type: Literal["sglang"] = "sglang"
policy: str = "cache_aware"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, setting policy to invalid values, such as policy: invalid, is ignored. The gateway is then provisioned successfully, but starting a service fails with a hard to debug error.

Unexpected error: status code 500 when requesting http://localhost:3000/api/project/ilya/runs/apply. Check the server logs for backend issues, and the CLI
logs at (~/.dstack/logs/cli/latest.log) local CLI output

I can suggest to make this field a typing.Literal that only allows valid values.

  --policy {random,round_robin,cache_aware,power_of_two}
                        Load balancing policy to use. In PD mode, this is used for both prefill and decode unless overridden (default: cache_aware)

Then the user will see a detailed configuration validation error when trying to create the gateway.



class VLLMRouterConfig(CoreModel):
type: Literal["vllm"] = "vllm"
policy: str = "cache_aware"


AnyRouterConfig = Annotated[
Union[SGLangRouterConfig, SGLangNewRouterConfig, VLLMRouterConfig], Field(discriminator="type")
]
26 changes: 26 additions & 0 deletions src/dstack/_internal/proxy/gateway/model_routers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from typing import Optional

from dstack._internal.core.models.routers import AnyRouterConfig
from dstack._internal.proxy.gateway.model_routers.sglang import SglangRouter
from dstack._internal.proxy.gateway.model_routers.sglang_new import SglangRouterNew

from .base import Replica, Router, RouterContext


def get_router(router: AnyRouterConfig, context: Optional[RouterContext] = None) -> Router:
"""Factory function to create a router instance from router configuration."""
if router.type == "sglang":
return SglangRouterNew(router=router, context=context)
if router.type == "sglang_deprecated":
return SglangRouter(router=router, context=context)
if router.type == "sglang_new":
return SglangRouterNew(router=router, context=context)
raise ValueError(f"Router type '{router.type}' is not available")


__all__ = [
"Router",
"RouterContext",
"Replica",
"get_router",
]
147 changes: 147 additions & 0 deletions src/dstack/_internal/proxy/gateway/model_routers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Literal, Optional

from pydantic import BaseModel

from dstack._internal.core.models.routers import AnyRouterConfig


class RouterContext(BaseModel):
"""Context for router initialization and configuration."""

class Config:
frozen = True

host: str = "127.0.0.1"
port: int = 3000
log_dir: Path = Path("./router_logs")
log_level: Literal["debug", "info", "warning", "error"] = "info"


class Replica(BaseModel):
"""Represents a single replica (worker) endpoint managed by the router.
The model field identifies which model this replica serves.
In SGLang, model = model_id (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct").
"""

url: str # HTTP URL where the replica is accessible (e.g., "http://127.0.0.1:10001")
model: str # (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct")


class Router(ABC):
"""Abstract base class for router implementations (e.g., SGLang, vLLM).
A router manages the lifecycle of worker replicas and handles request routing.
Different router implementations may have different mechanisms for managing
replicas.
"""

def __init__(
self,
router: Optional[AnyRouterConfig] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) It seems a bit confusing to me that to construct a Router you need a router.

Maybe config is a better name for this parameter

context: Optional[RouterContext] = None,
):
"""Initialize router with context.
Args:
router: Optional router configuration (implementation-specific)
context: Runtime context for the router (host, port, logging, etc.)
"""
self.context = context or RouterContext()

@abstractmethod
def start(self) -> None:
"""Start the router process.
Raises:
Exception: If the router fails to start.
"""
...

@abstractmethod
def stop(self) -> None:
"""Stop the router process.
Raises:
Exception: If the router fails to stop.
"""
...
Comment on lines +63 to +70
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Unused, can delete here and in SglangRouter


@abstractmethod
def is_running(self) -> bool:
"""Check if the router is currently running and responding.
Returns:
True if the router is running and healthy, False otherwise.
"""
...

@abstractmethod
def register_replicas(
self, domain: str, num_replicas: int, model_id: Optional[str] = None
) -> List[Replica]:
"""Register replicas to a domain (allocate ports/URLs for workers).
Args:
domain: The domain name for this service.
num_replicas: The number of replicas to allocate for this domain.
model_id: Optional model identifier (e.g., "meta-llama/Meta-Llama-3.1-8B-Instruct").
Required only for routers that support IGW (Inference Gateway) mode for multi-model serving.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This method is supposed to return list[Replica], and Replica.model is required. So I can suggest to make the model_id parameter required too. We can make it optional later if we will actually need it for some other router type

Returns:
List of Replica objects with allocated URLs and model_id set (if provided).
Raises:
Exception: If allocation fails.
"""
...

@abstractmethod
def unregister_replicas(self, domain: str) -> None:
"""Unregister replicas for a domain (remove model and unassign all its replicas).
Args:
domain: The domain name for this service.
Raises:
Exception: If removal fails or domain is not found.
"""
...

@abstractmethod
def add_replicas(self, replicas: List[Replica]) -> None:
"""Register replicas with the router (actual API calls to add workers).
Args:
replicas: The list of replicas to add to router.
Raises:
Exception: If adding replicas fails.
"""
...

@abstractmethod
def remove_replicas(self, replicas: List[Replica]) -> None:
"""Unregister replicas from the router (actual API calls to remove workers).
Args:
replicas: The list of replicas to remove from router.
Raises:
Exception: If removing replicas fails.
"""
...
Comment on lines +113 to +135
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Unused, can delete here and in SglangRouter


@abstractmethod
def update_replicas(self, replicas: List[Replica]) -> None:
"""Update replicas for service, replacing the current set.
Args:
replicas: The new list of replicas for this service.
Raises:
Exception: If updating replicas fails.
"""
...
Loading