diff --git a/contributing/AUTOSCALING.md b/contributing/AUTOSCALING.md
index 7fa987aaf..60a33aec1 100644
--- a/contributing/AUTOSCALING.md
+++ b/contributing/AUTOSCALING.md
@@ -11,6 +11,8 @@
- STEP 7: `scale_run_replicas` terminates or starts replicas.
- `SUBMITTED` and `PROVISIONING` replicas get terminated before `RUNNING`.
- Replicas are terminated by descending `replica_num` and launched by ascending `replica_num`.
+ - For services with `replica_groups`, only groups with autoscaling ranges (min != max) participate in scaling.
+ - Scale operations respect per-group minimum and maximum constraints.
## RPSAutoscaler
diff --git a/contributing/RUNS-AND-JOBS.md b/contributing/RUNS-AND-JOBS.md
index b2c0430af..18544caa5 100644
--- a/contributing/RUNS-AND-JOBS.md
+++ b/contributing/RUNS-AND-JOBS.md
@@ -13,7 +13,7 @@ Runs are created from run configurations. There are three types of run configura
2. `task` — runs the user's bash script until completion.
3. `service` — runs the user's bash script and exposes a port through [dstack-proxy](PROXY.md).
-A run can spawn one or multiple jobs, depending on the configuration. A task that specifies multiple `nodes` spawns a job for every node (a multi-node task). A service that specifies multiple `replicas` spawns a job for every replica. A job submission is always assigned to one particular instance. If a job fails and the configuration allows retrying, the server creates a new job submission for the job.
+A run can spawn one or multiple jobs, depending on the configuration. A task that specifies multiple `nodes` spawns a job for every node (a multi-node task). A service that specifies multiple `replicas` or `replica_groups` spawns a job for every replica. Each job in a replica group is tagged with `replica_group_name` to track which group it belongs to. A job submission is always assigned to one particular instance. If a job fails and the configuration allows retrying, the server creates a new job submission for the job.
## Run's Lifecycle
diff --git a/docs/docs/concepts/services.md b/docs/docs/concepts/services.md
index cb2649e00..383738538 100644
--- a/docs/docs/concepts/services.md
+++ b/docs/docs/concepts/services.md
@@ -160,6 +160,66 @@ Setting the minimum number of replicas to `0` allows the service to scale down t
> The `scaling` property requires creating a [gateway](gateways.md).
+### Replica Groups (Advanced)
+
+For advanced use cases, you can define multiple **replica groups** with different instance types, resources, and configurations within a single service. This is useful when you want to:
+
+- Run different GPU types in the same service (e.g., H100 for primary, RTX5090 for overflow)
+- Configure different backends or regions per replica type
+- Set different autoscaling behavior per group
+
+
+
+```yaml
+type: service
+name: llama31-service
+
+python: 3.12
+env:
+ - HF_TOKEN
+commands:
+ - uv pip install vllm
+ - vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 4096
+port: 8000
+
+# Define multiple replica groups with different configurations
+replica_groups:
+ - name: primary
+ replicas: 1 # Always 1 H100 (fixed)
+ resources:
+ gpu: H100:1
+ backends: [aws]
+ regions: [us-west-2]
+
+ - name: overflow
+ replicas: 0..5 # Autoscales 0-5 RTX5090s
+ resources:
+ gpu: RTX5090:1
+ backends: [runpod]
+
+scaling:
+ metric: rps
+ target: 10
+```
+
+
+
+In this example:
+
+- The `primary` group always runs 1 H100 replica on AWS (fixed, never scaled)
+- The `overflow` group scales 0-5 RTX5090 replicas on RunPod based on load
+- Scale operations only affect groups with autoscaling ranges (min != max)
+
+Each replica group can override any [profile parameter](../reference/profiles.yml.md) including `backends`, `regions`, `instance_types`, `spot_policy`, etc. Group-level settings override service-level settings.
+
+> **Note:** When using `replica_groups`, you cannot use the simple `replicas` field. They are mutually exclusive.
+
+**When to use replica groups:**
+
+- You need different GPU types in the same service
+- Different replicas should run in different regions or clouds
+- Some replicas should be fixed while others autoscale
+
### Model
If the service is running a chat model with an OpenAI-compatible interface,
diff --git a/docs/docs/reference/dstack.yml/service.md b/docs/docs/reference/dstack.yml/service.md
index 8d89b2d57..40509332e 100644
--- a/docs/docs/reference/dstack.yml/service.md
+++ b/docs/docs/reference/dstack.yml/service.md
@@ -10,6 +10,22 @@ The `service` configuration type allows running [services](../../concepts/servic
type:
required: true
+### `replica_groups`
+
+Define multiple replica groups with different configurations within a single service.
+
+> **Note:** Cannot be used together with `replicas`.
+
+#### `replica_groups[n]`
+
+#SCHEMA# dstack._internal.core.models.configurations.ReplicaGroup
+ overrides:
+ show_root_heading: false
+ type:
+ required: true
+
+Each replica group inherits from [ProfileParams](../profiles.yml.md) and can override any profile parameter including `backends`, `regions`, `instance_types`, `spot_policy`, etc.
+
### `model` { data-toc-label="model" }
=== "OpenAI"
diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py
index 58497c084..b61bf94fb 100644
--- a/src/dstack/_internal/cli/utils/run.py
+++ b/src/dstack/_internal/cli/utils/run.py
@@ -119,9 +119,55 @@ def th(s: str) -> str:
if include_run_properties:
props.add_row(th("Configuration"), run_spec.configuration_path)
props.add_row(th("Type"), run_spec.configuration.type)
- props.add_row(th("Resources"), pretty_req)
- props.add_row(th("Spot policy"), spot_policy)
- props.add_row(th("Max price"), max_price)
+
+ from dstack._internal.core.models.configurations import ServiceConfiguration
+
+ has_replica_groups = (
+ include_run_properties
+ and isinstance(run_spec.configuration, ServiceConfiguration)
+ and run_spec.configuration.replica_groups
+ )
+
+ if has_replica_groups:
+ groups_info = []
+ for group in run_spec.configuration.replica_groups:
+ group_parts = [f"[cyan]{group.name}[/cyan]"]
+
+ # Replica count
+ if group.replicas.min == group.replicas.max:
+ group_parts.append(f"×{group.replicas.max}")
+ else:
+ group_parts.append(f"×{group.replicas.min}..{group.replicas.max}")
+ group_parts.append("[dim](autoscalable)[/dim]")
+
+ # Resources
+ group_parts.append(f"[dim]({group.resources.pretty_format()})[/dim]")
+
+ # Group-specific overrides
+ overrides = []
+ if group.spot_policy is not None:
+ overrides.append(f"spot={group.spot_policy.value}")
+ if group.regions:
+ regions_str = ",".join(group.regions[:2]) # Show first 2
+ if len(group.regions) > 2:
+ regions_str += f",+{len(group.regions) - 2}"
+ overrides.append(f"regions={regions_str}")
+ if group.backends:
+ backends_str = ",".join([b.value for b in group.backends[:2]])
+ if len(group.backends) > 2:
+ backends_str += f",+{len(group.backends) - 2}"
+ overrides.append(f"backends={backends_str}")
+
+ if overrides:
+ group_parts.append(f"[dim]({'; '.join(overrides)})[/dim]")
+
+ groups_info.append(" ".join(group_parts))
+
+ props.add_row(th("Replica groups"), "\n".join(groups_info))
+ else:
+ props.add_row(th("Resources"), pretty_req)
+ props.add_row(th("Spot policy"), spot_policy)
+ props.add_row(th("Max price"), max_price)
if include_run_properties:
props.add_row(th("Retry policy"), retry)
props.add_row(th("Creation policy"), creation_policy)
@@ -139,44 +185,172 @@ def th(s: str) -> str:
offers.add_column("PRICE", style="grey58", ratio=1)
offers.add_column()
- job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers
-
- for i, offer in enumerate(job_plan.offers, start=1):
- r = offer.instance.resources
-
- availability = ""
- if offer.availability in {
- InstanceAvailability.NOT_AVAILABLE,
- InstanceAvailability.NO_QUOTA,
- InstanceAvailability.IDLE,
- InstanceAvailability.BUSY,
- }:
- availability = offer.availability.value.replace("_", " ").lower()
- instance = offer.instance.name
- if offer.total_blocks > 1:
- instance += f" ({offer.blocks}/{offer.total_blocks})"
- offers.add_row(
- f"{i}",
- offer.backend.replace("remote", "ssh") + " (" + offer.region + ")",
- r.pretty_format(include_spot=True),
- instance,
- f"${offer.price:.4f}".rstrip("0").rstrip("."),
- availability,
- style=None if i == 1 or not include_run_properties else "secondary",
- )
- if job_plan.total_offers > len(job_plan.offers):
- offers.add_row("", "...", style="secondary")
+ # For replica groups, show offers from all job plans
+ if len(run_plan.job_plans) > 1:
+ # Multiple jobs - ensure fair representation of all groups
+ groups_with_no_offers = []
+ groups_with_offers = {}
+ total_offers_count = 0
+
+ # Collect offers per group
+ for jp in run_plan.job_plans:
+ group_name = jp.job_spec.replica_group_name or "default"
+ if jp.total_offers == 0:
+ groups_with_no_offers.append(group_name)
+ else:
+ groups_with_offers[group_name] = jp.offers
+ total_offers_count += jp.total_offers
+
+ # Strategy: Show at least min_per_group offers from each group, then fill with cheapest
+ num_groups = len(groups_with_offers)
+ if num_groups > 0 and max_offers:
+ min_per_group = max(
+ 1, max_offers // (num_groups * 2)
+ ) # At least 1, aim for ~half distribution
+ remaining_slots = max_offers
+ else:
+ min_per_group = None
+ remaining_slots = None
+
+ selected_offers = []
+
+ # First pass: Take min_per_group from each group (cheapest from each)
+ if min_per_group:
+ for group_name, group_offers in groups_with_offers.items():
+ sorted_group_offers = sorted(group_offers, key=lambda x: x.price)
+ take_count = min(min_per_group, len(sorted_group_offers), remaining_slots)
+ for offer in sorted_group_offers[:take_count]:
+ selected_offers.append((group_name, offer))
+ remaining_slots -= take_count
+
+ # Second pass: Fill remaining slots with cheapest offers globally
+ if remaining_slots and remaining_slots > 0:
+ all_remaining = []
+ for group_name, group_offers in groups_with_offers.items():
+ sorted_group_offers = sorted(group_offers, key=lambda x: x.price)
+ # Skip offers already selected
+ for offer in sorted_group_offers[min_per_group:]:
+ all_remaining.append((group_name, offer))
+
+ # Sort remaining by price and take the cheapest
+ all_remaining.sort(key=lambda x: x[1].price)
+ selected_offers.extend(all_remaining[:remaining_slots])
+
+ # If no max_offers limit, show all
+ if not max_offers:
+ selected_offers = []
+ for group_name, group_offers in groups_with_offers.items():
+ for offer in group_offers:
+ selected_offers.append((group_name, offer))
+
+ # Sort final selection by price for display
+ selected_offers.sort(key=lambda x: x[1].price)
+
+ # Show groups with no offers FIRST
+ for group_name in groups_with_no_offers:
+ offers.add_row(
+ "",
+ f"[cyan]{group_name}[/cyan]:",
+ "[red]No matching instance offers available.[/red]\n"
+ "Possible reasons: https://dstack.ai/docs/guides/troubleshooting/#no-offers",
+ "",
+ "",
+ "",
+ style="secondary",
+ )
+
+ # Then show selected offers
+ for i, (group_name, offer) in enumerate(selected_offers, start=1):
+ r = offer.instance.resources
+
+ availability = ""
+ if offer.availability in {
+ InstanceAvailability.NOT_AVAILABLE,
+ InstanceAvailability.NO_QUOTA,
+ InstanceAvailability.IDLE,
+ InstanceAvailability.BUSY,
+ }:
+ availability = offer.availability.value.replace("_", " ").lower()
+ instance = offer.instance.name
+ if offer.total_blocks > 1:
+ instance += f" ({offer.blocks}/{offer.total_blocks})"
+
+ # Add group name prefix for multi-group display
+ backend_display = f"[cyan]{group_name}[/cyan]: {offer.backend.replace('remote', 'ssh')} ({offer.region})"
+
+ offers.add_row(
+ f"{i}",
+ backend_display,
+ r.pretty_format(include_spot=True),
+ instance,
+ f"${offer.price:.4f}".rstrip("0").rstrip("."),
+ availability,
+ style=None if i == 1 or not include_run_properties else "secondary",
+ )
+
+ if total_offers_count > len(selected_offers):
+ offers.add_row("", "...", style="secondary")
+ else:
+ # Single job - original logic
+ job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers
+
+ for i, offer in enumerate(job_plan.offers, start=1):
+ r = offer.instance.resources
+
+ availability = ""
+ if offer.availability in {
+ InstanceAvailability.NOT_AVAILABLE,
+ InstanceAvailability.NO_QUOTA,
+ InstanceAvailability.IDLE,
+ InstanceAvailability.BUSY,
+ }:
+ availability = offer.availability.value.replace("_", " ").lower()
+ instance = offer.instance.name
+ if offer.total_blocks > 1:
+ instance += f" ({offer.blocks}/{offer.total_blocks})"
+ offers.add_row(
+ f"{i}",
+ offer.backend.replace("remote", "ssh") + " (" + offer.region + ")",
+ r.pretty_format(include_spot=True),
+ instance,
+ f"${offer.price:.4f}".rstrip("0").rstrip("."),
+ availability,
+ style=None if i == 1 or not include_run_properties else "secondary",
+ )
+ if job_plan.total_offers > len(job_plan.offers):
+ offers.add_row("", "...", style="secondary")
console.print(props)
console.print()
- if len(job_plan.offers) > 0:
+
+ # Check if we have offers to display
+ has_offers = False
+ if len(run_plan.job_plans) > 1:
+ has_offers = any(len(jp.offers) > 0 for jp in run_plan.job_plans)
+ else:
+ has_offers = len(job_plan.offers) > 0
+
+ if has_offers:
console.print(offers)
- if job_plan.total_offers > len(job_plan.offers):
- console.print(
- f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, "
- f"${job_plan.max_price:3f}".rstrip("0").rstrip(".")
- + "max[/]"
- )
+ # Show summary for multi-job plans
+ if len(run_plan.job_plans) > 1:
+ if total_offers_count > len(selected_offers):
+ max_price_overall = max(
+ (jp.max_price for jp in run_plan.job_plans if jp.max_price), default=None
+ )
+ if max_price_overall:
+ console.print(
+ f"[secondary] Shown {len(selected_offers)} of {total_offers_count} offers, "
+ f"${max_price_overall:3f}".rstrip("0").rstrip(".")
+ + " max[/]"
+ )
+ else:
+ if job_plan.total_offers > len(job_plan.offers):
+ console.print(
+ f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, "
+ f"${job_plan.max_price:3f}".rstrip("0").rstrip(".")
+ + " max[/]"
+ )
console.print()
else:
console.print(NO_OFFERS_WARNING)
@@ -233,8 +407,14 @@ def get_runs_table(
if verbose and latest_job_submission.inactivity_secs:
inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs)
status += f" (inactive for {inactive_for})"
+
+ job_name_parts = [f" replica={job.job_spec.replica_num}"]
+ if job.job_spec.replica_group_name:
+ job_name_parts.append(f"[cyan]group={job.job_spec.replica_group_name}[/cyan]")
+ job_name_parts.append(f"job={job.job_spec.job_num}")
+
job_row: Dict[Union[str, int], Any] = {
- "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}"
+ "NAME": " ".join(job_name_parts)
+ (
f" deployment={latest_job_submission.deployment_num}"
if show_deployment_num
diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py
index a7a9d40a3..3c4b9c5c9 100644
--- a/src/dstack/_internal/core/compatibility/runs.py
+++ b/src/dstack/_internal/core/compatibility/runs.py
@@ -151,6 +151,9 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType:
configuration_excludes["schedule"] = True
if profile is not None and profile.schedule is None:
profile_excludes.add("schedule")
+ # Exclude replica_groups for backward compatibility with older servers
+ if isinstance(configuration, ServiceConfiguration) and configuration.replica_groups is None:
+ configuration_excludes["replica_groups"] = True
configuration_excludes["repos"] = True
if configuration_excludes:
diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py
index 6fe8132de..f687aec30 100644
--- a/src/dstack/_internal/core/models/configurations.py
+++ b/src/dstack/_internal/core/models/configurations.py
@@ -685,6 +685,52 @@ class TaskConfiguration(
type: Literal["task"] = "task"
+class ReplicaGroupConfig(ProfileParamsConfig):
+ @staticmethod
+ def schema_extra(schema: Dict[str, Any]):
+ ProfileParamsConfig.schema_extra(schema)
+ add_extra_schema_types(
+ schema["properties"]["replicas"],
+ extra_types=[{"type": "integer"}, {"type": "string"}],
+ )
+
+
+class ReplicaGroup(ProfileParams, generate_dual_core_model(ReplicaGroupConfig)):
+ """
+ A replica group defines a set of service replicas with specific resource requirements
+ and provisioning parameters.
+ """
+
+ name: Annotated[str, Field(description="Group name (must be unique within the service)")]
+ replicas: Annotated[
+ Range[int],
+ Field(
+ description="Number of replicas. Can be a fixed number (e.g., `2`) or a range (`1..3`). "
+ "If it's a range, the group can be autoscaled"
+ ),
+ ]
+ resources: Annotated[
+ ResourcesSpec,
+ Field(description="Resource requirements for replicas in this group"),
+ ]
+
+ @validator("name")
+ def validate_name(cls, v):
+ if not v or not v.strip():
+ raise ValueError("Group name cannot be empty")
+ return v
+
+ @validator("replicas")
+ def convert_replicas(cls, v: Range[int]) -> Range[int]:
+ if v.max is None:
+ raise ValueError("The maximum number of replicas is required")
+ if v.min is None:
+ v.min = 0
+ if v.min < 0:
+ raise ValueError("The minimum number of replicas must be greater than or equal to 0")
+ return v
+
+
class ServiceConfigurationParamsConfig(CoreConfig):
@staticmethod
def schema_extra(schema: Dict[str, Any]):
@@ -754,6 +800,13 @@ class ServiceConfigurationParams(CoreModel):
list[ProbeConfig],
Field(description="List of probes used to determine job health"),
] = []
+ replica_groups: Annotated[
+ Optional[List[ReplicaGroup]],
+ Field(
+ description="Define multiple replica groups with different configurations. "
+ "Cannot be used together with 'replicas'"
+ ),
+ ] = None
@validator("port")
def convert_port(cls, v) -> PortMapping:
@@ -789,14 +842,46 @@ def validate_gateway(
)
return v
+ @root_validator()
+ def validate_replica_groups_xor_replicas(cls, values):
+ replica_groups = values.get("replica_groups")
+ replicas = values.get("replicas")
+
+ # Check if user specified both
+ has_groups = replica_groups is not None
+ has_replicas = replicas != Range[int](min=1, max=1)
+
+ if has_groups and has_replicas:
+ raise ValueError("Cannot specify both 'replicas' and 'replica_groups'")
+
+ if has_groups:
+ # Validate unique names
+ names = [g.name for g in replica_groups]
+ if len(names) != len(set(names)):
+ raise ValueError("Replica group names must be unique")
+
+ # Validate at least one group
+ if not replica_groups:
+ raise ValueError("replica_groups cannot be empty")
+
+ return values
+
@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")
- if replicas and replicas.min != replicas.max and not scaling:
+ replica_groups = values.get("replica_groups")
+
+ if replica_groups:
+ # Check if any group has a range
+ has_range = any(g.replicas.min != g.replicas.max for g in replica_groups)
+ if has_range and not scaling:
+ raise ValueError("When any replica group has a range, 'scaling' must be specified")
+ elif replicas and replicas.min != replicas.max and not scaling:
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
- if replicas and replicas.min == replicas.max and scaling:
+ elif replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
+
return values
@validator("rate_limits")
diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py
index 0a5b174d2..7efa83902 100644
--- a/src/dstack/_internal/core/models/runs.py
+++ b/src/dstack/_internal/core/models/runs.py
@@ -1,11 +1,14 @@
from datetime import datetime, timedelta
from enum import Enum
-from typing import Any, Dict, List, Literal, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
from urllib.parse import urlparse
from pydantic import UUID4, Field, root_validator
from typing_extensions import Annotated
+if TYPE_CHECKING:
+ from dstack._internal.core.models.configurations import ReplicaGroup, ServiceConfiguration
+
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.common import (
ApplyAction,
@@ -247,6 +250,7 @@ class ProbeSpec(CoreModel):
class JobSpec(CoreModel):
replica_num: int = 0 # default value for backward compatibility
+ replica_group_name: Optional[str] = None
job_num: int
job_name: str
jobs_per_replica: int = 1 # default value for backward compatibility
@@ -618,3 +622,37 @@ def get_service_port(job_spec: JobSpec, configuration: ServiceConfiguration) ->
if job_spec.service_port is None:
return configuration.port.container_port
return job_spec.service_port
+
+
+def get_normalized_replica_groups(configuration: "ServiceConfiguration") -> List["ReplicaGroup"]:
+ """
+ Normalize service configuration to replica groups.
+ Converts legacy replicas field to a single "default" group for backward compatibility.
+ """
+ from dstack._internal.core.models.configurations import ReplicaGroup
+
+ if configuration.replica_groups:
+ return configuration.replica_groups
+
+ return [
+ ReplicaGroup(
+ name="default",
+ replicas=configuration.replicas,
+ resources=configuration.resources,
+ backends=configuration.backends,
+ regions=configuration.regions,
+ availability_zones=configuration.availability_zones,
+ instance_types=configuration.instance_types,
+ reservation=configuration.reservation,
+ spot_policy=configuration.spot_policy,
+ retry=configuration.retry,
+ max_duration=configuration.max_duration,
+ stop_duration=configuration.stop_duration,
+ max_price=configuration.max_price,
+ creation_policy=configuration.creation_policy,
+ idle_duration=configuration.idle_duration,
+ utilization_policy=configuration.utilization_policy,
+ startup_order=configuration.startup_order,
+ stop_criteria=configuration.stop_criteria,
+ )
+ ]
diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py
index df1cce72f..0be74df83 100644
--- a/src/dstack/_internal/server/background/tasks/process_runs.py
+++ b/src/dstack/_internal/server/background/tasks/process_runs.py
@@ -156,6 +156,10 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
)
run_model = res.unique().scalar_one()
logger.debug("%s: processing run", fmt(run_model))
+
+ # Migrate legacy jobs without replica_group_name (one-time fix)
+ await _migrate_legacy_job_replica_groups(session, run_model)
+
try:
if run_model.status == RunStatus.PENDING:
await _process_pending_run(session, run_model)
@@ -176,6 +180,70 @@ async def _process_run(session: AsyncSession, run_model: RunModel):
await session.commit()
+async def _migrate_legacy_job_replica_groups(session: AsyncSession, run_model: RunModel):
+ """
+ Migrate jobs from old runs that don't have replica_group_name set.
+ This fixes jobs created before the replica_groups feature was added.
+ """
+ run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
+
+ # Only migrate service runs with replica_groups
+ if run_spec.configuration.type != "service":
+ return
+
+ # Check if run uses replica_groups
+ if not getattr(run_spec.configuration, "replica_groups", None):
+ return
+
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+
+ # Check if any jobs need migration
+ needs_migration = any(job.replica_group_name is None for job in run_model.jobs)
+
+ if not needs_migration:
+ return
+
+ logger.info(
+ "%s: Migrating legacy jobs to assign replica_group_name",
+ fmt(run_model),
+ )
+
+ # Build a map of replica_num -> group_name based on how jobs were originally created
+ replica_num_to_group = {}
+ current_replica_num = 0
+
+ for group in normalized_groups:
+ group_min = group.replicas.min or 0
+ for _ in range(group_min):
+ replica_num_to_group[current_replica_num] = group.name
+ current_replica_num += 1
+
+ # Update jobs
+ migrated_count = 0
+ for job in run_model.jobs:
+ if job.replica_group_name is None:
+ expected_group = replica_num_to_group.get(job.replica_num)
+ if expected_group:
+ job.replica_group_name = expected_group
+ migrated_count += 1
+ logger.info(
+ "%s: Migrated job replica_num=%d to group '%s'",
+ fmt(run_model),
+ job.replica_num,
+ expected_group,
+ )
+
+ if migrated_count > 0:
+ await session.commit()
+ logger.info(
+ "%s: Migrated %d job(s) to replica groups",
+ fmt(run_model),
+ migrated_count,
+ )
+
+
async def _process_pending_run(session: AsyncSession, run_model: RunModel):
"""Jobs are not created yet"""
run = run_model_to_run(run_model)
@@ -189,7 +257,10 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel):
run_model.desired_replica_count = 1
if run.run_spec.configuration.type == "service":
- run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = get_normalized_replica_groups(run.run_spec.configuration)
+ run_model.desired_replica_count = sum(g.replicas.min or 0 for g in normalized_groups)
await update_service_desired_replica_count(
session,
run_model,
@@ -481,6 +552,7 @@ async def _handle_run_replicas(
session,
run_model,
replicas_diff=max_replica_count - non_terminated_replica_count,
+ allow_exceeding_max=True, # Allow exceeding max for rolling deployments
)
replicas_to_stop_count = 0
@@ -507,6 +579,7 @@ async def _handle_run_replicas(
session,
run_model,
replicas_diff=-replicas_to_stop_count,
+ allow_exceeding_max=True, # Allow terminating out-of-date replicas during rolling deployment
)
@@ -516,20 +589,49 @@ async def _update_jobs_to_new_deployment_in_place(
"""
Bump deployment_num for jobs that do not require redeployment.
"""
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
secrets = await get_project_secrets_mapping(
session=session,
project=run_model.project,
)
+
+ # Get replica groups from the new run_spec for matching
+ replica_group = None
+ if run_spec.configuration.type == "service":
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+ else:
+ normalized_groups = []
+
for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs):
if all(j.status.is_finished() for j in job_models):
continue
if all(j.deployment_num == run_model.deployment_num for j in job_models):
continue
+
+ # Determine which replica group this job belongs to
+ # Use the old job's replica_group_name to find the matching group in new spec
+ old_job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data)
+ if old_job_spec.replica_group_name and normalized_groups:
+ replica_group = next(
+ (g for g in normalized_groups if g.name == old_job_spec.replica_group_name),
+ None,
+ )
+ if replica_group is None:
+ logger.warning(
+ "Replica group '%s' from old job not found in new run_spec. "
+ "Job will use base configuration.",
+ old_job_spec.replica_group_name,
+ )
+ else:
+ replica_group = None
+
# FIXME: Handle getting image configuration errors or skip it.
new_job_specs = await get_job_specs_from_run_spec(
run_spec=run_spec,
secrets=secrets,
replica_num=replica_num,
+ replica_group=replica_group,
)
assert len(new_job_specs) == len(job_models), (
"Changing the number of jobs within a replica is not yet supported"
diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
index 2814840b5..ea0422808 100644
--- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
+++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py
@@ -609,11 +609,15 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int:
nodes_required_num = 1
if run_spec.configuration.type == "task":
nodes_required_num = run_spec.configuration.nodes
- elif (
- run_spec.configuration.type == "service"
- and run_spec.configuration.replicas.min is not None
- ):
- nodes_required_num = run_spec.configuration.replicas.min
+ elif run_spec.configuration.type == "service":
+ # Use groups if present
+ if run_spec.configuration.replica_groups:
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ groups = get_normalized_replica_groups(run_spec.configuration)
+ nodes_required_num = sum(g.replicas.min or 0 for g in groups)
+ elif run_spec.configuration.replicas.min is not None:
+ nodes_required_num = run_spec.configuration.replicas.min
return nodes_required_num
@@ -728,6 +732,62 @@ async def _assign_job_to_fleet_instance(
return instance
+def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile:
+ """Get merged profile with group overrides for this job."""
+ from dstack._internal.core.models.profiles import Profile, ProfileParams
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ base_profile = run_spec.merged_profile
+
+ group_name = job.job_spec.replica_group_name
+ logger.info(
+ "Getting profile for job %s: replica_group_name=%s, config_type=%s",
+ job.job_spec.job_name,
+ group_name,
+ run_spec.configuration.type,
+ )
+
+ if not group_name or run_spec.configuration.type != "service":
+ logger.info("Using base profile (no group_name or not a service)")
+ return base_profile
+
+ # Find the group
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+ logger.info(
+ "Normalized groups: %s",
+ [f"{g.name} (regions={g.regions})" for g in normalized_groups],
+ )
+ group = next((g for g in normalized_groups if g.name == group_name), None)
+
+ if not group:
+ logger.warning(
+ "Replica group '%s' not found in run_spec. Available groups: %s",
+ group_name,
+ [g.name for g in normalized_groups],
+ )
+ return base_profile
+
+ # Merge: group overrides base
+ merged = Profile.parse_obj(base_profile.dict())
+ for field_name in ProfileParams.__fields__:
+ group_value = getattr(group, field_name, None)
+ if group_value is not None:
+ setattr(merged, field_name, group_value)
+
+ logger.info(
+ "Profile for group '%s': regions=%s, backends=%s, spot_policy=%s (base had: regions=%s, backends=%s, spot_policy=%s)",
+ group_name,
+ merged.regions,
+ [b.value for b in merged.backends] if merged.backends else None,
+ merged.spot_policy,
+ base_profile.regions,
+ [b.value for b in base_profile.backends] if base_profile.backends else None,
+ base_profile.spot_policy,
+ )
+
+ return merged
+
+
async def _run_job_on_new_instance(
project: ProjectModel,
job_model: JobModel,
@@ -741,8 +801,31 @@ async def _run_job_on_new_instance(
) -> Optional[tuple[JobProvisioningData, InstanceOfferWithAvailability, Profile, Requirements]]:
if volumes is None:
volumes = []
- profile = run.run_spec.merged_profile
- requirements = job.job_spec.requirements
+ profile = _get_profile_for_job(run.run_spec, job)
+ requirements = job.job_spec.requirements # Already has group resources baked in
+
+ # Debug logging for replica groups
+ replica_group_name = job.job_spec.replica_group_name
+ if replica_group_name:
+ logger.debug(
+ "%s: Provisioning replica group '%s' with profile: regions=%s, backends=%s, spot_policy=%s",
+ fmt(job_model),
+ replica_group_name,
+ profile.regions,
+ [b.value for b in profile.backends] if profile.backends else None,
+ profile.spot_policy,
+ )
+ gpu_req = (
+ requirements.resources.gpu.name
+ if requirements.resources and requirements.resources.gpu
+ else None
+ )
+ logger.debug(
+ "%s: GPU requirements for group '%s': %s",
+ fmt(job_model),
+ replica_group_name,
+ gpu_req,
+ )
fleet = None
if fleet_model is not None:
fleet = fleet_model_to_fleet(fleet_model)
@@ -761,6 +844,21 @@ async def _run_job_on_new_instance(
multinode = job.job_spec.jobs_per_replica > 1 or (
fleet is not None and fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER
)
+
+ # Log the requirements and profile being used
+ gpu_requirement = (
+ requirements.resources.gpu.name
+ if requirements.resources and requirements.resources.gpu
+ else None
+ )
+ logger.info(
+ "%s: Fetching offers with GPU=%s, regions=%s, backends=%s",
+ fmt(job_model),
+ gpu_requirement,
+ profile.regions,
+ [b.value for b in profile.backends] if profile.backends else None,
+ )
+
offers = await get_offers_by_requirements(
project=project,
profile=profile,
@@ -772,6 +870,14 @@ async def _run_job_on_new_instance(
privileged=job.job_spec.privileged,
instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec),
)
+
+ # Debug logging for offers
+ logger.info(
+ "%s: Got %d offers. First 3: %s",
+ fmt(job_model),
+ len(offers),
+ [f"{o.instance.name} ({o.backend.value}/{o.region})" for _, o in offers[:3]],
+ )
# Limit number of offers tried to prevent long-running processing
# in case all offers fail.
for backend, offer in offers[: settings.MAX_OFFERS_TRIED]:
@@ -822,7 +928,9 @@ def _get_run_profile_and_requirements_in_fleet(
run_spec: RunSpec,
fleet: Fleet,
) -> tuple[Profile, Requirements]:
- profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, run_spec.merged_profile)
+ # Use group-merged profile instead of run-level
+ job_profile = _get_profile_for_job(run_spec, job)
+ profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, job_profile)
if profile is None:
raise ValueError("Cannot combine fleet profile")
fleet_requirements = get_fleet_requirements(fleet.spec)
diff --git a/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py
new file mode 100644
index 000000000..a1d9e3eaf
--- /dev/null
+++ b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py
@@ -0,0 +1,26 @@
+"""Add JobModel.replica_group_name
+
+Revision ID: a1b2c3d4e5f6
+Revises: ff1d94f65b08
+Create Date: 2025-10-17 00:00:00.000000
+
+"""
+
+import sqlalchemy as sa
+from alembic import op
+
+# revision identifiers, used by Alembic.
+revision = "a1b2c3d4e5f6"
+down_revision = "ff1d94f65b08"
+branch_labels = None
+depends_on = None
+
+
+def upgrade() -> None:
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("replica_group_name", sa.String(), nullable=True))
+
+
+def downgrade() -> None:
+ with op.batch_alter_table("jobs", schema=None) as batch_op:
+ batch_op.drop_column("replica_group_name")
diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py
index 31f44d369..d23a99e28 100644
--- a/src/dstack/_internal/server/models.py
+++ b/src/dstack/_internal/server/models.py
@@ -437,6 +437,7 @@ class JobModel(BaseModel):
instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs")
used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False))
replica_num: Mapped[int] = mapped_column(Integer)
+ replica_group_name: Mapped[Optional[str]] = mapped_column(String, nullable=True)
deployment_num: Mapped[int] = mapped_column(Integer)
job_runtime_data: Mapped[Optional[str]] = mapped_column(Text)
probes: Mapped[list["ProbeModel"]] = relationship(
diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py
index cbb089b2c..1f9f66580 100644
--- a/src/dstack/_internal/server/services/jobs/__init__.py
+++ b/src/dstack/_internal/server/services/jobs/__init__.py
@@ -1,9 +1,12 @@
import itertools
import json
from datetime import timedelta
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from uuid import UUID
+if TYPE_CHECKING:
+ from dstack._internal.core.models.configurations import ReplicaGroup
+
import requests
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
@@ -67,7 +70,10 @@
async def get_jobs_from_run_spec(
- run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
+ run_spec: RunSpec,
+ secrets: Dict[str, str],
+ replica_num: int,
+ replica_group: Optional["ReplicaGroup"] = None,
) -> List[Job]:
return [
Job(job_spec=s, job_submissions=[])
@@ -75,14 +81,19 @@ async def get_jobs_from_run_spec(
run_spec=run_spec,
secrets=secrets,
replica_num=replica_num,
+ replica_group=replica_group,
)
]
async def get_job_specs_from_run_spec(
- run_spec: RunSpec, secrets: Dict[str, str], replica_num: int
+ run_spec: RunSpec,
+ secrets: Dict[str, str],
+ replica_num: int,
+ replica_group: Optional["ReplicaGroup"] = None,
) -> List[JobSpec]:
job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets)
+ job_configurator.replica_group = replica_group
job_specs = await job_configurator.get_job_specs(replica_num=replica_num)
return job_specs
diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py
index 02cdc70b3..b79c2a15e 100644
--- a/src/dstack/_internal/server/services/jobs/configurators/base.py
+++ b/src/dstack/_internal/server/services/jobs/configurators/base.py
@@ -3,7 +3,11 @@
import threading
from abc import ABC, abstractmethod
from pathlib import PurePosixPath
-from typing import Dict, List, Optional
+from typing import TYPE_CHECKING, Dict, List, Optional
+
+if TYPE_CHECKING:
+ from dstack._internal.core.models.configurations import ReplicaGroup
+ from dstack._internal.core.models.profiles import Profile
from cachetools import TTLCache, cached
@@ -84,6 +88,7 @@ class JobConfigurator(ABC):
_image_config: Optional[ImageConfig] = None
# JobSSHKey should be shared for all jobs in a replica for inter-node communication.
_job_ssh_key: Optional[JobSSHKey] = None
+ replica_group: Optional["ReplicaGroup"] = None
def __init__(
self,
@@ -146,6 +151,7 @@ async def _get_job_spec(
) -> JobSpec:
job_spec = JobSpec(
replica_num=replica_num, # TODO(egor-s): add to env variables in the runner
+ replica_group_name=self.replica_group.name if self.replica_group else None,
job_num=job_num,
job_name=f"{self.run_spec.run_name}-{job_num}-{replica_num}",
jobs_per_replica=jobs_per_replica,
@@ -295,13 +301,40 @@ def _utilization_policy(self) -> Optional[UtilizationPolicy]:
def _registry_auth(self) -> Optional[RegistryAuth]:
return self.run_spec.configuration.registry_auth
+ def _get_merged_profile(self) -> "Profile":
+ """Get profile with group overrides applied."""
+ from dstack._internal.core.models.profiles import Profile, ProfileParams
+
+ base = self.run_spec.merged_profile
+
+ if not self.replica_group:
+ return base
+
+ # Clone and apply group overrides
+ merged = Profile.parse_obj(base.dict())
+ for field_name in ProfileParams.__fields__:
+ group_value = getattr(self.replica_group, field_name, None)
+ if group_value is not None:
+ setattr(merged, field_name, group_value)
+
+ return merged
+
def _requirements(self) -> Requirements:
- spot_policy = self._spot_policy()
+ # Use group resources if available, else fall back to config
+ if self.replica_group:
+ resources = self.replica_group.resources
+ else:
+ resources = self.run_spec.configuration.resources
+
+ # Get merged profile for spot/price/reservation
+ profile = self._get_merged_profile()
+ spot_policy = profile.spot_policy or SpotPolicy.ONDEMAND
+
return Requirements(
- resources=self.run_spec.configuration.resources,
- max_price=self.run_spec.merged_profile.max_price,
+ resources=resources,
+ max_price=profile.max_price,
spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT),
- reservation=self.run_spec.merged_profile.reservation,
+ reservation=profile.reservation,
)
def _retry(self) -> Optional[Retry]:
diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py
index 25ac750aa..ffc7cbc14 100644
--- a/src/dstack/_internal/server/services/runs.py
+++ b/src/dstack/_internal/server/services/runs.py
@@ -30,6 +30,8 @@
)
from dstack._internal.core.models.profiles import (
CreationPolicy,
+ Profile,
+ ProfileParams,
RetryEvent,
)
from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData
@@ -302,6 +304,45 @@ async def get_run_by_id(
return run_model_to_run(run_model, return_in_api=True)
+def _get_job_profile(run_spec: RunSpec, replica_group_name: Optional[str]) -> Profile:
+ """
+ Get the profile for a job, including replica group overrides if applicable.
+
+ Args:
+ run_spec: The run specification
+ replica_group_name: Name of the replica group, or None for legacy jobs
+
+ Returns:
+ Profile with replica group overrides applied
+ """
+ base_profile = run_spec.merged_profile
+
+ # If no replica group, return base profile
+ if not replica_group_name:
+ return base_profile
+
+ # Find the replica group
+ if run_spec.configuration.type != "service":
+ return base_profile
+
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+ replica_group = next((g for g in normalized_groups if g.name == replica_group_name), None)
+
+ if not replica_group:
+ return base_profile
+
+ # Clone base profile and apply group overrides
+ merged = Profile.parse_obj(base_profile.dict())
+ for field_name in ProfileParams.__fields__:
+ group_value = getattr(replica_group, field_name, None)
+ if group_value is not None:
+ setattr(merged, field_name, group_value)
+
+ return merged
+
+
async def get_plan(
session: AsyncSession,
project: ProjectModel,
@@ -340,11 +381,34 @@ async def get_plan(
action = ApplyAction.UPDATE
secrets = await get_project_secrets_mapping(session=session, project=project)
- jobs = await get_jobs_from_run_spec(
- run_spec=effective_run_spec,
- secrets=secrets,
- replica_num=0,
- )
+
+ # For services with replica groups, create jobs for all groups during planning
+ jobs = []
+ if (
+ effective_run_spec.configuration.type == "service"
+ and effective_run_spec.configuration.replica_groups
+ ):
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = get_normalized_replica_groups(effective_run_spec.configuration)
+ replica_num = 0
+ for group in normalized_groups:
+ # Create one job per group for planning (minimum replicas)
+ group_jobs = await get_jobs_from_run_spec(
+ run_spec=effective_run_spec,
+ secrets=secrets,
+ replica_num=replica_num,
+ replica_group=group,
+ )
+ jobs.extend(group_jobs)
+ replica_num += 1
+ else:
+ # Legacy: single job for planning
+ jobs = await get_jobs_from_run_spec(
+ run_spec=effective_run_spec,
+ secrets=secrets,
+ replica_num=0,
+ )
volumes = await get_job_configured_volumes(
session=session,
@@ -353,19 +417,47 @@ async def get_plan(
job_num=0,
)
- pool_offers = await _get_pool_offers(
- session=session,
- project=project,
- run_spec=effective_run_spec,
- job=jobs[0],
- volumes=volumes,
- )
+ # For replica groups, we need pool offers for all GPU types, not just the first job's type
+ # So we fetch pool offers for each job separately and aggregate them
+ all_pool_offers = []
+ if len(jobs) > 1:
+ # Multiple jobs (likely replica groups) - get pool offers per job to include all GPU types
+ for job in jobs:
+ job_pool_offers = await _get_pool_offers(
+ session=session,
+ project=project,
+ run_spec=effective_run_spec,
+ job=job,
+ volumes=volumes,
+ )
+ all_pool_offers.extend(job_pool_offers)
+ # Deduplicate by (backend, instance_name, region) tuple
+ seen_offers = set()
+ pool_offers = []
+ for offer in all_pool_offers:
+ offer_key = (offer.backend, offer.instance.name, offer.region)
+ if offer_key not in seen_offers:
+ seen_offers.add(offer_key)
+ pool_offers.append(offer)
+ else:
+ pool_offers = await _get_pool_offers(
+ session=session,
+ project=project,
+ run_spec=effective_run_spec,
+ job=jobs[0],
+ volumes=volumes,
+ )
effective_run_spec.run_name = "dry-run" # will regenerate jobs on submission
- # Get offers once for all jobs
- offers = []
- if creation_policy == CreationPolicy.REUSE_OR_CREATE:
- offers = await get_offers_by_requirements(
+ # Check if all jobs have identical requirements (optimization for single-type jobs)
+ all_requirements_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+
+ # Get offers once if all jobs are identical, otherwise get per-job
+ shared_offers = []
+ if creation_policy == CreationPolicy.REUSE_OR_CREATE and all_requirements_identical:
+ shared_offers = await get_offers_by_requirements(
project=project,
profile=profile,
requirements=jobs[0].job_spec.requirements,
@@ -379,8 +471,46 @@ async def get_plan(
job_plans = []
for job in jobs:
job_offers: List[InstanceOfferWithAvailability] = []
- job_offers.extend(pool_offers)
- job_offers.extend(offer for _, offer in offers)
+
+ # Filter pool offers to match this job's GPU requirements
+ gpu_req = None
+ if job.job_spec.requirements.resources and job.job_spec.requirements.resources.gpu:
+ gpu_req = job.job_spec.requirements.resources.gpu.name
+
+ matching_pool_offers = []
+ for pool_offer in pool_offers:
+ offer_gpus = pool_offer.instance.resources.gpus
+ if offer_gpus and gpu_req:
+ # Check if offer's GPU matches job's requirement
+ offer_gpu_names = [gpu.name for gpu in offer_gpus]
+ if any(req_gpu in offer_gpu_names for req_gpu in gpu_req):
+ matching_pool_offers.append(pool_offer)
+ elif not gpu_req:
+ # No GPU requirement, include all pool offers
+ matching_pool_offers.append(pool_offer)
+
+ job_offers.extend(matching_pool_offers)
+
+ # Get the correct profile for this job (with replica group overrides if applicable)
+ job_profile = _get_job_profile(effective_run_spec, job.job_spec.replica_group_name)
+
+ # Use shared offers if all jobs are identical, otherwise fetch per-job
+ if shared_offers:
+ job_offers.extend(offer for _, offer in shared_offers)
+ elif creation_policy == CreationPolicy.REUSE_OR_CREATE:
+ # Fetch offers specific to this job's requirements with job-specific profile
+ job_specific_offers = await get_offers_by_requirements(
+ project=project,
+ profile=job_profile,
+ requirements=job.job_spec.requirements,
+ exclude_not_available=False,
+ multinode=job.job_spec.jobs_per_replica > 1,
+ volumes=volumes,
+ privileged=job.job_spec.privileged,
+ instance_mounts=check_run_spec_requires_instance_mounts(effective_run_spec),
+ )
+ job_offers.extend(offer for _, offer in job_specific_offers)
+
job_offers.sort(key=lambda offer: not offer.availability.is_available())
job_spec = job.job_spec
@@ -557,19 +687,47 @@ async def submit_run(
if run_spec.configuration.type == "service":
await services.register_service(session, run_model, run_spec)
- for replica_num in range(initial_replicas):
- jobs = await get_jobs_from_run_spec(
- run_spec=run_spec,
- secrets=secrets,
- replica_num=replica_num,
- )
- for job in jobs:
- job_model = create_job_model_for_new_submission(
- run_model=run_model,
- job=job,
- status=JobStatus.SUBMITTED,
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+
+ # Set initial desired count (sum of all group minimums)
+ run_model.desired_replica_count = sum(g.replicas.min or 0 for g in normalized_groups)
+
+ # Create jobs by iterating over groups
+ replica_num = 0 # Global counter across all groups
+ for group in normalized_groups:
+ group_min = group.replicas.min or 0
+ for _ in range(group_min):
+ jobs = await get_jobs_from_run_spec(
+ run_spec=run_spec,
+ secrets=secrets,
+ replica_num=replica_num,
+ replica_group=group, # Pass group context
+ )
+ for job in jobs:
+ job_model = create_job_model_for_new_submission(
+ run_model=run_model,
+ job=job,
+ status=JobStatus.SUBMITTED,
+ )
+ session.add(job_model)
+ replica_num += 1
+ else:
+ # Non-service runs (tasks, dev environments)
+ for replica_num in range(initial_replicas):
+ jobs = await get_jobs_from_run_spec(
+ run_spec=run_spec,
+ secrets=secrets,
+ replica_num=replica_num,
)
- session.add(job_model)
+ for job in jobs:
+ job_model = create_job_model_for_new_submission(
+ run_model=run_model,
+ job=job,
+ status=JobStatus.SUBMITTED,
+ )
+ session.add(job_model)
await session.commit()
await session.refresh(run_model)
@@ -591,6 +749,7 @@ def create_job_model_for_new_submission(
job_num=job.job_spec.job_num,
job_name=f"{job.job_spec.job_name}",
replica_num=job.job_spec.replica_num,
+ replica_group_name=job.job_spec.replica_group_name,
deployment_num=run_model.deployment_num,
submission_num=len(job.job_submissions),
submitted_at=now,
@@ -1017,10 +1176,20 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s"
)
if isinstance(run_spec.configuration, ServiceConfiguration):
- if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0:
- raise ServerClientError(
- "Scheduled services with autoscaling to zero are not supported"
- )
+ # Check all groups for min=0 with schedule
+ if run_spec.merged_profile.schedule:
+ if run_spec.configuration.replica_groups:
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ groups = get_normalized_replica_groups(run_spec.configuration)
+ if any(g.replicas.min == 0 for g in groups):
+ raise ServerClientError(
+ "Scheduled services with autoscaling to zero are not supported"
+ )
+ elif run_spec.configuration.replicas.min == 0:
+ raise ServerClientError(
+ "Scheduled services with autoscaling to zero are not supported"
+ )
if len(run_spec.configuration.probes) > settings.MAX_PROBES_PER_JOB:
raise ServerClientError(
f"Cannot configure more than {settings.MAX_PROBES_PER_JOB} probes"
@@ -1058,6 +1227,7 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec):
"service": [
# in-place
"replicas",
+ "replica_groups", # Named replica groups (mutually exclusive with replicas)
"scaling",
# rolling deployment
# NOTE: keep this list in sync with the "Rolling deployment" section in services.md
@@ -1186,7 +1356,21 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel):
)
-async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int):
+async def scale_run_replicas(
+ session: AsyncSession,
+ run_model: RunModel,
+ replicas_diff: int,
+ allow_exceeding_max: bool = False,
+):
+ """
+ Scale run replicas up or down.
+
+ Args:
+ session: Database session
+ run_model: The run to scale
+ replicas_diff: Number of replicas to add (positive) or remove (negative)
+ allow_exceeding_max: If True, allow scaling beyond configured max (for rolling deployments)
+ """
if replicas_diff == 0:
# nothing to do
return
@@ -1226,38 +1410,144 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
active_replicas.sort(key=lambda r: (r[1], -r[0], r[2]))
run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = (
+ get_normalized_replica_groups(run_spec.configuration)
+ if run_spec.configuration.type == "service"
+ else []
+ )
+
if replicas_diff < 0:
- for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]):
- # scale down the less important replicas first
+ # SCALE DOWN: Only terminate from autoscalable groups while respecting group minimums
+ autoscalable_groups = {
+ g.name for g in normalized_groups if g.replicas.min != g.replicas.max
+ }
+
+ # Count replicas per group
+ group_counts = {}
+ for _, _, _, replica_jobs in active_replicas:
+ if replica_jobs:
+ group_name = replica_jobs[0].replica_group_name or "default"
+ group_counts[group_name] = group_counts.get(group_name, 0) + 1
+
+ # Get group minimums
+ group_mins = {g.name: g.replicas.min for g in normalized_groups}
+
+ # Terminate from end (reversed)
+ # For rolling deployments (allow_exceeding_max), prioritize terminating out-of-date replicas
+ terminated_count = 0
+ for _, is_out_of_date, _, replica_jobs in reversed(active_replicas):
+ if terminated_count >= abs(replicas_diff):
+ break
+
+ if not replica_jobs:
+ continue
+
+ group_name = replica_jobs[0].replica_group_name or "default"
+
+ # For rolling deployment, allow terminating any out-of-date replica
+ if allow_exceeding_max and is_out_of_date:
+ # Terminate this replica (out-of-date during rolling deployment)
+ for job in replica_jobs:
+ if not job.status.is_finished() and job.status != JobStatus.TERMINATING:
+ job.status = JobStatus.TERMINATING
+ job.termination_reason = JobTerminationReason.SCALED_DOWN
+
+ group_counts[group_name] -= 1
+ terminated_count += 1
+ continue
+
+ # For normal scaling, skip if not autoscalable
+ if normalized_groups and group_name not in autoscalable_groups:
+ continue
+
+ # Skip if at minimum
+ current_count = group_counts.get(group_name, 0)
+ min_count = group_mins.get(group_name, 0)
+ if current_count <= min_count:
+ continue
+
+ # Terminate this replica
for job in replica_jobs:
- if job.status.is_finished() or job.status == JobStatus.TERMINATING:
- continue
- job.status = JobStatus.TERMINATING
- job.termination_reason = JobTerminationReason.SCALED_DOWN
- # background task will process the job later
+ if not job.status.is_finished() and job.status != JobStatus.TERMINATING:
+ job.status = JobStatus.TERMINATING
+ job.termination_reason = JobTerminationReason.SCALED_DOWN
+
+ group_counts[group_name] -= 1
+ terminated_count += 1
else:
+ # SCALE UP
+ # Count current replicas per group
+ group_counts = {}
+ for _, _, _, replica_jobs in active_replicas:
+ if replica_jobs:
+ group_name = replica_jobs[0].replica_group_name or "default"
+ group_counts[group_name] = group_counts.get(group_name, 0) + 1
+
+ # First, identify groups below minimum (need to scale regardless of autoscalability)
+ below_min_groups = [
+ g for g in normalized_groups if group_counts.get(g.name, 0) < (g.replicas.min or 0)
+ ]
+
+ # Then, identify autoscalable groups that can scale beyond minimum
+ autoscalable_groups = [
+ g
+ for g in normalized_groups
+ if g.replicas.min != g.replicas.max
+ and (
+ allow_exceeding_max
+ or group_counts.get(g.name, 0) < (g.replicas.max or float("inf"))
+ )
+ ]
+
+ # Eligible groups are: below-min groups + autoscalable groups
+ eligible_groups = []
+ if below_min_groups:
+ eligible_groups.extend(below_min_groups)
+ elif autoscalable_groups:
+ # Only use autoscalable groups if no groups are below minimum
+ eligible_groups.extend(autoscalable_groups)
+ elif allow_exceeding_max and normalized_groups:
+ # For rolling deployments, allow exceeding max even for fixed groups
+ eligible_groups.extend(normalized_groups)
+
+ if normalized_groups and not eligible_groups:
+ # All groups at their limits
+ logger.info("%s: all replica groups at their limits (min/max)", fmt(run_model))
+ return
+
scheduled_replicas = 0
- # rerun inactive replicas
+ # Reuse inactive replicas first (existing logic)
for _, _, _, replica_jobs in inactive_replicas:
if scheduled_replicas == replicas_diff:
break
- await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False)
- scheduled_replicas += 1
+ # Only reuse if from eligible group
+ if replica_jobs:
+ group_name = replica_jobs[0].replica_group_name or "default"
+ if not normalized_groups or group_name in {g.name for g in eligible_groups}:
+ await retry_run_replica_jobs(
+ session, run_model, replica_jobs, only_failed=False
+ )
+ scheduled_replicas += 1
- secrets = await get_project_secrets_mapping(
- session=session,
- project=run_model.project,
- )
+ # Create new replicas for remaining diff
+ secrets = await get_project_secrets_mapping(session=session, project=run_model.project)
+
+ for _ in range(replicas_diff - scheduled_replicas):
+ # Pick group for new replica
+ # v1: Simple heuristic - pick first eligible group (round-robin in future)
+ selected_group = eligible_groups[0] if eligible_groups else None
+
+ replica_num = len(active_replicas) + scheduled_replicas
- for replica_num in range(
- len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff
- ):
# FIXME: Handle getting image configuration errors or skip it.
jobs = await get_jobs_from_run_spec(
run_spec=run_spec,
secrets=secrets,
replica_num=replica_num,
+ replica_group=selected_group,
)
for job in jobs:
job_model = create_job_model_for_new_submission(
@@ -1267,6 +1557,24 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica
)
session.add(job_model)
+ # Update count
+ if selected_group:
+ group_counts[selected_group.name] = group_counts.get(selected_group.name, 0) + 1
+ scheduled_replicas += 1
+
+ # Remove from eligible if at max
+ if group_counts[selected_group.name] >= (
+ selected_group.replicas.max or float("inf")
+ ):
+ eligible_groups = [g for g in eligible_groups if g.name != selected_group.name]
+ if not eligible_groups:
+ logger.info(
+ "%s: all eligible groups reached maximum capacity", fmt(run_model)
+ )
+ break
+ else:
+ scheduled_replicas += 1
+
async def retry_run_replica_jobs(
session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool
@@ -1276,10 +1584,29 @@ async def retry_run_replica_jobs(
session=session,
project=run_model.project,
)
+
+ # Determine which replica group this job belongs to
+ run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
+ replica_group = None
+ if run_spec.configuration.type == "service" and latest_jobs:
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ group_name = latest_jobs[0].replica_group_name
+ if group_name:
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+ replica_group = next((g for g in normalized_groups if g.name == group_name), None)
+ if replica_group:
+ logger.info(
+ "%s: retrying job from replica group '%s'",
+ fmt(run_model),
+ replica_group.name,
+ )
+
new_jobs = await get_jobs_from_run_spec(
- run_spec=RunSpec.__response__.parse_raw(run_model.run_spec),
+ run_spec=run_spec,
secrets=secrets,
replica_num=latest_jobs[0].replica_num,
+ replica_group=replica_group,
)
assert len(new_jobs) == len(latest_jobs), (
"Changing the number of jobs within a replica is not yet supported"
diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py
index cd6d06e58..6573d1f7f 100644
--- a/src/dstack/_internal/server/services/services/autoscalers.py
+++ b/src/dstack/_internal/server/services/services/autoscalers.py
@@ -120,18 +120,29 @@ def get_desired_count(
def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler:
- assert conf.replicas.min is not None
- assert conf.replicas.max is not None
+ # Compute bounds from groups if present
+ if conf.replica_groups:
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ groups = get_normalized_replica_groups(conf)
+ min_replicas = sum(g.replicas.min or 0 for g in groups)
+ max_replicas = sum(g.replicas.max or 0 for g in groups)
+ else:
+ assert conf.replicas.min is not None
+ assert conf.replicas.max is not None
+ min_replicas = conf.replicas.min
+ max_replicas = conf.replicas.max
+
if conf.scaling is None:
return ManualScaler(
- min_replicas=conf.replicas.min,
- max_replicas=conf.replicas.max,
+ min_replicas=min_replicas,
+ max_replicas=max_replicas,
)
if conf.scaling.metric == "rps":
return RPSAutoscaler(
# replicas count validated by configuration model
- min_replicas=conf.replicas.min,
- max_replicas=conf.replicas.max,
+ min_replicas=min_replicas,
+ max_replicas=max_replicas,
target=conf.scaling.target,
scale_up_delay=conf.scaling.scale_up_delay,
scale_down_delay=conf.scaling.scale_down_delay,
diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py
index 0a8adaa42..6b1d42a10 100644
--- a/src/dstack/_internal/server/testing/common.py
+++ b/src/dstack/_internal/server/testing/common.py
@@ -345,6 +345,7 @@ async def create_job(
instance: Optional[InstanceModel] = None,
job_num: int = 0,
replica_num: int = 0,
+ replica_group_name: Optional[str] = None,
deployment_num: Optional[int] = None,
instance_assigned: bool = False,
disconnected_at: Optional[datetime] = None,
@@ -353,8 +354,19 @@ async def create_job(
if deployment_num is None:
deployment_num = run.deployment_num
run_spec = RunSpec.parse_raw(run.run_spec)
+
+ # Look up replica group if specified
+ replica_group = None
+ if replica_group_name and run_spec.configuration.type == "service":
+ from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+ normalized_groups = get_normalized_replica_groups(run_spec.configuration)
+ replica_group = next((g for g in normalized_groups if g.name == replica_group_name), None)
+
job_spec = (
- await get_job_specs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=replica_num)
+ await get_job_specs_from_run_spec(
+ run_spec=run_spec, secrets={}, replica_num=replica_num, replica_group=replica_group
+ )
)[0]
job_spec.job_num = job_num
job = JobModel(
@@ -365,6 +377,7 @@ async def create_job(
job_num=job_num,
job_name=run.run_name + f"-{job_num}-{replica_num}",
replica_num=replica_num,
+ replica_group_name=replica_group_name,
deployment_num=deployment_num,
submission_num=submission_num,
submitted_at=submitted_at,
diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py
new file mode 100644
index 000000000..8ec6bb4fb
--- /dev/null
+++ b/src/tests/_internal/cli/utils/test_run_plan_display.py
@@ -0,0 +1,756 @@
+"""Test CLI display of run plans with replica groups."""
+
+from dstack._internal.cli.utils.run import print_run_plan
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.configurations import ServiceConfiguration
+from dstack._internal.core.models.instances import (
+ Gpu,
+ InstanceAvailability,
+ InstanceType,
+ Resources,
+)
+from dstack._internal.core.models.profiles import Profile
+from dstack._internal.core.models.repos import LocalRunRepoData
+from dstack._internal.core.models.resources import Range, ResourcesSpec
+from dstack._internal.core.models.runs import (
+ ApplyAction,
+ InstanceOfferWithAvailability,
+ JobPlan,
+ JobSpec,
+ Requirements,
+ RunPlan,
+ RunSpec,
+)
+
+
+def create_test_offer(
+ backend: BackendType,
+ gpu_name: str,
+ price: float,
+ region: str = "us-east",
+ availability: InstanceAvailability = InstanceAvailability.AVAILABLE,
+) -> InstanceOfferWithAvailability:
+ """Helper to create test offers."""
+ return InstanceOfferWithAvailability(
+ backend=backend,
+ instance=InstanceType(
+ name=f"{gpu_name.lower()}-instance",
+ resources=Resources(
+ cpus=8,
+ memory_mib=16384,
+ gpus=[Gpu(name=gpu_name, memory_mib=40960)],
+ spot=False,
+ ),
+ ),
+ region=region,
+ price=price,
+ availability=availability,
+ )
+
+
+class TestReplicaGroupsDisplayInCLI:
+ """Test that replica groups are properly displayed in CLI output."""
+
+ def test_multiple_replica_groups_show_group_names(self, capsys):
+ """CLI should prefix offers with group names when multiple job plans exist."""
+ # Create a service with 2 replica groups
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replica_groups=[
+ {
+ "name": "l40s-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "L40S", "count": 1}},
+ },
+ {
+ "name": "a100-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "A100", "count": 1}},
+ },
+ ],
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.VASTAI]),
+ )
+
+ # Create job plans for each group
+ l40s_offer = create_test_offer(BackendType.VASTAI, "L40S", 0.50)
+ a100_offer = create_test_offer(BackendType.VASTAI, "A100", 1.20)
+
+ job_plan_l40s = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="l40s-group",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "L40S", "count": 1})
+ ),
+ ),
+ offers=[l40s_offer],
+ total_offers=1,
+ max_price=0.50,
+ )
+
+ job_plan_a100 = JobPlan(
+ job_spec=JobSpec(
+ replica_num=1,
+ replica_group_name="a100-group",
+ job_num=1,
+ job_name="test-job-1",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "A100", "count": 1})
+ ),
+ ),
+ offers=[a100_offer],
+ total_offers=1,
+ max_price=1.20,
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan_l40s, job_plan_a100],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print the plan
+ print_run_plan(run_plan, max_offers=10, include_run_properties=True)
+
+ # Capture output
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Verify group names are in the output
+ assert "l40s-group" in output, "l40s-group name should appear in output"
+ assert "a100-group" in output, "a100-group name should appear in output"
+
+ # Verify both GPU types are shown
+ assert "L40S" in output or "l40s" in output.lower()
+ assert "A100" in output or "a100" in output.lower()
+
+ def test_single_job_plan_no_group_prefix(self, capsys):
+ """CLI should NOT prefix offers when only one job plan exists (legacy)."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replicas=Range[int](min=1, max=1),
+ resources=ResourcesSpec(gpu={"name": "V100", "count": 1}),
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.AWS]),
+ )
+
+ v100_offer = create_test_offer(BackendType.AWS, "V100", 0.80)
+
+ job_plan = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="default",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "V100", "count": 1})
+ ),
+ ),
+ offers=[v100_offer],
+ total_offers=1,
+ max_price=0.80,
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print the plan
+ print_run_plan(run_plan, max_offers=10, include_run_properties=True)
+
+ # Capture output
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Verify NO group prefix (legacy mode)
+ assert "default:" not in output, "Legacy mode should not show group prefix"
+ # But should show backend normally
+ assert "aws" in output.lower()
+
+ def test_replica_groups_offers_sorted_by_price(self, capsys):
+ """Offers from multiple groups should be sorted by price across all groups."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replica_groups=[
+ {
+ "name": "expensive-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "H100", "count": 1}},
+ },
+ {
+ "name": "cheap-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "T4", "count": 1}},
+ },
+ ],
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.AWS]),
+ )
+
+ # Expensive offer
+ h100_offer = create_test_offer(BackendType.AWS, "H100", 3.00)
+ # Cheap offer
+ t4_offer = create_test_offer(BackendType.AWS, "T4", 0.30)
+
+ job_plan_expensive = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="expensive-group",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "H100", "count": 1})
+ ),
+ ),
+ offers=[h100_offer],
+ total_offers=1,
+ max_price=3.00,
+ )
+
+ job_plan_cheap = JobPlan(
+ job_spec=JobSpec(
+ replica_num=1,
+ replica_group_name="cheap-group",
+ job_num=1,
+ job_name="test-job-1",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(resources=ResourcesSpec(gpu={"name": "T4", "count": 1})),
+ ),
+ offers=[t4_offer],
+ total_offers=1,
+ max_price=0.30,
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan_expensive, job_plan_cheap],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print the plan
+ print_run_plan(run_plan, max_offers=10, include_run_properties=True)
+
+ # Capture output
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Split output to find the offers table (after the header section)
+ lines = output.split("\n")
+
+ # Find lines that contain both a number and a group name (these are offer rows)
+ offer_rows = [
+ line
+ for line in lines
+ if ("cheap-group:" in line or "expensive-group:" in line)
+ and line.strip().startswith(("1", "2", "3"))
+ ]
+
+ # The first offer row should be cheap-group (lower price)
+ assert len(offer_rows) >= 2, "Should have at least 2 offer rows"
+ assert "cheap-group:" in offer_rows[0], (
+ "First offer should be cheap-group (sorted by price)"
+ )
+ assert "expensive-group:" in offer_rows[1], "Second offer should be expensive-group"
+ assert "$0.3" in output # Price displayed as $0.3
+ assert "$3" in output # Price displayed as $3
+
+ def test_replica_group_with_no_offers_shows_message(self, capsys):
+ """Replica groups with no available offers should show a message."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replica_groups=[
+ {
+ "name": "available-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "L40S", "count": 1}},
+ },
+ {
+ "name": "unavailable-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "A100", "count": 1}},
+ },
+ ],
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.VASTAI]),
+ )
+
+ # One group has offers, another doesn't
+ l40s_offer = create_test_offer(BackendType.VASTAI, "L40S", 0.50)
+
+ job_plan_with_offers = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="available-group",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "L40S", "count": 1})
+ ),
+ ),
+ offers=[l40s_offer],
+ total_offers=1,
+ max_price=0.50,
+ )
+
+ job_plan_no_offers = JobPlan(
+ job_spec=JobSpec(
+ replica_num=1,
+ replica_group_name="unavailable-group",
+ job_num=1,
+ job_name="test-job-1",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "A100", "count": 1})
+ ),
+ ),
+ offers=[], # No offers
+ total_offers=0,
+ max_price=0.0,
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan_with_offers, job_plan_no_offers],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print the plan
+ print_run_plan(run_plan, max_offers=10, include_run_properties=True)
+
+ # Capture output
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Verify available group shows offer
+ assert "available-group:" in output
+ assert "L40S" in output
+
+ # Verify unavailable group shows the standard "no offers" message
+ # (Message may be wrapped across lines in table display)
+ assert "unavailable-group:" in output
+ assert "No matching instance" in output
+ assert "offers available" in output
+ assert "Possible reasons:" in output
+ assert "dstack.ai/docs" in output # URL may be truncated in table
+
+ # Verify unavailable group appears BEFORE available group (at top)
+ unavailable_pos = output.find("unavailable-group:")
+ available_pos = output.find("available-group:")
+ assert unavailable_pos < available_pos, "Group with no offers should appear first"
+
+
+class TestReplicaGroupsFairOfferDistribution:
+ """Test that CLI displays offers from all replica groups fairly."""
+
+ def test_all_groups_represented_in_display(self, capsys):
+ """Test that offers from all replica groups are shown when max_offers is set."""
+ # Create offers for three groups with different price ranges
+ h100_offers = [
+ create_test_offer(BackendType.AWS, "H100", 3.0, region="us-east"),
+ create_test_offer(BackendType.AWS, "H100", 3.5, region="us-west"),
+ create_test_offer(BackendType.GCP, "H100", 4.0, region="eu-west"),
+ ]
+
+ rtx5090_offers = [
+ create_test_offer(BackendType.VASTAI, "RTX5090", 0.5, region="us"),
+ create_test_offer(BackendType.VASTAI, "RTX5090", 0.6, region="eu"),
+ ]
+
+ a100_offers = [
+ create_test_offer(BackendType.AWS, "A100", 2.0, region="us-east"),
+ create_test_offer(BackendType.GCP, "A100", 2.2, region="eu-west"),
+ ]
+
+ # Create job plans for each group
+ job_plan_h100 = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="h100-group",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "H100", "count": 1})
+ ),
+ ),
+ offers=h100_offers,
+ total_offers=len(h100_offers),
+ max_price=4.0,
+ )
+
+ job_plan_rtx5090 = JobPlan(
+ job_spec=JobSpec(
+ replica_num=1,
+ replica_group_name="rtx5090-group",
+ job_num=1,
+ job_name="test-job-1",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1})
+ ),
+ ),
+ offers=rtx5090_offers,
+ total_offers=len(rtx5090_offers),
+ max_price=0.6,
+ )
+
+ job_plan_a100 = JobPlan(
+ job_spec=JobSpec(
+ replica_num=2,
+ replica_group_name="a100-group",
+ job_num=2,
+ job_name="test-job-2",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "A100", "count": 1})
+ ),
+ ),
+ offers=a100_offers,
+ total_offers=len(a100_offers),
+ max_price=2.2,
+ )
+
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.AWS]),
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan_h100, job_plan_rtx5090, job_plan_a100],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print with max_offers=5 (should show at least 1 from each group)
+ print_run_plan(run_plan, max_offers=5, include_run_properties=True)
+
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Verify all three groups appear in the output
+ assert "h100-group:" in output, "H100 group should be displayed"
+ assert "rtx5090-group:" in output, "RTX5090 group should be displayed"
+ assert "a100-group:" in output, "A100 group should be displayed"
+
+ # Verify GPUs are shown
+ assert "H100" in output
+ assert "RTX5090" in output
+ assert "A100" in output
+
+ def test_fair_distribution_with_limited_slots(self, capsys):
+ """Test that when max_offers is limited, all groups get fair representation."""
+ # Group 1: Many cheap offers
+ cheap_offers = [
+ create_test_offer(BackendType.VASTAI, "RTX5090", 0.4 + i * 0.1, region="us")
+ for i in range(10)
+ ]
+
+ # Group 2: Few expensive offers
+ expensive_offers = [
+ create_test_offer(BackendType.AWS, "H100", 3.0 + i * 0.5, region="us-east")
+ for i in range(3)
+ ]
+
+ job_plan_cheap = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="cheap-group",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1})
+ ),
+ ),
+ offers=cheap_offers,
+ total_offers=len(cheap_offers),
+ max_price=1.4,
+ )
+
+ job_plan_expensive = JobPlan(
+ job_spec=JobSpec(
+ replica_num=1,
+ replica_group_name="expensive-group",
+ job_num=1,
+ job_name="test-job-1",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "H100", "count": 1})
+ ),
+ ),
+ offers=expensive_offers,
+ total_offers=len(expensive_offers),
+ max_price=4.0,
+ )
+
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.VASTAI]),
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan_cheap, job_plan_expensive],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print with max_offers=4 (should show at least 1 from each group)
+ print_run_plan(run_plan, max_offers=4, include_run_properties=True)
+
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Both groups should be represented
+ assert "cheap-group:" in output
+ assert "expensive-group:" in output
+
+ # Count occurrences (rough check - both should appear)
+ cheap_count = output.count("cheap-group:")
+ expensive_count = output.count("expensive-group:")
+
+ # Both should have at least one offer shown
+ assert cheap_count >= 1, "Cheap group should have at least one offer"
+ assert expensive_count >= 1, "Expensive group should have at least one offer"
+
+
+class TestReplicaGroupsProfileOverridesDisplay:
+ """Test that CLI correctly displays profile overrides for replica groups."""
+
+ def test_shows_group_specific_spot_policy_and_regions(self, capsys):
+ """Test that group-specific spot_policy, regions, backends are displayed."""
+
+ from dstack._internal.core.models.backends.base import BackendType
+
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replica_groups=[
+ {
+ "name": "h100-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "H100", "count": 1}},
+ "spot_policy": "spot",
+ "regions": ["us-east-1", "us-west-2"],
+ "backends": ["aws"],
+ },
+ {
+ "name": "rtx5090-group",
+ "replicas": "0..5",
+ "resources": {"gpu": {"name": "RTX5090", "count": 1}},
+ "spot_policy": "on-demand",
+ "regions": ["jp-japan"],
+ "backends": ["vastai", "runpod"],
+ },
+ ],
+ scaling={"metric": "rps", "target": 10},
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ configuration_path=".dstack.yml",
+ profile=Profile(backends=[BackendType.AWS]),
+ )
+
+ # Create job plans
+ job_plan_h100 = JobPlan(
+ job_spec=JobSpec(
+ replica_num=0,
+ replica_group_name="h100-group",
+ job_num=0,
+ job_name="test-job-0",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "H100", "count": 1})
+ ),
+ ),
+ offers=[create_test_offer(BackendType.AWS, "H100", 3.0)],
+ total_offers=1,
+ max_price=3.0,
+ )
+
+ job_plan_rtx = JobPlan(
+ job_spec=JobSpec(
+ replica_num=1,
+ replica_group_name="rtx5090-group",
+ job_num=1,
+ job_name="test-job-1",
+ image_name="dstackai/base",
+ commands=["echo test"],
+ env={},
+ working_dir="/workflow",
+ requirements=Requirements(
+ resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1})
+ ),
+ ),
+ offers=[create_test_offer(BackendType.VASTAI, "RTX5090", 0.5)],
+ total_offers=1,
+ max_price=0.5,
+ )
+
+ run_plan = RunPlan(
+ project_name="test-project",
+ user="test-user",
+ run_spec=run_spec,
+ effective_run_spec=run_spec,
+ job_plans=[job_plan_h100, job_plan_rtx],
+ current_resource=None,
+ action=ApplyAction.CREATE,
+ )
+
+ # Print the plan
+ print_run_plan(run_plan, max_offers=10, include_run_properties=True)
+
+ # Capture output
+ captured = capsys.readouterr()
+ output = captured.out
+
+ # Verify group-specific overrides are shown
+ assert "h100-group" in output
+ assert "spot=spot" in output # H100 group's spot policy
+ assert "regions=us-east-1,us-west-2" in output # H100 group's regions
+ assert "backends=aws" in output # H100 group's backend
+
+ assert "rtx5090-group" in output
+ assert "spot=on-demand" in output # RTX5090 group's spot policy
+ assert "regions=jp-japan" in output # RTX5090 group's region
+ assert "backends=vastai,runpod" in output # RTX5090 group's backends
+
+ # Verify service-level "Spot policy" row is NOT shown (misleading with groups)
+ lines = output.split("\n")
+ spot_policy_lines = [line for line in lines if line.strip().startswith("Spot policy")]
+ assert len(spot_policy_lines) == 0, (
+ "Service-level 'Spot policy' should not be shown with replica_groups"
+ )
diff --git a/src/tests/_internal/core/models/test_replica_groups.py b/src/tests/_internal/core/models/test_replica_groups.py
new file mode 100644
index 000000000..7c0a35a08
--- /dev/null
+++ b/src/tests/_internal/core/models/test_replica_groups.py
@@ -0,0 +1,437 @@
+"""Tests for Named Replica Groups functionality"""
+
+import pytest
+
+from dstack._internal.core.errors import ConfigurationError
+from dstack._internal.core.models.configurations import (
+ ServiceConfiguration,
+ parse_run_configuration,
+)
+from dstack._internal.core.models.resources import Range
+from dstack._internal.core.models.runs import get_normalized_replica_groups
+
+
+class TestReplicaGroupConfiguration:
+ """Test replica group configuration parsing and validation"""
+
+ def test_basic_replica_groups(self):
+ """Test basic replica groups configuration"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "h100-group",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ },
+ {
+ "name": "rtx5090-group",
+ "replicas": 2,
+ "resources": {"gpu": "RTX5090:1"},
+ },
+ ],
+ }
+
+ parsed = parse_run_configuration(conf)
+ assert isinstance(parsed, ServiceConfiguration)
+ assert parsed.replica_groups is not None
+ assert len(parsed.replica_groups) == 2
+
+ # Check first group
+ assert parsed.replica_groups[0].name == "h100-group"
+ assert parsed.replica_groups[0].replicas == Range(min=1, max=1)
+ assert parsed.replica_groups[0].resources.gpu.name == ["H100"]
+
+ # Check second group
+ assert parsed.replica_groups[1].name == "rtx5090-group"
+ assert parsed.replica_groups[1].replicas == Range(min=2, max=2)
+ assert parsed.replica_groups[1].resources.gpu.name == ["RTX5090"]
+
+ def test_replica_groups_with_ranges(self):
+ """Test replica groups with autoscaling ranges"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "fixed-group",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ },
+ {
+ "name": "scalable-group",
+ "replicas": "1..3", # Range
+ "resources": {"gpu": "RTX5090:1"},
+ },
+ ],
+ "scaling": {
+ "metric": "rps",
+ "target": 10,
+ },
+ }
+
+ parsed = parse_run_configuration(conf)
+ assert parsed.replica_groups is not None
+ assert len(parsed.replica_groups) == 2
+
+ # Fixed group
+ assert parsed.replica_groups[0].replicas == Range(min=1, max=1)
+
+ # Scalable group
+ assert parsed.replica_groups[1].replicas == Range(min=1, max=3)
+
+ def test_replica_groups_with_profile_params(self):
+ """Test replica groups can override profile parameters"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ # Service-level settings
+ "backends": ["aws"],
+ "regions": ["us-west-2"],
+ "replica_groups": [
+ {
+ "name": "aws-group",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ # Inherits backends/regions from service
+ },
+ {
+ "name": "runpod-group",
+ "replicas": 1,
+ "resources": {"gpu": "RTX5090:1"},
+ # Override backends
+ "backends": ["runpod"],
+ "regions": ["eu-west-1"],
+ },
+ ],
+ }
+
+ parsed = parse_run_configuration(conf)
+
+ # First group inherits from service (doesn't specify backends/regions)
+ assert parsed.replica_groups[0].backends is None
+ assert parsed.replica_groups[0].regions is None
+
+ # Second group overrides
+ assert parsed.replica_groups[1].backends == ["runpod"]
+ assert parsed.replica_groups[1].regions == ["eu-west-1"]
+
+ def test_replica_groups_xor_replicas(self):
+ """Test that replica_groups and replicas are mutually exclusive"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replicas": 2, # Old format
+ "replica_groups": [ # New format
+ {
+ "name": "group1",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ }
+ ],
+ }
+
+ with pytest.raises(
+ ConfigurationError,
+ match="Cannot specify both 'replicas' and 'replica_groups'",
+ ):
+ parse_run_configuration(conf)
+
+ def test_replica_groups_unique_names(self):
+ """Test that replica group names must be unique"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "group1",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ },
+ {
+ "name": "group1", # Duplicate!
+ "replicas": 1,
+ "resources": {"gpu": "RTX5090:1"},
+ },
+ ],
+ }
+
+ with pytest.raises(
+ ConfigurationError,
+ match="Replica group names must be unique",
+ ):
+ parse_run_configuration(conf)
+
+ def test_replica_groups_empty_name(self):
+ """Test that replica group names cannot be empty"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "", # Empty name
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ }
+ ],
+ }
+
+ with pytest.raises(
+ ConfigurationError,
+ match="Group name cannot be empty",
+ ):
+ parse_run_configuration(conf)
+
+ def test_replica_groups_range_requires_scaling(self):
+ """Test that replica ranges require scaling configuration"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "scalable-group",
+ "replicas": "1..3",
+ "resources": {"gpu": "RTX5090:1"},
+ }
+ ],
+ # Missing scaling!
+ }
+
+ with pytest.raises(
+ ConfigurationError,
+ match="When any replica group has a range, 'scaling' must be specified",
+ ):
+ parse_run_configuration(conf)
+
+ def test_replica_groups_cannot_be_empty(self):
+ """Test that replica_groups list cannot be empty"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [], # Empty list
+ }
+
+ with pytest.raises(
+ ConfigurationError,
+ match="replica_groups cannot be empty",
+ ):
+ parse_run_configuration(conf)
+
+
+class TestReplicaGroupNormalization:
+ """Test get_normalized_replica_groups helper"""
+
+ def test_normalize_new_format(self):
+ """Test normalization with replica_groups format"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "group1",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ },
+ {
+ "name": "group2",
+ "replicas": 2,
+ "resources": {"gpu": "RTX5090:1"},
+ },
+ ],
+ }
+
+ parsed = parse_run_configuration(conf)
+ normalized = get_normalized_replica_groups(parsed)
+
+ assert len(normalized) == 2
+ assert normalized[0].name == "group1"
+ assert normalized[1].name == "group2"
+
+ def test_normalize_legacy_format(self):
+ """Test normalization converts legacy replicas to default group"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replicas": 3,
+ "resources": {"gpu": "H100:1"},
+ "backends": ["aws"],
+ "regions": ["us-west-2"],
+ }
+
+ parsed = parse_run_configuration(conf)
+ normalized = get_normalized_replica_groups(parsed)
+
+ # Should create single "default" group
+ assert len(normalized) == 1
+ assert normalized[0].name == "default"
+ assert normalized[0].replicas == Range(min=3, max=3)
+ assert normalized[0].resources.gpu.name == ["H100"]
+
+ # Should inherit profile params
+ assert normalized[0].backends == ["aws"]
+ assert normalized[0].regions == ["us-west-2"]
+
+ def test_normalize_legacy_with_range(self):
+ """Test normalization with legacy autoscaling"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replicas": "1..5",
+ "resources": {"gpu": "RTX5090:1"},
+ "scaling": {
+ "metric": "rps",
+ "target": 10,
+ },
+ }
+
+ parsed = parse_run_configuration(conf)
+ normalized = get_normalized_replica_groups(parsed)
+
+ assert len(normalized) == 1
+ assert normalized[0].name == "default"
+ assert normalized[0].replicas == Range(min=1, max=5)
+
+
+class TestReplicaGroupAutoscaling:
+ """Test autoscaling behavior with replica groups"""
+
+ def test_autoscalable_group_detection(self):
+ """Test identifying which groups are autoscalable"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "fixed",
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ },
+ {
+ "name": "scalable",
+ "replicas": "1..3",
+ "resources": {"gpu": "RTX5090:1"},
+ },
+ ],
+ "scaling": {
+ "metric": "rps",
+ "target": 10,
+ },
+ }
+
+ parsed = parse_run_configuration(conf)
+
+ # Fixed group: min == max
+ assert parsed.replica_groups[0].replicas.min == parsed.replica_groups[0].replicas.max
+
+ # Scalable group: min != max
+ assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max
+
+ def test_multiple_autoscalable_groups(self):
+ """Test multiple groups can be autoscalable"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replica_groups": [
+ {
+ "name": "scalable-1",
+ "replicas": "1..3",
+ "resources": {"gpu": "H100:1"},
+ },
+ {
+ "name": "scalable-2",
+ "replicas": "2..5",
+ "resources": {"gpu": "RTX5090:1"},
+ },
+ ],
+ "scaling": {
+ "metric": "rps",
+ "target": 10,
+ },
+ }
+
+ parsed = parse_run_configuration(conf)
+
+ # Both are autoscalable
+ assert parsed.replica_groups[0].replicas.min != parsed.replica_groups[0].replicas.max
+ assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max
+
+
+class TestBackwardCompatibility:
+ """Test backward compatibility with existing configurations"""
+
+ def test_legacy_service_config(self):
+ """Test that legacy service configs still work"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replicas": 2,
+ "resources": {"gpu": "A100:1"},
+ }
+
+ parsed = parse_run_configuration(conf)
+
+ # Should parse successfully
+ assert isinstance(parsed, ServiceConfiguration)
+ assert parsed.replicas == Range(min=2, max=2)
+ assert parsed.replica_groups is None # Not using new format
+
+ def test_legacy_autoscaling_config(self):
+ """Test legacy autoscaling configurations"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replicas": "0..5",
+ "resources": {"gpu": "A100:1"},
+ "scaling": {
+ "metric": "rps",
+ "target": 10,
+ },
+ }
+
+ parsed = parse_run_configuration(conf)
+
+ # Should parse successfully
+ assert parsed.replicas == Range(min=0, max=5)
+ assert parsed.scaling is not None
+
+ def test_normalization_preserves_all_profile_params(self):
+ """Test that normalization copies all ProfileParams fields"""
+ conf = {
+ "type": "service",
+ "commands": ["python3 app.py"],
+ "port": 8000,
+ "replicas": 1,
+ "resources": {"gpu": "H100:1"},
+ "backends": ["aws"],
+ "regions": ["us-east-1"],
+ "instance_types": ["p4d.24xlarge"],
+ "spot_policy": "spot",
+ "max_price": 10.0,
+ }
+
+ parsed = parse_run_configuration(conf)
+ normalized = get_normalized_replica_groups(parsed)
+
+ # Check all fields are copied
+ group = normalized[0]
+ assert group.backends == ["aws"]
+ assert group.regions == ["us-east-1"]
+ assert group.instance_types == ["p4d.24xlarge"]
+ assert group.spot_policy == "spot"
+ assert group.max_price == 10.0
diff --git a/src/tests/_internal/core/test_backward_compatibility.py b/src/tests/_internal/core/test_backward_compatibility.py
new file mode 100644
index 000000000..ca49f12a3
--- /dev/null
+++ b/src/tests/_internal/core/test_backward_compatibility.py
@@ -0,0 +1,127 @@
+"""Test backward compatibility for replica_groups with older servers."""
+
+from dstack._internal.core.compatibility.runs import get_get_plan_excludes, get_run_spec_excludes
+from dstack._internal.core.models.configurations import ServiceConfiguration
+from dstack._internal.core.models.repos import LocalRunRepoData
+from dstack._internal.core.models.runs import RunSpec
+from dstack._internal.server.schemas.runs import GetRunPlanRequest
+
+
+class TestReplicaGroupsBackwardCompatibility:
+ """Test that replica_groups field is excluded when None for backward compatibility."""
+
+ def test_replica_groups_excluded_when_none(self):
+ """replica_groups should be excluded from JSON when None."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replicas={"min": 1, "max": 1},
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ profile=None,
+ )
+
+ # Get excludes
+ excludes = get_run_spec_excludes(run_spec)
+
+ # replica_groups should be in excludes
+ assert "configuration" in excludes
+ assert "replica_groups" in excludes["configuration"]
+ assert excludes["configuration"]["replica_groups"] is True
+
+ def test_replica_groups_not_excluded_when_set(self):
+ """replica_groups should NOT be excluded when set."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replica_groups=[
+ {
+ "name": "gpu-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "A100"}},
+ }
+ ],
+ scaling={"metric": "rps", "target": 10},
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ profile=None,
+ )
+
+ # Get excludes
+ excludes = get_run_spec_excludes(run_spec)
+
+ # replica_groups should NOT be in excludes (or be False)
+ if "configuration" in excludes and "replica_groups" in excludes["configuration"]:
+ assert excludes["configuration"]["replica_groups"] is not True
+
+ def test_get_plan_request_serialization_without_replica_groups(self):
+ """GetRunPlanRequest should not include replica_groups in JSON when None."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replicas={"min": 1, "max": 1},
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ profile=None,
+ )
+
+ request = GetRunPlanRequest(run_spec=run_spec, max_offers=None)
+ excludes = get_get_plan_excludes(request)
+
+ # Serialize with excludes
+ json_str = request.json(exclude=excludes)
+
+ # replica_groups should not appear in JSON
+ assert "replica_groups" not in json_str
+
+ def test_get_plan_request_serialization_with_replica_groups(self):
+ """GetRunPlanRequest should include replica_groups in JSON when set."""
+ config = ServiceConfiguration(
+ type="service",
+ port=8000,
+ commands=["echo test"],
+ replica_groups=[
+ {
+ "name": "gpu-group",
+ "replicas": "1",
+ "resources": {"gpu": {"name": "A100"}},
+ }
+ ],
+ scaling={"metric": "rps", "target": 10},
+ )
+
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data=LocalRunRepoData(repo_dir="/tmp"),
+ configuration=config,
+ profile=None,
+ )
+
+ request = GetRunPlanRequest(run_spec=run_spec, max_offers=None)
+ excludes = get_get_plan_excludes(request)
+
+ # Serialize with excludes
+ json_str = request.json(exclude=excludes)
+
+ # replica_groups SHOULD appear in JSON
+ assert "replica_groups" in json_str
+ assert "gpu-group" in json_str
diff --git a/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py b/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py
new file mode 100644
index 000000000..3e9e2e3d5
--- /dev/null
+++ b/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py
@@ -0,0 +1,242 @@
+"""
+Tests for migrating legacy jobs without replica_group_name.
+"""
+
+import pytest
+
+from dstack._internal.core.models.configurations import (
+ ServiceConfiguration,
+)
+from dstack._internal.core.models.profiles import Profile
+from dstack._internal.core.models.resources import Range, ResourcesSpec
+from dstack._internal.core.models.runs import JobStatus, RunSpec
+from dstack._internal.server.background.tasks.process_runs import (
+ _migrate_legacy_job_replica_groups,
+)
+from dstack._internal.server.testing.common import (
+ create_job,
+ create_project,
+ create_repo,
+ create_run,
+ create_user,
+)
+
+
+class TestMigrateLegacyJobs:
+ @pytest.mark.asyncio
+ async def test_migrates_jobs_without_replica_group_name(
+ self, test_db, session, socket_enabled
+ ):
+ """Test that jobs without replica_group_name get migrated correctly."""
+ user = await create_user(session=session)
+ project = await create_project(session=session, owner=user)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ # Create a run with replica_groups configuration
+ service_config = ServiceConfiguration(
+ replica_groups=[
+ {
+ "name": "h100-gpu",
+ "replicas": Range(min=1, max=1),
+ "resources": ResourcesSpec(gpu="H100:1"),
+ },
+ {
+ "name": "rtx5090-gpu",
+ "replicas": Range(min=1, max=1),
+ "resources": ResourcesSpec(gpu="RTX5090:1"),
+ },
+ ],
+ commands=["echo hello"],
+ port=8000,
+ )
+
+ run_spec = RunSpec(
+ run_name="test-service",
+ repo_id="test-repo",
+ configuration=service_config,
+ merged_profile=Profile(),
+ )
+
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-service",
+ run_spec=run_spec,
+ )
+
+ # Create jobs WITHOUT replica_group_name (simulating old code)
+ job1 = await create_job(
+ session=session,
+ run=run,
+ replica_num=0,
+ replica_group_name=None, # Old job without group
+ status=JobStatus.RUNNING,
+ )
+
+ job2 = await create_job(
+ session=session,
+ run=run,
+ replica_num=1,
+ replica_group_name=None, # Old job without group
+ status=JobStatus.RUNNING,
+ )
+
+ # Verify jobs have no group
+ assert job1.replica_group_name is None
+ assert job2.replica_group_name is None
+
+ # Refresh run to load jobs relationship
+ await session.refresh(run, ["jobs"])
+
+ # Run migration
+ await _migrate_legacy_job_replica_groups(session, run)
+ await session.refresh(job1)
+ await session.refresh(job2)
+
+ # Verify jobs now have correct groups
+ assert job1.replica_group_name == "h100-gpu"
+ assert job2.replica_group_name == "rtx5090-gpu"
+
+ @pytest.mark.asyncio
+ async def test_skips_already_migrated_jobs(self, test_db, session):
+ """Test that jobs with replica_group_name are not re-migrated."""
+ user = await create_user(session=session)
+ project = await create_project(session=session, owner=user)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ service_config = ServiceConfiguration(
+ replica_groups=[
+ {
+ "name": "gpu-group",
+ "replicas": Range(min=1, max=1),
+ "resources": ResourcesSpec(gpu="A100:1"),
+ },
+ ],
+ commands=["echo hello"],
+ port=8000,
+ )
+
+ run_spec = RunSpec(
+ run_name="test-service",
+ repo_id="test-repo",
+ configuration=service_config,
+ merged_profile=Profile(),
+ )
+
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-service",
+ run_spec=run_spec,
+ )
+
+ # Create job WITH replica_group_name (already migrated)
+ job = await create_job(
+ session=session,
+ run=run,
+ replica_num=0,
+ replica_group_name="gpu-group",
+ status=JobStatus.RUNNING,
+ )
+
+ original_group = job.replica_group_name
+
+ # Run migration (should be a no-op)
+ await _migrate_legacy_job_replica_groups(session, run)
+ await session.refresh(job)
+
+ # Verify group unchanged
+ assert job.replica_group_name == original_group
+
+ @pytest.mark.asyncio
+ async def test_skips_non_service_runs(self, test_db, session):
+ """Test that non-service runs are skipped."""
+ from dstack._internal.core.models.configurations import TaskConfiguration
+
+ user = await create_user(session=session)
+ project = await create_project(session=session, owner=user)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ task_config = TaskConfiguration(
+ commands=["echo hello"],
+ )
+
+ run_spec = RunSpec(
+ run_name="test-task",
+ repo_id="test-repo",
+ configuration=task_config,
+ merged_profile=Profile(),
+ )
+
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-task",
+ run_spec=run_spec,
+ )
+
+ job = await create_job(
+ session=session,
+ run=run,
+ replica_num=0,
+ replica_group_name=None,
+ status=JobStatus.RUNNING,
+ )
+
+ # Run migration (should skip task runs)
+ await _migrate_legacy_job_replica_groups(session, run)
+ await session.refresh(job)
+
+ # Verify no change
+ assert job.replica_group_name is None
+
+ @pytest.mark.asyncio
+ async def test_skips_legacy_replicas_config(self, test_db, session):
+ """Test that runs using legacy 'replicas' (not replica_groups) are skipped."""
+ user = await create_user(session=session)
+ project = await create_project(session=session, owner=user)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ # Use legacy replicas configuration
+ service_config = ServiceConfiguration(
+ replicas=Range(min=2, max=2),
+ commands=["echo hello"],
+ port=8000,
+ )
+
+ run_spec = RunSpec(
+ run_name="test-service",
+ repo_id="test-repo",
+ configuration=service_config,
+ merged_profile=Profile(),
+ )
+
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-service",
+ run_spec=run_spec,
+ )
+
+ job = await create_job(
+ session=session,
+ run=run,
+ replica_num=0,
+ replica_group_name=None,
+ status=JobStatus.RUNNING,
+ )
+
+ # Run migration (should skip legacy replicas)
+ await _migrate_legacy_job_replica_groups(session, run)
+ await session.refresh(job)
+
+ # Verify no change (legacy replicas don't use replica_group_name)
+ assert job.replica_group_name is None
diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py
index f4e481f53..42758111c 100644
--- a/src/tests/_internal/server/routers/test_runs.py
+++ b/src/tests/_internal/server/routers/test_runs.py
@@ -233,6 +233,7 @@ def get_dev_env_run_plan_dict(
"privileged": True if docker else privileged,
"job_name": f"{run_name}-0-0",
"replica_num": 0,
+ "replica_group_name": None,
"job_num": 0,
"jobs_per_replica": 1,
"single_branch": False,
@@ -441,6 +442,7 @@ def get_dev_env_run_dict(
"privileged": True if docker else privileged,
"job_name": f"{run_name}-0-0",
"replica_num": 0,
+ "replica_group_name": None,
"job_num": 0,
"jobs_per_replica": 1,
"single_branch": False,
diff --git a/src/tests/_internal/server/services/test_get_plan_replica_groups.py b/src/tests/_internal/server/services/test_get_plan_replica_groups.py
new file mode 100644
index 000000000..84a549656
--- /dev/null
+++ b/src/tests/_internal/server/services/test_get_plan_replica_groups.py
@@ -0,0 +1,206 @@
+"""Test get_plan() offer fetching logic for replica groups."""
+
+from dstack._internal.core.models.resources import ResourcesSpec
+from dstack._internal.core.models.runs import Requirements
+
+
+class TestGetPlanOfferFetchingLogic:
+ """Test the logic for determining when to fetch offers per-job vs. shared."""
+
+ def test_requirements_equality_check(self):
+ """Test that Requirements objects can be compared for equality."""
+ # Identical requirements
+ req1 = Requirements(
+ resources=ResourcesSpec(gpu={"name": "A100", "count": 1}),
+ )
+ req2 = Requirements(
+ resources=ResourcesSpec(gpu={"name": "A100", "count": 1}),
+ )
+ assert req1 == req2
+
+ # Different GPU names
+ req3 = Requirements(
+ resources=ResourcesSpec(gpu={"name": "H100", "count": 1}),
+ )
+ assert req1 != req3
+
+ # Different GPU counts
+ req4 = Requirements(
+ resources=ResourcesSpec(gpu={"name": "A100", "count": 2}),
+ )
+ assert req1 != req4
+
+ def test_identical_requirements_detection_logic(self):
+ """Test logic for detecting when all jobs have identical requirements."""
+
+ # Simulate job specs with requirements
+ class MockJobSpec:
+ def __init__(self, gpu_name: str, gpu_count: int = 1):
+ self.requirements = Requirements(
+ resources=ResourcesSpec(gpu={"name": gpu_name, "count": gpu_count}),
+ )
+
+ class MockJob:
+ def __init__(self, gpu_name: str, gpu_count: int = 1):
+ self.job_spec = MockJobSpec(gpu_name, gpu_count)
+
+ # Test 1: All identical
+ jobs = [MockJob("A100"), MockJob("A100"), MockJob("A100")]
+ all_requirements_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_requirements_identical is True
+
+ # Test 2: Different GPU types
+ jobs = [MockJob("A100"), MockJob("H100"), MockJob("A100")]
+ all_requirements_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_requirements_identical is False
+
+ # Test 3: Different GPU counts
+ jobs = [MockJob("A100", 1), MockJob("A100", 2)]
+ all_requirements_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_requirements_identical is False
+
+ # Test 4: Single job (always identical)
+ jobs = [MockJob("V100")]
+ all_requirements_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_requirements_identical is True
+
+ def test_offer_fetch_decision_logic(self):
+ """Test the decision logic for when to use shared vs per-job offer fetching."""
+
+ class MockJobSpec:
+ def __init__(self, gpu_name: str):
+ self.requirements = Requirements(
+ resources=ResourcesSpec(gpu={"name": gpu_name, "count": 1}),
+ )
+
+ class MockJob:
+ def __init__(self, group_name: str, gpu_name: str):
+ self.job_spec = MockJobSpec(gpu_name)
+ self.job_spec.replica_group_name = group_name
+
+ # Scenario 1: Replica groups with different GPUs -> per-job fetch
+ jobs = [
+ MockJob("l40s-group", "L40S"),
+ MockJob("rtx4080-group", "RTX4080"),
+ ]
+ all_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_identical is False, "Different GPU types should trigger per-job offer fetch"
+
+ # Scenario 2: Replica groups with same GPU -> shared fetch (optimization)
+ jobs = [
+ MockJob("group-1", "A100"),
+ MockJob("group-2", "A100"),
+ ]
+ all_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_identical is True, "Identical GPUs should use shared offer fetch"
+
+ # Scenario 3: Legacy replicas (same requirements) -> shared fetch
+ jobs = [
+ MockJob("default", "V100"),
+ MockJob("default", "V100"),
+ ]
+ all_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_identical is True, "Legacy replicas with same GPU should use shared fetch"
+
+ # Scenario 4: Mixed groups (2 same + 1 different) -> per-job fetch
+ jobs = [
+ MockJob("a100-group-1", "A100"),
+ MockJob("h100-group", "H100"),
+ MockJob("a100-group-2", "A100"),
+ ]
+ all_identical = all(
+ job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs
+ )
+ assert all_identical is False, "Mix of different GPUs should trigger per-job fetch for all"
+
+
+class TestReplicaGroupOfferSearchIntegration:
+ """Integration tests for replica group offer search behavior."""
+
+ def test_different_gpu_types_creates_different_requirements(self):
+ """Different replica group GPU types should create different Requirements objects."""
+ # This tests the data model behavior that get_plan() relies on
+ req_l40s = Requirements(
+ resources=ResourcesSpec(
+ gpu={"name": "L40S", "count": 1},
+ )
+ )
+
+ req_rtx4080 = Requirements(
+ resources=ResourcesSpec(
+ gpu={"name": "RTX4080", "count": 1},
+ )
+ )
+
+ # These should NOT be equal
+ assert req_l40s != req_rtx4080
+
+ # Verify the GPU names are different
+ assert req_l40s.resources.gpu.name != req_rtx4080.resources.gpu.name
+
+ def test_identical_gpu_types_creates_identical_requirements(self):
+ """Identical replica group GPU types should create equal Requirements objects."""
+ req_a = Requirements(
+ resources=ResourcesSpec(
+ gpu={"name": "A100", "count": 1},
+ )
+ )
+
+ req_b = Requirements(
+ resources=ResourcesSpec(
+ gpu={"name": "A100", "count": 1},
+ )
+ )
+
+ # These SHOULD be equal (enables optimization)
+ assert req_a == req_b
+
+ def test_requirements_with_different_memory(self):
+ """Requirements with different GPU memory should not be equal."""
+ req_16gb = Requirements(
+ resources=ResourcesSpec(
+ gpu={"name": "A100", "memory": "16GB", "count": 1},
+ )
+ )
+
+ req_40gb = Requirements(
+ resources=ResourcesSpec(
+ gpu={"name": "A100", "memory": "40GB", "count": 1},
+ )
+ )
+
+ # Different memory specifications
+ assert req_16gb != req_40gb
+
+ def test_requirements_with_different_cpu_specs(self):
+ """Requirements with different CPU specs should not be equal."""
+ req_low_cpu = Requirements(
+ resources=ResourcesSpec(
+ cpu={"min": 2},
+ gpu={"name": "A100", "count": 1},
+ )
+ )
+
+ req_high_cpu = Requirements(
+ resources=ResourcesSpec(
+ cpu={"min": 16},
+ gpu={"name": "A100", "count": 1},
+ )
+ )
+
+ # Different CPU requirements
+ assert req_low_cpu != req_high_cpu
diff --git a/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py b/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py
new file mode 100644
index 000000000..65bf86ccf
--- /dev/null
+++ b/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py
@@ -0,0 +1,265 @@
+"""Tests for replica group profile overrides (regions, spot_policy, etc.)"""
+
+import pytest
+from pydantic import parse_obj_as
+
+from dstack._internal.core.models.backends.base import BackendType
+from dstack._internal.core.models.configurations import (
+ ReplicaGroup,
+ ScalingSpec,
+ ServiceConfiguration,
+)
+from dstack._internal.core.models.profiles import Profile, SpotPolicy
+from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec
+from dstack._internal.core.models.runs import RunSpec
+from dstack._internal.server.services.runs import _get_job_profile
+
+pytestmark = pytest.mark.usefixtures("image_config_mock")
+
+
+def test_spot_policy_override_per_group():
+ """Test that each replica group can have its own spot_policy."""
+ # Create a service with different spot policies per group
+ config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="on-demand-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ spot_policy=SpotPolicy.ONDEMAND,
+ ),
+ ReplicaGroup(
+ name="spot-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")),
+ spot_policy=SpotPolicy.SPOT,
+ ),
+ ReplicaGroup(
+ name="auto-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")),
+ spot_policy=SpotPolicy.AUTO,
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ profile = Profile(name="test-profile", spot_policy=SpotPolicy.AUTO) # base policy
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=config,
+ profile=profile,
+ ssh_key_pub="ssh_key",
+ )
+
+ # Test on-demand group
+ on_demand_profile = _get_job_profile(run_spec, "on-demand-group")
+ assert on_demand_profile.spot_policy == SpotPolicy.ONDEMAND
+
+ # Test spot group
+ spot_profile = _get_job_profile(run_spec, "spot-group")
+ assert spot_profile.spot_policy == SpotPolicy.SPOT
+
+ # Test auto group
+ auto_profile = _get_job_profile(run_spec, "auto-group")
+ assert auto_profile.spot_policy == SpotPolicy.AUTO
+
+ # Test legacy (no group) uses base profile
+ legacy_profile = _get_job_profile(run_spec, None)
+ assert legacy_profile.spot_policy == SpotPolicy.AUTO
+
+
+def test_regions_override_per_group():
+ """Test that each replica group can have its own regions."""
+ config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="us-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ regions=["us-east-1", "us-west-2"],
+ ),
+ ReplicaGroup(
+ name="eu-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")),
+ regions=["eu-west-1", "eu-central-1"],
+ ),
+ ReplicaGroup(
+ name="asia-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")),
+ regions=["ap-northeast-1"],
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ profile = Profile(name="test-profile", regions=["us-east-1"]) # base regions
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=config,
+ profile=profile,
+ ssh_key_pub="ssh_key",
+ )
+
+ # Test US group
+ us_profile = _get_job_profile(run_spec, "us-group")
+ assert us_profile.regions == ["us-east-1", "us-west-2"]
+
+ # Test EU group
+ eu_profile = _get_job_profile(run_spec, "eu-group")
+ assert eu_profile.regions == ["eu-west-1", "eu-central-1"]
+
+ # Test Asia group
+ asia_profile = _get_job_profile(run_spec, "asia-group")
+ assert asia_profile.regions == ["ap-northeast-1"]
+
+ # Test legacy (no group) uses base profile
+ legacy_profile = _get_job_profile(run_spec, None)
+ assert legacy_profile.regions == ["us-east-1"]
+
+
+def test_backends_override_per_group():
+ """Test that each replica group can have its own backends."""
+ config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="aws-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ backends=[BackendType.AWS],
+ ),
+ ReplicaGroup(
+ name="vastai-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")),
+ backends=[BackendType.VASTAI],
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ profile = Profile(
+ name="test-profile",
+ backends=[BackendType.AWS, BackendType.GCP], # base backends
+ )
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=config,
+ profile=profile,
+ ssh_key_pub="ssh_key",
+ )
+
+ # Test AWS group
+ aws_profile = _get_job_profile(run_spec, "aws-group")
+ assert aws_profile.backends == [BackendType.AWS]
+
+ # Test VastAI group
+ vastai_profile = _get_job_profile(run_spec, "vastai-group")
+ assert vastai_profile.backends == [BackendType.VASTAI]
+
+ # Test legacy (no group) uses base profile
+ legacy_profile = _get_job_profile(run_spec, None)
+ assert legacy_profile.backends == [BackendType.AWS, BackendType.GCP]
+
+
+def test_multiple_profile_overrides_per_group():
+ """Test that a replica group can override multiple profile parameters at once."""
+ config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="specialized-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ regions=["us-west-2"],
+ backends=[BackendType.AWS],
+ spot_policy=SpotPolicy.ONDEMAND,
+ max_price=5.0,
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ profile = Profile(
+ name="test-profile",
+ regions=["us-east-1"],
+ backends=[BackendType.GCP],
+ spot_policy=SpotPolicy.SPOT,
+ max_price=1.0,
+ )
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=config,
+ profile=profile,
+ ssh_key_pub="ssh_key",
+ )
+
+ specialized_profile = _get_job_profile(run_spec, "specialized-group")
+ assert specialized_profile.regions == ["us-west-2"]
+ assert specialized_profile.backends == [BackendType.AWS]
+ assert specialized_profile.spot_policy == SpotPolicy.ONDEMAND
+ assert specialized_profile.max_price == 5.0
+
+
+def test_partial_profile_override():
+ """Test that only specified profile parameters are overridden, others inherit from base."""
+ config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="partial-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ regions=["us-west-2"], # Only override regions
+ # spot_policy, backends, max_price should inherit from base
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ profile = Profile(
+ name="test-profile",
+ regions=["us-east-1"],
+ backends=[BackendType.GCP, BackendType.AWS],
+ spot_policy=SpotPolicy.SPOT,
+ max_price=2.5,
+ )
+ run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=config,
+ profile=profile,
+ ssh_key_pub="ssh_key",
+ )
+
+ partial_profile = _get_job_profile(run_spec, "partial-group")
+ # Overridden
+ assert partial_profile.regions == ["us-west-2"]
+ # Inherited from base
+ assert partial_profile.backends == [BackendType.GCP, BackendType.AWS]
+ assert partial_profile.spot_policy == SpotPolicy.SPOT
+ assert partial_profile.max_price == 2.5
diff --git a/src/tests/_internal/server/services/test_replica_groups_scaling.py b/src/tests/_internal/server/services/test_replica_groups_scaling.py
new file mode 100644
index 000000000..9cea6193b
--- /dev/null
+++ b/src/tests/_internal/server/services/test_replica_groups_scaling.py
@@ -0,0 +1,395 @@
+"""Integration tests for replica groups scaling functionality"""
+
+from typing import List
+
+import pytest
+from pydantic import parse_obj_as
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import selectinload
+
+from dstack._internal.core.models.configurations import (
+ ReplicaGroup,
+ ScalingSpec,
+ ServiceConfiguration,
+)
+from dstack._internal.core.models.profiles import Profile
+from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec
+from dstack._internal.core.models.runs import JobStatus, JobTerminationReason
+from dstack._internal.server.models import RunModel
+from dstack._internal.server.services.runs import scale_run_replicas
+from dstack._internal.server.testing.common import (
+ create_job,
+ create_project,
+ create_repo,
+ create_run,
+ create_user,
+ get_run_spec,
+)
+
+pytestmark = pytest.mark.usefixtures("image_config_mock")
+
+
+async def scale_wrapper(session: AsyncSession, run: RunModel, diff: int):
+ """Wrapper that handles commit and refresh like existing tests"""
+ await scale_run_replicas(session, run, diff)
+ await session.commit()
+ await session.refresh(run)
+
+
+async def make_run_with_groups(
+ session: AsyncSession,
+ groups_config: List[dict], # List of {name, replicas_range, gpu, initial_jobs}
+) -> RunModel:
+ """Helper to create a run with replica groups"""
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ # Build replica groups
+ replica_groups = []
+ for group_cfg in groups_config:
+ replica_groups.append(
+ ReplicaGroup(
+ name=group_cfg["name"],
+ replicas=parse_obj_as(Range[int], group_cfg["replicas_range"]),
+ resources=ResourcesSpec(gpu=GPUSpec(name=[group_cfg["gpu"]], count=1)),
+ )
+ )
+
+ profile = Profile(name="test-profile")
+ run_spec = get_run_spec(
+ repo_id=repo.name,
+ run_name="test-run",
+ profile=profile,
+ configuration=ServiceConfiguration(
+ commands=["python app.py"],
+ port=8000,
+ replica_groups=replica_groups,
+ scaling=ScalingSpec(metric="rps", target=10),
+ ),
+ )
+
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-run",
+ run_spec=run_spec,
+ )
+
+ # Create initial jobs
+ replica_num = 0
+ for group_cfg in groups_config:
+ for job_status in group_cfg.get("initial_jobs", []):
+ job = await create_job(
+ session=session,
+ run=run,
+ status=job_status,
+ replica_num=replica_num,
+ replica_group_name=group_cfg["name"],
+ )
+ run.jobs.append(job)
+ replica_num += 1
+
+ await session.commit()
+
+ # Reload with jobs and project
+ res = await session.execute(
+ select(RunModel)
+ .where(RunModel.id == run.id)
+ .options(selectinload(RunModel.jobs), selectinload(RunModel.project))
+ )
+ return res.scalar_one()
+
+
+class TestReplicaGroupsScaleDown:
+ """Test scaling down with replica groups"""
+
+ @pytest.mark.asyncio
+ async def test_scale_down_only_from_autoscalable_groups(self, session: AsyncSession):
+ """Test that scale down only affects autoscalable groups"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "fixed-h100",
+ "replicas_range": "1..1", # Fixed
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING],
+ },
+ {
+ "name": "scalable-rtx",
+ "replicas_range": "1..3", # Autoscalable
+ "gpu": "RTX5090",
+ "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING],
+ },
+ ],
+ )
+
+ # Scale down by 1 (should only affect scalable group)
+ await scale_wrapper(session, run, -1)
+
+ # Check: fixed group should still have 1 running job
+ fixed_jobs = [j for j in run.jobs if j.replica_group_name == "fixed-h100"]
+ assert len(fixed_jobs) == 1
+ assert fixed_jobs[0].status == JobStatus.RUNNING
+
+ # Check: scalable group should have 1 terminated, 1 running
+ scalable_jobs = [j for j in run.jobs if j.replica_group_name == "scalable-rtx"]
+ assert len(scalable_jobs) == 2
+ terminating = [j for j in scalable_jobs if j.status == JobStatus.TERMINATING]
+ assert len(terminating) == 1
+ assert terminating[0].termination_reason == JobTerminationReason.SCALED_DOWN
+
+ @pytest.mark.asyncio
+ async def test_scale_down_respects_group_minimums(self, session: AsyncSession):
+ """Test that scale down respects each group's minimum"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "group-a",
+ "replicas_range": "1..3", # Min=1
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING], # At minimum
+ },
+ {
+ "name": "group-b",
+ "replicas_range": "2..5", # Min=2
+ "gpu": "RTX5090",
+ "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.RUNNING],
+ },
+ ],
+ )
+
+ # Try to scale down by 2
+ await scale_wrapper(session, run, -2)
+
+ # Group A should still have 1 (at minimum)
+ group_a_jobs = [j for j in run.jobs if j.replica_group_name == "group-a"]
+ assert len([j for j in group_a_jobs if j.status == JobStatus.RUNNING]) == 1
+
+ # Group B should have terminated 1 (3 -> 2, which is minimum)
+ group_b_jobs = [j for j in run.jobs if j.replica_group_name == "group-b"]
+ terminating = [j for j in group_b_jobs if j.status == JobStatus.TERMINATING]
+ assert len(terminating) == 1
+
+ @pytest.mark.asyncio
+ async def test_scale_down_all_groups_fixed(self, session: AsyncSession):
+ """Test scaling down when all groups are fixed (should not terminate anything)"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "fixed-1",
+ "replicas_range": "1..1",
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING],
+ },
+ {
+ "name": "fixed-2",
+ "replicas_range": "2..2",
+ "gpu": "RTX5090",
+ "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING],
+ },
+ ],
+ )
+
+ initial_count = len(run.jobs)
+
+ # Try to scale down
+ await scale_wrapper(session, run, -1)
+
+ # No jobs should be terminated (all groups are fixed)
+ assert len(run.jobs) == initial_count
+ assert all(j.status == JobStatus.RUNNING for j in run.jobs)
+
+
+class TestReplicaGroupsScaleUp:
+ """Test scaling up with replica groups"""
+
+ @pytest.mark.asyncio
+ async def test_scale_up_selects_autoscalable_group(self, session: AsyncSession):
+ """Test that scale up only creates jobs in autoscalable groups"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "fixed-h100",
+ "replicas_range": "1..1", # Fixed
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING],
+ },
+ {
+ "name": "scalable-rtx",
+ "replicas_range": "1..3", # Autoscalable
+ "gpu": "RTX5090",
+ "initial_jobs": [JobStatus.RUNNING],
+ },
+ ],
+ )
+
+ initial_count = len(run.jobs)
+
+ # Scale up by 1
+ await scale_wrapper(session, run, 1)
+
+ # Should have one more job
+ assert len(run.jobs) == initial_count + 1
+
+ # New job should be in scalable group
+ new_jobs = [j for j in run.jobs if j.replica_num == initial_count]
+ assert len(new_jobs) == 1
+ assert new_jobs[0].replica_group_name == "scalable-rtx"
+ assert new_jobs[0].status == JobStatus.SUBMITTED
+
+ @pytest.mark.asyncio
+ async def test_scale_up_respects_group_maximums(self, session: AsyncSession):
+ """Test that scale up respects group maximums"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "small-group",
+ "replicas_range": "1..2", # Max=2
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], # At max
+ },
+ {
+ "name": "large-group",
+ "replicas_range": "1..5", # Max=5
+ "gpu": "RTX5090",
+ "initial_jobs": [JobStatus.RUNNING],
+ },
+ ],
+ )
+
+ # Try to scale up by 2
+ await scale_wrapper(session, run, 2)
+
+ # Small group should still have 2 (at max)
+ small_jobs = [j for j in run.jobs if j.replica_group_name == "small-group"]
+ assert len(small_jobs) == 2
+
+ # Large group should have grown by 2
+ large_jobs = [j for j in run.jobs if j.replica_group_name == "large-group"]
+ assert len(large_jobs) == 3
+
+ @pytest.mark.asyncio
+ async def test_scale_up_no_autoscalable_groups(self, session: AsyncSession):
+ """Test scale up does nothing when no autoscalable groups exist"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "fixed-1",
+ "replicas_range": "1..1",
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING],
+ },
+ {
+ "name": "fixed-2",
+ "replicas_range": "2..2",
+ "gpu": "RTX5090",
+ "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING],
+ },
+ ],
+ )
+
+ initial_count = len(run.jobs)
+
+ # Try to scale up
+ await scale_wrapper(session, run, 2)
+
+ # Should not have added any jobs
+ assert len(run.jobs) == initial_count
+
+ @pytest.mark.asyncio
+ async def test_scale_up_all_groups_at_max(self, session: AsyncSession):
+ """Test scale up when all autoscalable groups are at maximum"""
+ run = await make_run_with_groups(
+ session,
+ [
+ {
+ "name": "group-a",
+ "replicas_range": "1..2",
+ "gpu": "H100",
+ "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], # At max
+ },
+ {
+ "name": "group-b",
+ "replicas_range": "1..3",
+ "gpu": "RTX5090",
+ "initial_jobs": [
+ JobStatus.RUNNING,
+ JobStatus.RUNNING,
+ JobStatus.RUNNING,
+ ], # At max
+ },
+ ],
+ )
+
+ initial_count = len(run.jobs)
+
+ # Try to scale up
+ await scale_wrapper(session, run, 1)
+
+ # Should not have added any jobs (all at max)
+ assert len(run.jobs) == initial_count
+
+
+class TestReplicaGroupsBackwardCompatibility:
+ """Test backward compatibility with legacy configs"""
+
+ @pytest.mark.asyncio
+ async def test_legacy_config_scaling(self, session: AsyncSession):
+ """Test scaling works with legacy replicas configuration"""
+ project = await create_project(session=session)
+ user = await create_user(session=session)
+ repo = await create_repo(session=session, project_id=project.id)
+
+ # Use legacy format (no replica_groups)
+ profile = Profile(name="test-profile")
+ run_spec = get_run_spec(
+ repo_id=repo.name,
+ run_name="test-run",
+ profile=profile,
+ configuration=ServiceConfiguration(
+ commands=["python app.py"],
+ port=8000,
+ replicas=parse_obj_as(Range[int], "1..3"), # Legacy format
+ scaling=ScalingSpec(metric="rps", target=10),
+ ),
+ )
+
+ run = await create_run(
+ session=session,
+ project=project,
+ repo=repo,
+ user=user,
+ run_name="test-run",
+ run_spec=run_spec,
+ )
+
+ # Add initial job (no group name)
+ job = await create_job(
+ session=session,
+ run=run,
+ status=JobStatus.RUNNING,
+ replica_num=0,
+ replica_group_name=None, # Legacy jobs have no group
+ )
+ run.jobs.append(job)
+ await session.commit()
+
+ # Scale up should work
+ await scale_wrapper(session, run, 1)
+
+ # Should have 2 jobs now
+ assert len(run.jobs) == 2
+
+ # New job should have "default" group name or None
+ new_job = [j for j in run.jobs if j.replica_num == 1][0]
+ assert new_job.replica_group_name in [None, "default"]
diff --git a/src/tests/_internal/server/services/test_replica_groups_update.py b/src/tests/_internal/server/services/test_replica_groups_update.py
new file mode 100644
index 000000000..1ae7148cc
--- /dev/null
+++ b/src/tests/_internal/server/services/test_replica_groups_update.py
@@ -0,0 +1,172 @@
+"""Tests for updating services with replica groups."""
+
+from pydantic import parse_obj_as
+
+from dstack._internal.core.models.configurations import (
+ ReplicaGroup,
+ ScalingSpec,
+ ServiceConfiguration,
+)
+from dstack._internal.core.models.profiles import Profile, SpotPolicy
+from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec
+from dstack._internal.core.models.runs import RunSpec
+from dstack._internal.server.services.runs import _check_can_update_run_spec
+
+
+def test_can_update_from_replicas_to_replica_groups():
+ """Test that we can update a service from simple replicas to replica_groups."""
+ # Old config with simple replicas
+ old_config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replicas=2,
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ )
+
+ old_run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=old_config,
+ profile=Profile(name="test-profile"),
+ ssh_key_pub="ssh_key",
+ )
+
+ # New config with replica_groups
+ new_config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="h100-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ regions=["us-east-1"],
+ ),
+ ReplicaGroup(
+ name="rtx5090-group",
+ replicas=parse_obj_as(Range[int], "0..3"),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")),
+ regions=["jp-japan"],
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ new_run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=new_config,
+ profile=Profile(name="test-profile"),
+ ssh_key_pub="ssh_key",
+ )
+
+ # This should NOT raise an error
+ _check_can_update_run_spec(old_run_spec, new_run_spec)
+
+
+def test_can_update_from_replica_groups_to_replicas():
+ """Test that we can update a service from replica_groups back to simple replicas."""
+ # Old config with replica_groups
+ old_config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="h100-group",
+ replicas=parse_obj_as(Range[int], 1),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ ),
+ ],
+ )
+
+ old_run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=old_config,
+ profile=Profile(name="test-profile"),
+ ssh_key_pub="ssh_key",
+ )
+
+ # New config with simple replicas
+ new_config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replicas=2,
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")),
+ )
+
+ new_run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=new_config,
+ profile=Profile(name="test-profile"),
+ ssh_key_pub="ssh_key",
+ )
+
+ # This should NOT raise an error
+ _check_can_update_run_spec(old_run_spec, new_run_spec)
+
+
+def test_can_update_replica_groups():
+ """Test that we can update replica_groups in place."""
+ # Old config
+ old_config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="gpu-group",
+ replicas=parse_obj_as(Range[int], "1..3"),
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")),
+ regions=["us-east-1"],
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=10),
+ )
+
+ old_run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=old_config,
+ profile=Profile(name="test-profile"),
+ ssh_key_pub="ssh_key",
+ )
+
+ # New config with different replica_groups
+ new_config = ServiceConfiguration(
+ commands=["echo hello"],
+ port=8000,
+ replica_groups=[
+ ReplicaGroup(
+ name="gpu-group",
+ replicas=parse_obj_as(Range[int], "2..5"), # Changed range
+ resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), # Changed GPU
+ regions=["us-west-2"], # Changed region
+ spot_policy=SpotPolicy.SPOT, # Added spot policy
+ ),
+ ],
+ scaling=ScalingSpec(metric="rps", target=20), # Changed target
+ )
+
+ new_run_spec = RunSpec(
+ run_name="test-run",
+ repo_id="test-repo",
+ repo_data={"repo_type": "local", "repo_dir": "/repo"},
+ configuration_path="dstack.yaml",
+ configuration=new_config,
+ profile=Profile(name="test-profile"),
+ ssh_key_pub="ssh_key",
+ )
+
+ # This should NOT raise an error (replica_groups + resources + scaling are all updatable)
+ _check_can_update_run_spec(old_run_spec, new_run_spec)