diff --git a/contributing/AUTOSCALING.md b/contributing/AUTOSCALING.md index 7fa987aaf..60a33aec1 100644 --- a/contributing/AUTOSCALING.md +++ b/contributing/AUTOSCALING.md @@ -11,6 +11,8 @@ - STEP 7: `scale_run_replicas` terminates or starts replicas. - `SUBMITTED` and `PROVISIONING` replicas get terminated before `RUNNING`. - Replicas are terminated by descending `replica_num` and launched by ascending `replica_num`. + - For services with `replica_groups`, only groups with autoscaling ranges (min != max) participate in scaling. + - Scale operations respect per-group minimum and maximum constraints. ## RPSAutoscaler diff --git a/contributing/RUNS-AND-JOBS.md b/contributing/RUNS-AND-JOBS.md index b2c0430af..18544caa5 100644 --- a/contributing/RUNS-AND-JOBS.md +++ b/contributing/RUNS-AND-JOBS.md @@ -13,7 +13,7 @@ Runs are created from run configurations. There are three types of run configura 2. `task` — runs the user's bash script until completion. 3. `service` — runs the user's bash script and exposes a port through [dstack-proxy](PROXY.md). -A run can spawn one or multiple jobs, depending on the configuration. A task that specifies multiple `nodes` spawns a job for every node (a multi-node task). A service that specifies multiple `replicas` spawns a job for every replica. A job submission is always assigned to one particular instance. If a job fails and the configuration allows retrying, the server creates a new job submission for the job. +A run can spawn one or multiple jobs, depending on the configuration. A task that specifies multiple `nodes` spawns a job for every node (a multi-node task). A service that specifies multiple `replicas` or `replica_groups` spawns a job for every replica. Each job in a replica group is tagged with `replica_group_name` to track which group it belongs to. A job submission is always assigned to one particular instance. If a job fails and the configuration allows retrying, the server creates a new job submission for the job. ## Run's Lifecycle diff --git a/docs/docs/concepts/services.md b/docs/docs/concepts/services.md index cb2649e00..383738538 100644 --- a/docs/docs/concepts/services.md +++ b/docs/docs/concepts/services.md @@ -160,6 +160,66 @@ Setting the minimum number of replicas to `0` allows the service to scale down t > The `scaling` property requires creating a [gateway](gateways.md). +### Replica Groups (Advanced) + +For advanced use cases, you can define multiple **replica groups** with different instance types, resources, and configurations within a single service. This is useful when you want to: + +- Run different GPU types in the same service (e.g., H100 for primary, RTX5090 for overflow) +- Configure different backends or regions per replica type +- Set different autoscaling behavior per group + +
+ +```yaml +type: service +name: llama31-service + +python: 3.12 +env: + - HF_TOKEN +commands: + - uv pip install vllm + - vllm serve meta-llama/Meta-Llama-3.1-8B-Instruct --max-model-len 4096 +port: 8000 + +# Define multiple replica groups with different configurations +replica_groups: + - name: primary + replicas: 1 # Always 1 H100 (fixed) + resources: + gpu: H100:1 + backends: [aws] + regions: [us-west-2] + + - name: overflow + replicas: 0..5 # Autoscales 0-5 RTX5090s + resources: + gpu: RTX5090:1 + backends: [runpod] + +scaling: + metric: rps + target: 10 +``` + +
+ +In this example: + +- The `primary` group always runs 1 H100 replica on AWS (fixed, never scaled) +- The `overflow` group scales 0-5 RTX5090 replicas on RunPod based on load +- Scale operations only affect groups with autoscaling ranges (min != max) + +Each replica group can override any [profile parameter](../reference/profiles.yml.md) including `backends`, `regions`, `instance_types`, `spot_policy`, etc. Group-level settings override service-level settings. + +> **Note:** When using `replica_groups`, you cannot use the simple `replicas` field. They are mutually exclusive. + +**When to use replica groups:** + +- You need different GPU types in the same service +- Different replicas should run in different regions or clouds +- Some replicas should be fixed while others autoscale + ### Model If the service is running a chat model with an OpenAI-compatible interface, diff --git a/docs/docs/reference/dstack.yml/service.md b/docs/docs/reference/dstack.yml/service.md index 8d89b2d57..40509332e 100644 --- a/docs/docs/reference/dstack.yml/service.md +++ b/docs/docs/reference/dstack.yml/service.md @@ -10,6 +10,22 @@ The `service` configuration type allows running [services](../../concepts/servic type: required: true +### `replica_groups` + +Define multiple replica groups with different configurations within a single service. + +> **Note:** Cannot be used together with `replicas`. + +#### `replica_groups[n]` + +#SCHEMA# dstack._internal.core.models.configurations.ReplicaGroup + overrides: + show_root_heading: false + type: + required: true + +Each replica group inherits from [ProfileParams](../profiles.yml.md) and can override any profile parameter including `backends`, `regions`, `instance_types`, `spot_policy`, etc. + ### `model` { data-toc-label="model" } === "OpenAI" diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 58497c084..b61bf94fb 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -119,9 +119,55 @@ def th(s: str) -> str: if include_run_properties: props.add_row(th("Configuration"), run_spec.configuration_path) props.add_row(th("Type"), run_spec.configuration.type) - props.add_row(th("Resources"), pretty_req) - props.add_row(th("Spot policy"), spot_policy) - props.add_row(th("Max price"), max_price) + + from dstack._internal.core.models.configurations import ServiceConfiguration + + has_replica_groups = ( + include_run_properties + and isinstance(run_spec.configuration, ServiceConfiguration) + and run_spec.configuration.replica_groups + ) + + if has_replica_groups: + groups_info = [] + for group in run_spec.configuration.replica_groups: + group_parts = [f"[cyan]{group.name}[/cyan]"] + + # Replica count + if group.replicas.min == group.replicas.max: + group_parts.append(f"×{group.replicas.max}") + else: + group_parts.append(f"×{group.replicas.min}..{group.replicas.max}") + group_parts.append("[dim](autoscalable)[/dim]") + + # Resources + group_parts.append(f"[dim]({group.resources.pretty_format()})[/dim]") + + # Group-specific overrides + overrides = [] + if group.spot_policy is not None: + overrides.append(f"spot={group.spot_policy.value}") + if group.regions: + regions_str = ",".join(group.regions[:2]) # Show first 2 + if len(group.regions) > 2: + regions_str += f",+{len(group.regions) - 2}" + overrides.append(f"regions={regions_str}") + if group.backends: + backends_str = ",".join([b.value for b in group.backends[:2]]) + if len(group.backends) > 2: + backends_str += f",+{len(group.backends) - 2}" + overrides.append(f"backends={backends_str}") + + if overrides: + group_parts.append(f"[dim]({'; '.join(overrides)})[/dim]") + + groups_info.append(" ".join(group_parts)) + + props.add_row(th("Replica groups"), "\n".join(groups_info)) + else: + props.add_row(th("Resources"), pretty_req) + props.add_row(th("Spot policy"), spot_policy) + props.add_row(th("Max price"), max_price) if include_run_properties: props.add_row(th("Retry policy"), retry) props.add_row(th("Creation policy"), creation_policy) @@ -139,44 +185,172 @@ def th(s: str) -> str: offers.add_column("PRICE", style="grey58", ratio=1) offers.add_column() - job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers - - for i, offer in enumerate(job_plan.offers, start=1): - r = offer.instance.resources - - availability = "" - if offer.availability in { - InstanceAvailability.NOT_AVAILABLE, - InstanceAvailability.NO_QUOTA, - InstanceAvailability.IDLE, - InstanceAvailability.BUSY, - }: - availability = offer.availability.value.replace("_", " ").lower() - instance = offer.instance.name - if offer.total_blocks > 1: - instance += f" ({offer.blocks}/{offer.total_blocks})" - offers.add_row( - f"{i}", - offer.backend.replace("remote", "ssh") + " (" + offer.region + ")", - r.pretty_format(include_spot=True), - instance, - f"${offer.price:.4f}".rstrip("0").rstrip("."), - availability, - style=None if i == 1 or not include_run_properties else "secondary", - ) - if job_plan.total_offers > len(job_plan.offers): - offers.add_row("", "...", style="secondary") + # For replica groups, show offers from all job plans + if len(run_plan.job_plans) > 1: + # Multiple jobs - ensure fair representation of all groups + groups_with_no_offers = [] + groups_with_offers = {} + total_offers_count = 0 + + # Collect offers per group + for jp in run_plan.job_plans: + group_name = jp.job_spec.replica_group_name or "default" + if jp.total_offers == 0: + groups_with_no_offers.append(group_name) + else: + groups_with_offers[group_name] = jp.offers + total_offers_count += jp.total_offers + + # Strategy: Show at least min_per_group offers from each group, then fill with cheapest + num_groups = len(groups_with_offers) + if num_groups > 0 and max_offers: + min_per_group = max( + 1, max_offers // (num_groups * 2) + ) # At least 1, aim for ~half distribution + remaining_slots = max_offers + else: + min_per_group = None + remaining_slots = None + + selected_offers = [] + + # First pass: Take min_per_group from each group (cheapest from each) + if min_per_group: + for group_name, group_offers in groups_with_offers.items(): + sorted_group_offers = sorted(group_offers, key=lambda x: x.price) + take_count = min(min_per_group, len(sorted_group_offers), remaining_slots) + for offer in sorted_group_offers[:take_count]: + selected_offers.append((group_name, offer)) + remaining_slots -= take_count + + # Second pass: Fill remaining slots with cheapest offers globally + if remaining_slots and remaining_slots > 0: + all_remaining = [] + for group_name, group_offers in groups_with_offers.items(): + sorted_group_offers = sorted(group_offers, key=lambda x: x.price) + # Skip offers already selected + for offer in sorted_group_offers[min_per_group:]: + all_remaining.append((group_name, offer)) + + # Sort remaining by price and take the cheapest + all_remaining.sort(key=lambda x: x[1].price) + selected_offers.extend(all_remaining[:remaining_slots]) + + # If no max_offers limit, show all + if not max_offers: + selected_offers = [] + for group_name, group_offers in groups_with_offers.items(): + for offer in group_offers: + selected_offers.append((group_name, offer)) + + # Sort final selection by price for display + selected_offers.sort(key=lambda x: x[1].price) + + # Show groups with no offers FIRST + for group_name in groups_with_no_offers: + offers.add_row( + "", + f"[cyan]{group_name}[/cyan]:", + "[red]No matching instance offers available.[/red]\n" + "Possible reasons: https://dstack.ai/docs/guides/troubleshooting/#no-offers", + "", + "", + "", + style="secondary", + ) + + # Then show selected offers + for i, (group_name, offer) in enumerate(selected_offers, start=1): + r = offer.instance.resources + + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + InstanceAvailability.IDLE, + InstanceAvailability.BUSY, + }: + availability = offer.availability.value.replace("_", " ").lower() + instance = offer.instance.name + if offer.total_blocks > 1: + instance += f" ({offer.blocks}/{offer.total_blocks})" + + # Add group name prefix for multi-group display + backend_display = f"[cyan]{group_name}[/cyan]: {offer.backend.replace('remote', 'ssh')} ({offer.region})" + + offers.add_row( + f"{i}", + backend_display, + r.pretty_format(include_spot=True), + instance, + f"${offer.price:.4f}".rstrip("0").rstrip("."), + availability, + style=None if i == 1 or not include_run_properties else "secondary", + ) + + if total_offers_count > len(selected_offers): + offers.add_row("", "...", style="secondary") + else: + # Single job - original logic + job_plan.offers = job_plan.offers[:max_offers] if max_offers else job_plan.offers + + for i, offer in enumerate(job_plan.offers, start=1): + r = offer.instance.resources + + availability = "" + if offer.availability in { + InstanceAvailability.NOT_AVAILABLE, + InstanceAvailability.NO_QUOTA, + InstanceAvailability.IDLE, + InstanceAvailability.BUSY, + }: + availability = offer.availability.value.replace("_", " ").lower() + instance = offer.instance.name + if offer.total_blocks > 1: + instance += f" ({offer.blocks}/{offer.total_blocks})" + offers.add_row( + f"{i}", + offer.backend.replace("remote", "ssh") + " (" + offer.region + ")", + r.pretty_format(include_spot=True), + instance, + f"${offer.price:.4f}".rstrip("0").rstrip("."), + availability, + style=None if i == 1 or not include_run_properties else "secondary", + ) + if job_plan.total_offers > len(job_plan.offers): + offers.add_row("", "...", style="secondary") console.print(props) console.print() - if len(job_plan.offers) > 0: + + # Check if we have offers to display + has_offers = False + if len(run_plan.job_plans) > 1: + has_offers = any(len(jp.offers) > 0 for jp in run_plan.job_plans) + else: + has_offers = len(job_plan.offers) > 0 + + if has_offers: console.print(offers) - if job_plan.total_offers > len(job_plan.offers): - console.print( - f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, " - f"${job_plan.max_price:3f}".rstrip("0").rstrip(".") - + "max[/]" - ) + # Show summary for multi-job plans + if len(run_plan.job_plans) > 1: + if total_offers_count > len(selected_offers): + max_price_overall = max( + (jp.max_price for jp in run_plan.job_plans if jp.max_price), default=None + ) + if max_price_overall: + console.print( + f"[secondary] Shown {len(selected_offers)} of {total_offers_count} offers, " + f"${max_price_overall:3f}".rstrip("0").rstrip(".") + + " max[/]" + ) + else: + if job_plan.total_offers > len(job_plan.offers): + console.print( + f"[secondary] Shown {len(job_plan.offers)} of {job_plan.total_offers} offers, " + f"${job_plan.max_price:3f}".rstrip("0").rstrip(".") + + " max[/]" + ) console.print() else: console.print(NO_OFFERS_WARNING) @@ -233,8 +407,14 @@ def get_runs_table( if verbose and latest_job_submission.inactivity_secs: inactive_for = format_duration_multiunit(latest_job_submission.inactivity_secs) status += f" (inactive for {inactive_for})" + + job_name_parts = [f" replica={job.job_spec.replica_num}"] + if job.job_spec.replica_group_name: + job_name_parts.append(f"[cyan]group={job.job_spec.replica_group_name}[/cyan]") + job_name_parts.append(f"job={job.job_spec.job_num}") + job_row: Dict[Union[str, int], Any] = { - "NAME": f" replica={job.job_spec.replica_num} job={job.job_spec.job_num}" + "NAME": " ".join(job_name_parts) + ( f" deployment={latest_job_submission.deployment_num}" if show_deployment_num diff --git a/src/dstack/_internal/core/compatibility/runs.py b/src/dstack/_internal/core/compatibility/runs.py index a7a9d40a3..3c4b9c5c9 100644 --- a/src/dstack/_internal/core/compatibility/runs.py +++ b/src/dstack/_internal/core/compatibility/runs.py @@ -151,6 +151,9 @@ def get_run_spec_excludes(run_spec: RunSpec) -> IncludeExcludeDictType: configuration_excludes["schedule"] = True if profile is not None and profile.schedule is None: profile_excludes.add("schedule") + # Exclude replica_groups for backward compatibility with older servers + if isinstance(configuration, ServiceConfiguration) and configuration.replica_groups is None: + configuration_excludes["replica_groups"] = True configuration_excludes["repos"] = True if configuration_excludes: diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 6fe8132de..f687aec30 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -685,6 +685,52 @@ class TaskConfiguration( type: Literal["task"] = "task" +class ReplicaGroupConfig(ProfileParamsConfig): + @staticmethod + def schema_extra(schema: Dict[str, Any]): + ProfileParamsConfig.schema_extra(schema) + add_extra_schema_types( + schema["properties"]["replicas"], + extra_types=[{"type": "integer"}, {"type": "string"}], + ) + + +class ReplicaGroup(ProfileParams, generate_dual_core_model(ReplicaGroupConfig)): + """ + A replica group defines a set of service replicas with specific resource requirements + and provisioning parameters. + """ + + name: Annotated[str, Field(description="Group name (must be unique within the service)")] + replicas: Annotated[ + Range[int], + Field( + description="Number of replicas. Can be a fixed number (e.g., `2`) or a range (`1..3`). " + "If it's a range, the group can be autoscaled" + ), + ] + resources: Annotated[ + ResourcesSpec, + Field(description="Resource requirements for replicas in this group"), + ] + + @validator("name") + def validate_name(cls, v): + if not v or not v.strip(): + raise ValueError("Group name cannot be empty") + return v + + @validator("replicas") + def convert_replicas(cls, v: Range[int]) -> Range[int]: + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError("The minimum number of replicas must be greater than or equal to 0") + return v + + class ServiceConfigurationParamsConfig(CoreConfig): @staticmethod def schema_extra(schema: Dict[str, Any]): @@ -754,6 +800,13 @@ class ServiceConfigurationParams(CoreModel): list[ProbeConfig], Field(description="List of probes used to determine job health"), ] = [] + replica_groups: Annotated[ + Optional[List[ReplicaGroup]], + Field( + description="Define multiple replica groups with different configurations. " + "Cannot be used together with 'replicas'" + ), + ] = None @validator("port") def convert_port(cls, v) -> PortMapping: @@ -789,14 +842,46 @@ def validate_gateway( ) return v + @root_validator() + def validate_replica_groups_xor_replicas(cls, values): + replica_groups = values.get("replica_groups") + replicas = values.get("replicas") + + # Check if user specified both + has_groups = replica_groups is not None + has_replicas = replicas != Range[int](min=1, max=1) + + if has_groups and has_replicas: + raise ValueError("Cannot specify both 'replicas' and 'replica_groups'") + + if has_groups: + # Validate unique names + names = [g.name for g in replica_groups] + if len(names) != len(set(names)): + raise ValueError("Replica group names must be unique") + + # Validate at least one group + if not replica_groups: + raise ValueError("replica_groups cannot be empty") + + return values + @root_validator() def validate_scaling(cls, values): scaling = values.get("scaling") replicas = values.get("replicas") - if replicas and replicas.min != replicas.max and not scaling: + replica_groups = values.get("replica_groups") + + if replica_groups: + # Check if any group has a range + has_range = any(g.replicas.min != g.replicas.max for g in replica_groups) + if has_range and not scaling: + raise ValueError("When any replica group has a range, 'scaling' must be specified") + elif replicas and replicas.min != replicas.max and not scaling: raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - if replicas and replicas.min == replicas.max and scaling: + elif replicas and replicas.min == replicas.max and scaling: raise ValueError("To use `scaling`, `replicas` must be set to a range.") + return values @validator("rate_limits") diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 0a5b174d2..7efa83902 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -1,11 +1,14 @@ from datetime import datetime, timedelta from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional from urllib.parse import urlparse from pydantic import UUID4, Field, root_validator from typing_extensions import Annotated +if TYPE_CHECKING: + from dstack._internal.core.models.configurations import ReplicaGroup, ServiceConfiguration + from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import ( ApplyAction, @@ -247,6 +250,7 @@ class ProbeSpec(CoreModel): class JobSpec(CoreModel): replica_num: int = 0 # default value for backward compatibility + replica_group_name: Optional[str] = None job_num: int job_name: str jobs_per_replica: int = 1 # default value for backward compatibility @@ -618,3 +622,37 @@ def get_service_port(job_spec: JobSpec, configuration: ServiceConfiguration) -> if job_spec.service_port is None: return configuration.port.container_port return job_spec.service_port + + +def get_normalized_replica_groups(configuration: "ServiceConfiguration") -> List["ReplicaGroup"]: + """ + Normalize service configuration to replica groups. + Converts legacy replicas field to a single "default" group for backward compatibility. + """ + from dstack._internal.core.models.configurations import ReplicaGroup + + if configuration.replica_groups: + return configuration.replica_groups + + return [ + ReplicaGroup( + name="default", + replicas=configuration.replicas, + resources=configuration.resources, + backends=configuration.backends, + regions=configuration.regions, + availability_zones=configuration.availability_zones, + instance_types=configuration.instance_types, + reservation=configuration.reservation, + spot_policy=configuration.spot_policy, + retry=configuration.retry, + max_duration=configuration.max_duration, + stop_duration=configuration.stop_duration, + max_price=configuration.max_price, + creation_policy=configuration.creation_policy, + idle_duration=configuration.idle_duration, + utilization_policy=configuration.utilization_policy, + startup_order=configuration.startup_order, + stop_criteria=configuration.stop_criteria, + ) + ] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index df1cce72f..0be74df83 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -156,6 +156,10 @@ async def _process_run(session: AsyncSession, run_model: RunModel): ) run_model = res.unique().scalar_one() logger.debug("%s: processing run", fmt(run_model)) + + # Migrate legacy jobs without replica_group_name (one-time fix) + await _migrate_legacy_job_replica_groups(session, run_model) + try: if run_model.status == RunStatus.PENDING: await _process_pending_run(session, run_model) @@ -176,6 +180,70 @@ async def _process_run(session: AsyncSession, run_model: RunModel): await session.commit() +async def _migrate_legacy_job_replica_groups(session: AsyncSession, run_model: RunModel): + """ + Migrate jobs from old runs that don't have replica_group_name set. + This fixes jobs created before the replica_groups feature was added. + """ + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + + # Only migrate service runs with replica_groups + if run_spec.configuration.type != "service": + return + + # Check if run uses replica_groups + if not getattr(run_spec.configuration, "replica_groups", None): + return + + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + + # Check if any jobs need migration + needs_migration = any(job.replica_group_name is None for job in run_model.jobs) + + if not needs_migration: + return + + logger.info( + "%s: Migrating legacy jobs to assign replica_group_name", + fmt(run_model), + ) + + # Build a map of replica_num -> group_name based on how jobs were originally created + replica_num_to_group = {} + current_replica_num = 0 + + for group in normalized_groups: + group_min = group.replicas.min or 0 + for _ in range(group_min): + replica_num_to_group[current_replica_num] = group.name + current_replica_num += 1 + + # Update jobs + migrated_count = 0 + for job in run_model.jobs: + if job.replica_group_name is None: + expected_group = replica_num_to_group.get(job.replica_num) + if expected_group: + job.replica_group_name = expected_group + migrated_count += 1 + logger.info( + "%s: Migrated job replica_num=%d to group '%s'", + fmt(run_model), + job.replica_num, + expected_group, + ) + + if migrated_count > 0: + await session.commit() + logger.info( + "%s: Migrated %d job(s) to replica groups", + fmt(run_model), + migrated_count, + ) + + async def _process_pending_run(session: AsyncSession, run_model: RunModel): """Jobs are not created yet""" run = run_model_to_run(run_model) @@ -189,7 +257,10 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": - run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run.run_spec.configuration) + run_model.desired_replica_count = sum(g.replicas.min or 0 for g in normalized_groups) await update_service_desired_replica_count( session, run_model, @@ -481,6 +552,7 @@ async def _handle_run_replicas( session, run_model, replicas_diff=max_replica_count - non_terminated_replica_count, + allow_exceeding_max=True, # Allow exceeding max for rolling deployments ) replicas_to_stop_count = 0 @@ -507,6 +579,7 @@ async def _handle_run_replicas( session, run_model, replicas_diff=-replicas_to_stop_count, + allow_exceeding_max=True, # Allow terminating out-of-date replicas during rolling deployment ) @@ -516,20 +589,49 @@ async def _update_jobs_to_new_deployment_in_place( """ Bump deployment_num for jobs that do not require redeployment. """ + from dstack._internal.core.models.runs import get_normalized_replica_groups + secrets = await get_project_secrets_mapping( session=session, project=run_model.project, ) + + # Get replica groups from the new run_spec for matching + replica_group = None + if run_spec.configuration.type == "service": + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + else: + normalized_groups = [] + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + + # Determine which replica group this job belongs to + # Use the old job's replica_group_name to find the matching group in new spec + old_job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + if old_job_spec.replica_group_name and normalized_groups: + replica_group = next( + (g for g in normalized_groups if g.name == old_job_spec.replica_group_name), + None, + ) + if replica_group is None: + logger.warning( + "Replica group '%s' from old job not found in new run_spec. " + "Job will use base configuration.", + old_job_spec.replica_group_name, + ) + else: + replica_group = None + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group=replica_group, ) assert len(new_job_specs) == len(job_models), ( "Changing the number of jobs within a replica is not yet supported" diff --git a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py index 2814840b5..ea0422808 100644 --- a/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py +++ b/src/dstack/_internal/server/background/tasks/process_submitted_jobs.py @@ -609,11 +609,15 @@ def _get_nodes_required_num_for_run(run_spec: RunSpec) -> int: nodes_required_num = 1 if run_spec.configuration.type == "task": nodes_required_num = run_spec.configuration.nodes - elif ( - run_spec.configuration.type == "service" - and run_spec.configuration.replicas.min is not None - ): - nodes_required_num = run_spec.configuration.replicas.min + elif run_spec.configuration.type == "service": + # Use groups if present + if run_spec.configuration.replica_groups: + from dstack._internal.core.models.runs import get_normalized_replica_groups + + groups = get_normalized_replica_groups(run_spec.configuration) + nodes_required_num = sum(g.replicas.min or 0 for g in groups) + elif run_spec.configuration.replicas.min is not None: + nodes_required_num = run_spec.configuration.replicas.min return nodes_required_num @@ -728,6 +732,62 @@ async def _assign_job_to_fleet_instance( return instance +def _get_profile_for_job(run_spec: RunSpec, job: Job) -> Profile: + """Get merged profile with group overrides for this job.""" + from dstack._internal.core.models.profiles import Profile, ProfileParams + from dstack._internal.core.models.runs import get_normalized_replica_groups + + base_profile = run_spec.merged_profile + + group_name = job.job_spec.replica_group_name + logger.info( + "Getting profile for job %s: replica_group_name=%s, config_type=%s", + job.job_spec.job_name, + group_name, + run_spec.configuration.type, + ) + + if not group_name or run_spec.configuration.type != "service": + logger.info("Using base profile (no group_name or not a service)") + return base_profile + + # Find the group + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + logger.info( + "Normalized groups: %s", + [f"{g.name} (regions={g.regions})" for g in normalized_groups], + ) + group = next((g for g in normalized_groups if g.name == group_name), None) + + if not group: + logger.warning( + "Replica group '%s' not found in run_spec. Available groups: %s", + group_name, + [g.name for g in normalized_groups], + ) + return base_profile + + # Merge: group overrides base + merged = Profile.parse_obj(base_profile.dict()) + for field_name in ProfileParams.__fields__: + group_value = getattr(group, field_name, None) + if group_value is not None: + setattr(merged, field_name, group_value) + + logger.info( + "Profile for group '%s': regions=%s, backends=%s, spot_policy=%s (base had: regions=%s, backends=%s, spot_policy=%s)", + group_name, + merged.regions, + [b.value for b in merged.backends] if merged.backends else None, + merged.spot_policy, + base_profile.regions, + [b.value for b in base_profile.backends] if base_profile.backends else None, + base_profile.spot_policy, + ) + + return merged + + async def _run_job_on_new_instance( project: ProjectModel, job_model: JobModel, @@ -741,8 +801,31 @@ async def _run_job_on_new_instance( ) -> Optional[tuple[JobProvisioningData, InstanceOfferWithAvailability, Profile, Requirements]]: if volumes is None: volumes = [] - profile = run.run_spec.merged_profile - requirements = job.job_spec.requirements + profile = _get_profile_for_job(run.run_spec, job) + requirements = job.job_spec.requirements # Already has group resources baked in + + # Debug logging for replica groups + replica_group_name = job.job_spec.replica_group_name + if replica_group_name: + logger.debug( + "%s: Provisioning replica group '%s' with profile: regions=%s, backends=%s, spot_policy=%s", + fmt(job_model), + replica_group_name, + profile.regions, + [b.value for b in profile.backends] if profile.backends else None, + profile.spot_policy, + ) + gpu_req = ( + requirements.resources.gpu.name + if requirements.resources and requirements.resources.gpu + else None + ) + logger.debug( + "%s: GPU requirements for group '%s': %s", + fmt(job_model), + replica_group_name, + gpu_req, + ) fleet = None if fleet_model is not None: fleet = fleet_model_to_fleet(fleet_model) @@ -761,6 +844,21 @@ async def _run_job_on_new_instance( multinode = job.job_spec.jobs_per_replica > 1 or ( fleet is not None and fleet.spec.configuration.placement == InstanceGroupPlacement.CLUSTER ) + + # Log the requirements and profile being used + gpu_requirement = ( + requirements.resources.gpu.name + if requirements.resources and requirements.resources.gpu + else None + ) + logger.info( + "%s: Fetching offers with GPU=%s, regions=%s, backends=%s", + fmt(job_model), + gpu_requirement, + profile.regions, + [b.value for b in profile.backends] if profile.backends else None, + ) + offers = await get_offers_by_requirements( project=project, profile=profile, @@ -772,6 +870,14 @@ async def _run_job_on_new_instance( privileged=job.job_spec.privileged, instance_mounts=check_run_spec_requires_instance_mounts(run.run_spec), ) + + # Debug logging for offers + logger.info( + "%s: Got %d offers. First 3: %s", + fmt(job_model), + len(offers), + [f"{o.instance.name} ({o.backend.value}/{o.region})" for _, o in offers[:3]], + ) # Limit number of offers tried to prevent long-running processing # in case all offers fail. for backend, offer in offers[: settings.MAX_OFFERS_TRIED]: @@ -822,7 +928,9 @@ def _get_run_profile_and_requirements_in_fleet( run_spec: RunSpec, fleet: Fleet, ) -> tuple[Profile, Requirements]: - profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, run_spec.merged_profile) + # Use group-merged profile instead of run-level + job_profile = _get_profile_for_job(run_spec, job) + profile = combine_fleet_and_run_profiles(fleet.spec.merged_profile, job_profile) if profile is None: raise ValueError("Cannot combine fleet profile") fleet_requirements = get_fleet_requirements(fleet.spec) diff --git a/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py new file mode 100644 index 000000000..a1d9e3eaf --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/a1b2c3d4e5f6_add_jobmodel_replica_group_name.py @@ -0,0 +1,26 @@ +"""Add JobModel.replica_group_name + +Revision ID: a1b2c3d4e5f6 +Revises: ff1d94f65b08 +Create Date: 2025-10-17 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a1b2c3d4e5f6" +down_revision = "ff1d94f65b08" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column(sa.Column("replica_group_name", sa.String(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("replica_group_name") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 31f44d369..d23a99e28 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -437,6 +437,7 @@ class JobModel(BaseModel): instance: Mapped[Optional["InstanceModel"]] = relationship(back_populates="jobs") used_instance_id: Mapped[Optional[uuid.UUID]] = mapped_column(UUIDType(binary=False)) replica_num: Mapped[int] = mapped_column(Integer) + replica_group_name: Mapped[Optional[str]] = mapped_column(String, nullable=True) deployment_num: Mapped[int] = mapped_column(Integer) job_runtime_data: Mapped[Optional[str]] = mapped_column(Text) probes: Mapped[list["ProbeModel"]] = relationship( diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index cbb089b2c..1f9f66580 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -1,9 +1,12 @@ import itertools import json from datetime import timedelta -from typing import Dict, Iterable, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from uuid import UUID +if TYPE_CHECKING: + from dstack._internal.core.models.configurations import ReplicaGroup + import requests from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -67,7 +70,10 @@ async def get_jobs_from_run_spec( - run_spec: RunSpec, secrets: Dict[str, str], replica_num: int + run_spec: RunSpec, + secrets: Dict[str, str], + replica_num: int, + replica_group: Optional["ReplicaGroup"] = None, ) -> List[Job]: return [ Job(job_spec=s, job_submissions=[]) @@ -75,14 +81,19 @@ async def get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group=replica_group, ) ] async def get_job_specs_from_run_spec( - run_spec: RunSpec, secrets: Dict[str, str], replica_num: int + run_spec: RunSpec, + secrets: Dict[str, str], + replica_num: int, + replica_group: Optional["ReplicaGroup"] = None, ) -> List[JobSpec]: job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets) + job_configurator.replica_group = replica_group job_specs = await job_configurator.get_job_specs(replica_num=replica_num) return job_specs diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 02cdc70b3..b79c2a15e 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -3,7 +3,11 @@ import threading from abc import ABC, abstractmethod from pathlib import PurePosixPath -from typing import Dict, List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional + +if TYPE_CHECKING: + from dstack._internal.core.models.configurations import ReplicaGroup + from dstack._internal.core.models.profiles import Profile from cachetools import TTLCache, cached @@ -84,6 +88,7 @@ class JobConfigurator(ABC): _image_config: Optional[ImageConfig] = None # JobSSHKey should be shared for all jobs in a replica for inter-node communication. _job_ssh_key: Optional[JobSSHKey] = None + replica_group: Optional["ReplicaGroup"] = None def __init__( self, @@ -146,6 +151,7 @@ async def _get_job_spec( ) -> JobSpec: job_spec = JobSpec( replica_num=replica_num, # TODO(egor-s): add to env variables in the runner + replica_group_name=self.replica_group.name if self.replica_group else None, job_num=job_num, job_name=f"{self.run_spec.run_name}-{job_num}-{replica_num}", jobs_per_replica=jobs_per_replica, @@ -295,13 +301,40 @@ def _utilization_policy(self) -> Optional[UtilizationPolicy]: def _registry_auth(self) -> Optional[RegistryAuth]: return self.run_spec.configuration.registry_auth + def _get_merged_profile(self) -> "Profile": + """Get profile with group overrides applied.""" + from dstack._internal.core.models.profiles import Profile, ProfileParams + + base = self.run_spec.merged_profile + + if not self.replica_group: + return base + + # Clone and apply group overrides + merged = Profile.parse_obj(base.dict()) + for field_name in ProfileParams.__fields__: + group_value = getattr(self.replica_group, field_name, None) + if group_value is not None: + setattr(merged, field_name, group_value) + + return merged + def _requirements(self) -> Requirements: - spot_policy = self._spot_policy() + # Use group resources if available, else fall back to config + if self.replica_group: + resources = self.replica_group.resources + else: + resources = self.run_spec.configuration.resources + + # Get merged profile for spot/price/reservation + profile = self._get_merged_profile() + spot_policy = profile.spot_policy or SpotPolicy.ONDEMAND + return Requirements( - resources=self.run_spec.configuration.resources, - max_price=self.run_spec.merged_profile.max_price, + resources=resources, + max_price=profile.max_price, spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), - reservation=self.run_spec.merged_profile.reservation, + reservation=profile.reservation, ) def _retry(self) -> Optional[Retry]: diff --git a/src/dstack/_internal/server/services/runs.py b/src/dstack/_internal/server/services/runs.py index 25ac750aa..ffc7cbc14 100644 --- a/src/dstack/_internal/server/services/runs.py +++ b/src/dstack/_internal/server/services/runs.py @@ -30,6 +30,8 @@ ) from dstack._internal.core.models.profiles import ( CreationPolicy, + Profile, + ProfileParams, RetryEvent, ) from dstack._internal.core.models.repos.virtual import DEFAULT_VIRTUAL_REPO_ID, VirtualRunRepoData @@ -302,6 +304,45 @@ async def get_run_by_id( return run_model_to_run(run_model, return_in_api=True) +def _get_job_profile(run_spec: RunSpec, replica_group_name: Optional[str]) -> Profile: + """ + Get the profile for a job, including replica group overrides if applicable. + + Args: + run_spec: The run specification + replica_group_name: Name of the replica group, or None for legacy jobs + + Returns: + Profile with replica group overrides applied + """ + base_profile = run_spec.merged_profile + + # If no replica group, return base profile + if not replica_group_name: + return base_profile + + # Find the replica group + if run_spec.configuration.type != "service": + return base_profile + + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + replica_group = next((g for g in normalized_groups if g.name == replica_group_name), None) + + if not replica_group: + return base_profile + + # Clone base profile and apply group overrides + merged = Profile.parse_obj(base_profile.dict()) + for field_name in ProfileParams.__fields__: + group_value = getattr(replica_group, field_name, None) + if group_value is not None: + setattr(merged, field_name, group_value) + + return merged + + async def get_plan( session: AsyncSession, project: ProjectModel, @@ -340,11 +381,34 @@ async def get_plan( action = ApplyAction.UPDATE secrets = await get_project_secrets_mapping(session=session, project=project) - jobs = await get_jobs_from_run_spec( - run_spec=effective_run_spec, - secrets=secrets, - replica_num=0, - ) + + # For services with replica groups, create jobs for all groups during planning + jobs = [] + if ( + effective_run_spec.configuration.type == "service" + and effective_run_spec.configuration.replica_groups + ): + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(effective_run_spec.configuration) + replica_num = 0 + for group in normalized_groups: + # Create one job per group for planning (minimum replicas) + group_jobs = await get_jobs_from_run_spec( + run_spec=effective_run_spec, + secrets=secrets, + replica_num=replica_num, + replica_group=group, + ) + jobs.extend(group_jobs) + replica_num += 1 + else: + # Legacy: single job for planning + jobs = await get_jobs_from_run_spec( + run_spec=effective_run_spec, + secrets=secrets, + replica_num=0, + ) volumes = await get_job_configured_volumes( session=session, @@ -353,19 +417,47 @@ async def get_plan( job_num=0, ) - pool_offers = await _get_pool_offers( - session=session, - project=project, - run_spec=effective_run_spec, - job=jobs[0], - volumes=volumes, - ) + # For replica groups, we need pool offers for all GPU types, not just the first job's type + # So we fetch pool offers for each job separately and aggregate them + all_pool_offers = [] + if len(jobs) > 1: + # Multiple jobs (likely replica groups) - get pool offers per job to include all GPU types + for job in jobs: + job_pool_offers = await _get_pool_offers( + session=session, + project=project, + run_spec=effective_run_spec, + job=job, + volumes=volumes, + ) + all_pool_offers.extend(job_pool_offers) + # Deduplicate by (backend, instance_name, region) tuple + seen_offers = set() + pool_offers = [] + for offer in all_pool_offers: + offer_key = (offer.backend, offer.instance.name, offer.region) + if offer_key not in seen_offers: + seen_offers.add(offer_key) + pool_offers.append(offer) + else: + pool_offers = await _get_pool_offers( + session=session, + project=project, + run_spec=effective_run_spec, + job=jobs[0], + volumes=volumes, + ) effective_run_spec.run_name = "dry-run" # will regenerate jobs on submission - # Get offers once for all jobs - offers = [] - if creation_policy == CreationPolicy.REUSE_OR_CREATE: - offers = await get_offers_by_requirements( + # Check if all jobs have identical requirements (optimization for single-type jobs) + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + + # Get offers once if all jobs are identical, otherwise get per-job + shared_offers = [] + if creation_policy == CreationPolicy.REUSE_OR_CREATE and all_requirements_identical: + shared_offers = await get_offers_by_requirements( project=project, profile=profile, requirements=jobs[0].job_spec.requirements, @@ -379,8 +471,46 @@ async def get_plan( job_plans = [] for job in jobs: job_offers: List[InstanceOfferWithAvailability] = [] - job_offers.extend(pool_offers) - job_offers.extend(offer for _, offer in offers) + + # Filter pool offers to match this job's GPU requirements + gpu_req = None + if job.job_spec.requirements.resources and job.job_spec.requirements.resources.gpu: + gpu_req = job.job_spec.requirements.resources.gpu.name + + matching_pool_offers = [] + for pool_offer in pool_offers: + offer_gpus = pool_offer.instance.resources.gpus + if offer_gpus and gpu_req: + # Check if offer's GPU matches job's requirement + offer_gpu_names = [gpu.name for gpu in offer_gpus] + if any(req_gpu in offer_gpu_names for req_gpu in gpu_req): + matching_pool_offers.append(pool_offer) + elif not gpu_req: + # No GPU requirement, include all pool offers + matching_pool_offers.append(pool_offer) + + job_offers.extend(matching_pool_offers) + + # Get the correct profile for this job (with replica group overrides if applicable) + job_profile = _get_job_profile(effective_run_spec, job.job_spec.replica_group_name) + + # Use shared offers if all jobs are identical, otherwise fetch per-job + if shared_offers: + job_offers.extend(offer for _, offer in shared_offers) + elif creation_policy == CreationPolicy.REUSE_OR_CREATE: + # Fetch offers specific to this job's requirements with job-specific profile + job_specific_offers = await get_offers_by_requirements( + project=project, + profile=job_profile, + requirements=job.job_spec.requirements, + exclude_not_available=False, + multinode=job.job_spec.jobs_per_replica > 1, + volumes=volumes, + privileged=job.job_spec.privileged, + instance_mounts=check_run_spec_requires_instance_mounts(effective_run_spec), + ) + job_offers.extend(offer for _, offer in job_specific_offers) + job_offers.sort(key=lambda offer: not offer.availability.is_available()) job_spec = job.job_spec @@ -557,19 +687,47 @@ async def submit_run( if run_spec.configuration.type == "service": await services.register_service(session, run_model, run_spec) - for replica_num in range(initial_replicas): - jobs = await get_jobs_from_run_spec( - run_spec=run_spec, - secrets=secrets, - replica_num=replica_num, - ) - for job in jobs: - job_model = create_job_model_for_new_submission( - run_model=run_model, - job=job, - status=JobStatus.SUBMITTED, + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + + # Set initial desired count (sum of all group minimums) + run_model.desired_replica_count = sum(g.replicas.min or 0 for g in normalized_groups) + + # Create jobs by iterating over groups + replica_num = 0 # Global counter across all groups + for group in normalized_groups: + group_min = group.replicas.min or 0 + for _ in range(group_min): + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, + replica_group=group, # Pass group context + ) + for job in jobs: + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) + replica_num += 1 + else: + # Non-service runs (tasks, dev environments) + for replica_num in range(initial_replicas): + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=replica_num, ) - session.add(job_model) + for job in jobs: + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) await session.commit() await session.refresh(run_model) @@ -591,6 +749,7 @@ def create_job_model_for_new_submission( job_num=job.job_spec.job_num, job_name=f"{job.job_spec.job_name}", replica_num=job.job_spec.replica_num, + replica_group_name=job.job_spec.replica_group_name, deployment_num=run_model.deployment_num, submission_num=len(job.job_submissions), submitted_at=now, @@ -1017,10 +1176,20 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s" ) if isinstance(run_spec.configuration, ServiceConfiguration): - if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0: - raise ServerClientError( - "Scheduled services with autoscaling to zero are not supported" - ) + # Check all groups for min=0 with schedule + if run_spec.merged_profile.schedule: + if run_spec.configuration.replica_groups: + from dstack._internal.core.models.runs import get_normalized_replica_groups + + groups = get_normalized_replica_groups(run_spec.configuration) + if any(g.replicas.min == 0 for g in groups): + raise ServerClientError( + "Scheduled services with autoscaling to zero are not supported" + ) + elif run_spec.configuration.replicas.min == 0: + raise ServerClientError( + "Scheduled services with autoscaling to zero are not supported" + ) if len(run_spec.configuration.probes) > settings.MAX_PROBES_PER_JOB: raise ServerClientError( f"Cannot configure more than {settings.MAX_PROBES_PER_JOB} probes" @@ -1058,6 +1227,7 @@ def _validate_run_spec_and_set_defaults(user: UserModel, run_spec: RunSpec): "service": [ # in-place "replicas", + "replica_groups", # Named replica groups (mutually exclusive with replicas) "scaling", # rolling deployment # NOTE: keep this list in sync with the "Rolling deployment" section in services.md @@ -1186,7 +1356,21 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel): ) -async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int): +async def scale_run_replicas( + session: AsyncSession, + run_model: RunModel, + replicas_diff: int, + allow_exceeding_max: bool = False, +): + """ + Scale run replicas up or down. + + Args: + session: Database session + run_model: The run to scale + replicas_diff: Number of replicas to add (positive) or remove (negative) + allow_exceeding_max: If True, allow scaling beyond configured max (for rolling deployments) + """ if replicas_diff == 0: # nothing to do return @@ -1226,38 +1410,144 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = ( + get_normalized_replica_groups(run_spec.configuration) + if run_spec.configuration.type == "service" + else [] + ) + if replicas_diff < 0: - for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): - # scale down the less important replicas first + # SCALE DOWN: Only terminate from autoscalable groups while respecting group minimums + autoscalable_groups = { + g.name for g in normalized_groups if g.replicas.min != g.replicas.max + } + + # Count replicas per group + group_counts = {} + for _, _, _, replica_jobs in active_replicas: + if replica_jobs: + group_name = replica_jobs[0].replica_group_name or "default" + group_counts[group_name] = group_counts.get(group_name, 0) + 1 + + # Get group minimums + group_mins = {g.name: g.replicas.min for g in normalized_groups} + + # Terminate from end (reversed) + # For rolling deployments (allow_exceeding_max), prioritize terminating out-of-date replicas + terminated_count = 0 + for _, is_out_of_date, _, replica_jobs in reversed(active_replicas): + if terminated_count >= abs(replicas_diff): + break + + if not replica_jobs: + continue + + group_name = replica_jobs[0].replica_group_name or "default" + + # For rolling deployment, allow terminating any out-of-date replica + if allow_exceeding_max and is_out_of_date: + # Terminate this replica (out-of-date during rolling deployment) + for job in replica_jobs: + if not job.status.is_finished() and job.status != JobStatus.TERMINATING: + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + group_counts[group_name] -= 1 + terminated_count += 1 + continue + + # For normal scaling, skip if not autoscalable + if normalized_groups and group_name not in autoscalable_groups: + continue + + # Skip if at minimum + current_count = group_counts.get(group_name, 0) + min_count = group_mins.get(group_name, 0) + if current_count <= min_count: + continue + + # Terminate this replica for job in replica_jobs: - if job.status.is_finished() or job.status == JobStatus.TERMINATING: - continue - job.status = JobStatus.TERMINATING - job.termination_reason = JobTerminationReason.SCALED_DOWN - # background task will process the job later + if not job.status.is_finished() and job.status != JobStatus.TERMINATING: + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + group_counts[group_name] -= 1 + terminated_count += 1 else: + # SCALE UP + # Count current replicas per group + group_counts = {} + for _, _, _, replica_jobs in active_replicas: + if replica_jobs: + group_name = replica_jobs[0].replica_group_name or "default" + group_counts[group_name] = group_counts.get(group_name, 0) + 1 + + # First, identify groups below minimum (need to scale regardless of autoscalability) + below_min_groups = [ + g for g in normalized_groups if group_counts.get(g.name, 0) < (g.replicas.min or 0) + ] + + # Then, identify autoscalable groups that can scale beyond minimum + autoscalable_groups = [ + g + for g in normalized_groups + if g.replicas.min != g.replicas.max + and ( + allow_exceeding_max + or group_counts.get(g.name, 0) < (g.replicas.max or float("inf")) + ) + ] + + # Eligible groups are: below-min groups + autoscalable groups + eligible_groups = [] + if below_min_groups: + eligible_groups.extend(below_min_groups) + elif autoscalable_groups: + # Only use autoscalable groups if no groups are below minimum + eligible_groups.extend(autoscalable_groups) + elif allow_exceeding_max and normalized_groups: + # For rolling deployments, allow exceeding max even for fixed groups + eligible_groups.extend(normalized_groups) + + if normalized_groups and not eligible_groups: + # All groups at their limits + logger.info("%s: all replica groups at their limits (min/max)", fmt(run_model)) + return + scheduled_replicas = 0 - # rerun inactive replicas + # Reuse inactive replicas first (existing logic) for _, _, _, replica_jobs in inactive_replicas: if scheduled_replicas == replicas_diff: break - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) - scheduled_replicas += 1 + # Only reuse if from eligible group + if replica_jobs: + group_name = replica_jobs[0].replica_group_name or "default" + if not normalized_groups or group_name in {g.name for g in eligible_groups}: + await retry_run_replica_jobs( + session, run_model, replica_jobs, only_failed=False + ) + scheduled_replicas += 1 - secrets = await get_project_secrets_mapping( - session=session, - project=run_model.project, - ) + # Create new replicas for remaining diff + secrets = await get_project_secrets_mapping(session=session, project=run_model.project) + + for _ in range(replicas_diff - scheduled_replicas): + # Pick group for new replica + # v1: Simple heuristic - pick first eligible group (round-robin in future) + selected_group = eligible_groups[0] if eligible_groups else None + + replica_num = len(active_replicas) + scheduled_replicas - for replica_num in range( - len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff - ): # FIXME: Handle getting image configuration errors or skip it. jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group=selected_group, ) for job in jobs: job_model = create_job_model_for_new_submission( @@ -1267,6 +1557,24 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica ) session.add(job_model) + # Update count + if selected_group: + group_counts[selected_group.name] = group_counts.get(selected_group.name, 0) + 1 + scheduled_replicas += 1 + + # Remove from eligible if at max + if group_counts[selected_group.name] >= ( + selected_group.replicas.max or float("inf") + ): + eligible_groups = [g for g in eligible_groups if g.name != selected_group.name] + if not eligible_groups: + logger.info( + "%s: all eligible groups reached maximum capacity", fmt(run_model) + ) + break + else: + scheduled_replicas += 1 + async def retry_run_replica_jobs( session: AsyncSession, run_model: RunModel, latest_jobs: List[JobModel], *, only_failed: bool @@ -1276,10 +1584,29 @@ async def retry_run_replica_jobs( session=session, project=run_model.project, ) + + # Determine which replica group this job belongs to + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + replica_group = None + if run_spec.configuration.type == "service" and latest_jobs: + from dstack._internal.core.models.runs import get_normalized_replica_groups + + group_name = latest_jobs[0].replica_group_name + if group_name: + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + replica_group = next((g for g in normalized_groups if g.name == group_name), None) + if replica_group: + logger.info( + "%s: retrying job from replica group '%s'", + fmt(run_model), + replica_group.name, + ) + new_jobs = await get_jobs_from_run_spec( - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=run_spec, secrets=secrets, replica_num=latest_jobs[0].replica_num, + replica_group=replica_group, ) assert len(new_jobs) == len(latest_jobs), ( "Changing the number of jobs within a replica is not yet supported" diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index cd6d06e58..6573d1f7f 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -120,18 +120,29 @@ def get_desired_count( def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: - assert conf.replicas.min is not None - assert conf.replicas.max is not None + # Compute bounds from groups if present + if conf.replica_groups: + from dstack._internal.core.models.runs import get_normalized_replica_groups + + groups = get_normalized_replica_groups(conf) + min_replicas = sum(g.replicas.min or 0 for g in groups) + max_replicas = sum(g.replicas.max or 0 for g in groups) + else: + assert conf.replicas.min is not None + assert conf.replicas.max is not None + min_replicas = conf.replicas.min + max_replicas = conf.replicas.max + if conf.scaling is None: return ManualScaler( - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=min_replicas, + max_replicas=max_replicas, ) if conf.scaling.metric == "rps": return RPSAutoscaler( # replicas count validated by configuration model - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=min_replicas, + max_replicas=max_replicas, target=conf.scaling.target, scale_up_delay=conf.scaling.scale_up_delay, scale_down_delay=conf.scaling.scale_down_delay, diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 0a8adaa42..6b1d42a10 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -345,6 +345,7 @@ async def create_job( instance: Optional[InstanceModel] = None, job_num: int = 0, replica_num: int = 0, + replica_group_name: Optional[str] = None, deployment_num: Optional[int] = None, instance_assigned: bool = False, disconnected_at: Optional[datetime] = None, @@ -353,8 +354,19 @@ async def create_job( if deployment_num is None: deployment_num = run.deployment_num run_spec = RunSpec.parse_raw(run.run_spec) + + # Look up replica group if specified + replica_group = None + if replica_group_name and run_spec.configuration.type == "service": + from dstack._internal.core.models.runs import get_normalized_replica_groups + + normalized_groups = get_normalized_replica_groups(run_spec.configuration) + replica_group = next((g for g in normalized_groups if g.name == replica_group_name), None) + job_spec = ( - await get_job_specs_from_run_spec(run_spec=run_spec, secrets={}, replica_num=replica_num) + await get_job_specs_from_run_spec( + run_spec=run_spec, secrets={}, replica_num=replica_num, replica_group=replica_group + ) )[0] job_spec.job_num = job_num job = JobModel( @@ -365,6 +377,7 @@ async def create_job( job_num=job_num, job_name=run.run_name + f"-{job_num}-{replica_num}", replica_num=replica_num, + replica_group_name=replica_group_name, deployment_num=deployment_num, submission_num=submission_num, submitted_at=submitted_at, diff --git a/src/tests/_internal/cli/utils/test_run_plan_display.py b/src/tests/_internal/cli/utils/test_run_plan_display.py new file mode 100644 index 000000000..8ec6bb4fb --- /dev/null +++ b/src/tests/_internal/cli/utils/test_run_plan_display.py @@ -0,0 +1,756 @@ +"""Test CLI display of run plans with replica groups.""" + +from dstack._internal.cli.utils.run import print_run_plan +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.instances import ( + Gpu, + InstanceAvailability, + InstanceType, + Resources, +) +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.repos import LocalRunRepoData +from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.runs import ( + ApplyAction, + InstanceOfferWithAvailability, + JobPlan, + JobSpec, + Requirements, + RunPlan, + RunSpec, +) + + +def create_test_offer( + backend: BackendType, + gpu_name: str, + price: float, + region: str = "us-east", + availability: InstanceAvailability = InstanceAvailability.AVAILABLE, +) -> InstanceOfferWithAvailability: + """Helper to create test offers.""" + return InstanceOfferWithAvailability( + backend=backend, + instance=InstanceType( + name=f"{gpu_name.lower()}-instance", + resources=Resources( + cpus=8, + memory_mib=16384, + gpus=[Gpu(name=gpu_name, memory_mib=40960)], + spot=False, + ), + ), + region=region, + price=price, + availability=availability, + ) + + +class TestReplicaGroupsDisplayInCLI: + """Test that replica groups are properly displayed in CLI output.""" + + def test_multiple_replica_groups_show_group_names(self, capsys): + """CLI should prefix offers with group names when multiple job plans exist.""" + # Create a service with 2 replica groups + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "l40s-group", + "replicas": "1", + "resources": {"gpu": {"name": "L40S", "count": 1}}, + }, + { + "name": "a100-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100", "count": 1}}, + }, + ], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.VASTAI]), + ) + + # Create job plans for each group + l40s_offer = create_test_offer(BackendType.VASTAI, "L40S", 0.50) + a100_offer = create_test_offer(BackendType.VASTAI, "A100", 1.20) + + job_plan_l40s = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="l40s-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "L40S", "count": 1}) + ), + ), + offers=[l40s_offer], + total_offers=1, + max_price=0.50, + ) + + job_plan_a100 = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="a100-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}) + ), + ), + offers=[a100_offer], + total_offers=1, + max_price=1.20, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_l40s, job_plan_a100], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify group names are in the output + assert "l40s-group" in output, "l40s-group name should appear in output" + assert "a100-group" in output, "a100-group name should appear in output" + + # Verify both GPU types are shown + assert "L40S" in output or "l40s" in output.lower() + assert "A100" in output or "a100" in output.lower() + + def test_single_job_plan_no_group_prefix(self, capsys): + """CLI should NOT prefix offers when only one job plan exists (legacy).""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replicas=Range[int](min=1, max=1), + resources=ResourcesSpec(gpu={"name": "V100", "count": 1}), + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + v100_offer = create_test_offer(BackendType.AWS, "V100", 0.80) + + job_plan = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="default", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "V100", "count": 1}) + ), + ), + offers=[v100_offer], + total_offers=1, + max_price=0.80, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify NO group prefix (legacy mode) + assert "default:" not in output, "Legacy mode should not show group prefix" + # But should show backend normally + assert "aws" in output.lower() + + def test_replica_groups_offers_sorted_by_price(self, capsys): + """Offers from multiple groups should be sorted by price across all groups.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "expensive-group", + "replicas": "1", + "resources": {"gpu": {"name": "H100", "count": 1}}, + }, + { + "name": "cheap-group", + "replicas": "1", + "resources": {"gpu": {"name": "T4", "count": 1}}, + }, + ], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + # Expensive offer + h100_offer = create_test_offer(BackendType.AWS, "H100", 3.00) + # Cheap offer + t4_offer = create_test_offer(BackendType.AWS, "T4", 0.30) + + job_plan_expensive = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="expensive-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=[h100_offer], + total_offers=1, + max_price=3.00, + ) + + job_plan_cheap = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="cheap-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements(resources=ResourcesSpec(gpu={"name": "T4", "count": 1})), + ), + offers=[t4_offer], + total_offers=1, + max_price=0.30, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_expensive, job_plan_cheap], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Split output to find the offers table (after the header section) + lines = output.split("\n") + + # Find lines that contain both a number and a group name (these are offer rows) + offer_rows = [ + line + for line in lines + if ("cheap-group:" in line or "expensive-group:" in line) + and line.strip().startswith(("1", "2", "3")) + ] + + # The first offer row should be cheap-group (lower price) + assert len(offer_rows) >= 2, "Should have at least 2 offer rows" + assert "cheap-group:" in offer_rows[0], ( + "First offer should be cheap-group (sorted by price)" + ) + assert "expensive-group:" in offer_rows[1], "Second offer should be expensive-group" + assert "$0.3" in output # Price displayed as $0.3 + assert "$3" in output # Price displayed as $3 + + def test_replica_group_with_no_offers_shows_message(self, capsys): + """Replica groups with no available offers should show a message.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "available-group", + "replicas": "1", + "resources": {"gpu": {"name": "L40S", "count": 1}}, + }, + { + "name": "unavailable-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100", "count": 1}}, + }, + ], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.VASTAI]), + ) + + # One group has offers, another doesn't + l40s_offer = create_test_offer(BackendType.VASTAI, "L40S", 0.50) + + job_plan_with_offers = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="available-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "L40S", "count": 1}) + ), + ), + offers=[l40s_offer], + total_offers=1, + max_price=0.50, + ) + + job_plan_no_offers = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="unavailable-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}) + ), + ), + offers=[], # No offers + total_offers=0, + max_price=0.0, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_with_offers, job_plan_no_offers], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify available group shows offer + assert "available-group:" in output + assert "L40S" in output + + # Verify unavailable group shows the standard "no offers" message + # (Message may be wrapped across lines in table display) + assert "unavailable-group:" in output + assert "No matching instance" in output + assert "offers available" in output + assert "Possible reasons:" in output + assert "dstack.ai/docs" in output # URL may be truncated in table + + # Verify unavailable group appears BEFORE available group (at top) + unavailable_pos = output.find("unavailable-group:") + available_pos = output.find("available-group:") + assert unavailable_pos < available_pos, "Group with no offers should appear first" + + +class TestReplicaGroupsFairOfferDistribution: + """Test that CLI displays offers from all replica groups fairly.""" + + def test_all_groups_represented_in_display(self, capsys): + """Test that offers from all replica groups are shown when max_offers is set.""" + # Create offers for three groups with different price ranges + h100_offers = [ + create_test_offer(BackendType.AWS, "H100", 3.0, region="us-east"), + create_test_offer(BackendType.AWS, "H100", 3.5, region="us-west"), + create_test_offer(BackendType.GCP, "H100", 4.0, region="eu-west"), + ] + + rtx5090_offers = [ + create_test_offer(BackendType.VASTAI, "RTX5090", 0.5, region="us"), + create_test_offer(BackendType.VASTAI, "RTX5090", 0.6, region="eu"), + ] + + a100_offers = [ + create_test_offer(BackendType.AWS, "A100", 2.0, region="us-east"), + create_test_offer(BackendType.GCP, "A100", 2.2, region="eu-west"), + ] + + # Create job plans for each group + job_plan_h100 = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="h100-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=h100_offers, + total_offers=len(h100_offers), + max_price=4.0, + ) + + job_plan_rtx5090 = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="rtx5090-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1}) + ), + ), + offers=rtx5090_offers, + total_offers=len(rtx5090_offers), + max_price=0.6, + ) + + job_plan_a100 = JobPlan( + job_spec=JobSpec( + replica_num=2, + replica_group_name="a100-group", + job_num=2, + job_name="test-job-2", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}) + ), + ), + offers=a100_offers, + total_offers=len(a100_offers), + max_price=2.2, + ) + + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_h100, job_plan_rtx5090, job_plan_a100], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print with max_offers=5 (should show at least 1 from each group) + print_run_plan(run_plan, max_offers=5, include_run_properties=True) + + captured = capsys.readouterr() + output = captured.out + + # Verify all three groups appear in the output + assert "h100-group:" in output, "H100 group should be displayed" + assert "rtx5090-group:" in output, "RTX5090 group should be displayed" + assert "a100-group:" in output, "A100 group should be displayed" + + # Verify GPUs are shown + assert "H100" in output + assert "RTX5090" in output + assert "A100" in output + + def test_fair_distribution_with_limited_slots(self, capsys): + """Test that when max_offers is limited, all groups get fair representation.""" + # Group 1: Many cheap offers + cheap_offers = [ + create_test_offer(BackendType.VASTAI, "RTX5090", 0.4 + i * 0.1, region="us") + for i in range(10) + ] + + # Group 2: Few expensive offers + expensive_offers = [ + create_test_offer(BackendType.AWS, "H100", 3.0 + i * 0.5, region="us-east") + for i in range(3) + ] + + job_plan_cheap = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="cheap-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1}) + ), + ), + offers=cheap_offers, + total_offers=len(cheap_offers), + max_price=1.4, + ) + + job_plan_expensive = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="expensive-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=expensive_offers, + total_offers=len(expensive_offers), + max_price=4.0, + ) + + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.VASTAI]), + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_cheap, job_plan_expensive], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print with max_offers=4 (should show at least 1 from each group) + print_run_plan(run_plan, max_offers=4, include_run_properties=True) + + captured = capsys.readouterr() + output = captured.out + + # Both groups should be represented + assert "cheap-group:" in output + assert "expensive-group:" in output + + # Count occurrences (rough check - both should appear) + cheap_count = output.count("cheap-group:") + expensive_count = output.count("expensive-group:") + + # Both should have at least one offer shown + assert cheap_count >= 1, "Cheap group should have at least one offer" + assert expensive_count >= 1, "Expensive group should have at least one offer" + + +class TestReplicaGroupsProfileOverridesDisplay: + """Test that CLI correctly displays profile overrides for replica groups.""" + + def test_shows_group_specific_spot_policy_and_regions(self, capsys): + """Test that group-specific spot_policy, regions, backends are displayed.""" + + from dstack._internal.core.models.backends.base import BackendType + + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "h100-group", + "replicas": "1", + "resources": {"gpu": {"name": "H100", "count": 1}}, + "spot_policy": "spot", + "regions": ["us-east-1", "us-west-2"], + "backends": ["aws"], + }, + { + "name": "rtx5090-group", + "replicas": "0..5", + "resources": {"gpu": {"name": "RTX5090", "count": 1}}, + "spot_policy": "on-demand", + "regions": ["jp-japan"], + "backends": ["vastai", "runpod"], + }, + ], + scaling={"metric": "rps", "target": 10}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + configuration_path=".dstack.yml", + profile=Profile(backends=[BackendType.AWS]), + ) + + # Create job plans + job_plan_h100 = JobPlan( + job_spec=JobSpec( + replica_num=0, + replica_group_name="h100-group", + job_num=0, + job_name="test-job-0", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}) + ), + ), + offers=[create_test_offer(BackendType.AWS, "H100", 3.0)], + total_offers=1, + max_price=3.0, + ) + + job_plan_rtx = JobPlan( + job_spec=JobSpec( + replica_num=1, + replica_group_name="rtx5090-group", + job_num=1, + job_name="test-job-1", + image_name="dstackai/base", + commands=["echo test"], + env={}, + working_dir="/workflow", + requirements=Requirements( + resources=ResourcesSpec(gpu={"name": "RTX5090", "count": 1}) + ), + ), + offers=[create_test_offer(BackendType.VASTAI, "RTX5090", 0.5)], + total_offers=1, + max_price=0.5, + ) + + run_plan = RunPlan( + project_name="test-project", + user="test-user", + run_spec=run_spec, + effective_run_spec=run_spec, + job_plans=[job_plan_h100, job_plan_rtx], + current_resource=None, + action=ApplyAction.CREATE, + ) + + # Print the plan + print_run_plan(run_plan, max_offers=10, include_run_properties=True) + + # Capture output + captured = capsys.readouterr() + output = captured.out + + # Verify group-specific overrides are shown + assert "h100-group" in output + assert "spot=spot" in output # H100 group's spot policy + assert "regions=us-east-1,us-west-2" in output # H100 group's regions + assert "backends=aws" in output # H100 group's backend + + assert "rtx5090-group" in output + assert "spot=on-demand" in output # RTX5090 group's spot policy + assert "regions=jp-japan" in output # RTX5090 group's region + assert "backends=vastai,runpod" in output # RTX5090 group's backends + + # Verify service-level "Spot policy" row is NOT shown (misleading with groups) + lines = output.split("\n") + spot_policy_lines = [line for line in lines if line.strip().startswith("Spot policy")] + assert len(spot_policy_lines) == 0, ( + "Service-level 'Spot policy' should not be shown with replica_groups" + ) diff --git a/src/tests/_internal/core/models/test_replica_groups.py b/src/tests/_internal/core/models/test_replica_groups.py new file mode 100644 index 000000000..7c0a35a08 --- /dev/null +++ b/src/tests/_internal/core/models/test_replica_groups.py @@ -0,0 +1,437 @@ +"""Tests for Named Replica Groups functionality""" + +import pytest + +from dstack._internal.core.errors import ConfigurationError +from dstack._internal.core.models.configurations import ( + ServiceConfiguration, + parse_run_configuration, +) +from dstack._internal.core.models.resources import Range +from dstack._internal.core.models.runs import get_normalized_replica_groups + + +class TestReplicaGroupConfiguration: + """Test replica group configuration parsing and validation""" + + def test_basic_replica_groups(self): + """Test basic replica groups configuration""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "h100-group", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "rtx5090-group", + "replicas": 2, + "resources": {"gpu": "RTX5090:1"}, + }, + ], + } + + parsed = parse_run_configuration(conf) + assert isinstance(parsed, ServiceConfiguration) + assert parsed.replica_groups is not None + assert len(parsed.replica_groups) == 2 + + # Check first group + assert parsed.replica_groups[0].name == "h100-group" + assert parsed.replica_groups[0].replicas == Range(min=1, max=1) + assert parsed.replica_groups[0].resources.gpu.name == ["H100"] + + # Check second group + assert parsed.replica_groups[1].name == "rtx5090-group" + assert parsed.replica_groups[1].replicas == Range(min=2, max=2) + assert parsed.replica_groups[1].resources.gpu.name == ["RTX5090"] + + def test_replica_groups_with_ranges(self): + """Test replica groups with autoscaling ranges""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "fixed-group", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "scalable-group", + "replicas": "1..3", # Range + "resources": {"gpu": "RTX5090:1"}, + }, + ], + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + assert parsed.replica_groups is not None + assert len(parsed.replica_groups) == 2 + + # Fixed group + assert parsed.replica_groups[0].replicas == Range(min=1, max=1) + + # Scalable group + assert parsed.replica_groups[1].replicas == Range(min=1, max=3) + + def test_replica_groups_with_profile_params(self): + """Test replica groups can override profile parameters""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + # Service-level settings + "backends": ["aws"], + "regions": ["us-west-2"], + "replica_groups": [ + { + "name": "aws-group", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + # Inherits backends/regions from service + }, + { + "name": "runpod-group", + "replicas": 1, + "resources": {"gpu": "RTX5090:1"}, + # Override backends + "backends": ["runpod"], + "regions": ["eu-west-1"], + }, + ], + } + + parsed = parse_run_configuration(conf) + + # First group inherits from service (doesn't specify backends/regions) + assert parsed.replica_groups[0].backends is None + assert parsed.replica_groups[0].regions is None + + # Second group overrides + assert parsed.replica_groups[1].backends == ["runpod"] + assert parsed.replica_groups[1].regions == ["eu-west-1"] + + def test_replica_groups_xor_replicas(self): + """Test that replica_groups and replicas are mutually exclusive""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 2, # Old format + "replica_groups": [ # New format + { + "name": "group1", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + } + ], + } + + with pytest.raises( + ConfigurationError, + match="Cannot specify both 'replicas' and 'replica_groups'", + ): + parse_run_configuration(conf) + + def test_replica_groups_unique_names(self): + """Test that replica group names must be unique""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "group1", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "group1", # Duplicate! + "replicas": 1, + "resources": {"gpu": "RTX5090:1"}, + }, + ], + } + + with pytest.raises( + ConfigurationError, + match="Replica group names must be unique", + ): + parse_run_configuration(conf) + + def test_replica_groups_empty_name(self): + """Test that replica group names cannot be empty""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "", # Empty name + "replicas": 1, + "resources": {"gpu": "H100:1"}, + } + ], + } + + with pytest.raises( + ConfigurationError, + match="Group name cannot be empty", + ): + parse_run_configuration(conf) + + def test_replica_groups_range_requires_scaling(self): + """Test that replica ranges require scaling configuration""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "scalable-group", + "replicas": "1..3", + "resources": {"gpu": "RTX5090:1"}, + } + ], + # Missing scaling! + } + + with pytest.raises( + ConfigurationError, + match="When any replica group has a range, 'scaling' must be specified", + ): + parse_run_configuration(conf) + + def test_replica_groups_cannot_be_empty(self): + """Test that replica_groups list cannot be empty""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [], # Empty list + } + + with pytest.raises( + ConfigurationError, + match="replica_groups cannot be empty", + ): + parse_run_configuration(conf) + + +class TestReplicaGroupNormalization: + """Test get_normalized_replica_groups helper""" + + def test_normalize_new_format(self): + """Test normalization with replica_groups format""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "group1", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "group2", + "replicas": 2, + "resources": {"gpu": "RTX5090:1"}, + }, + ], + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + assert len(normalized) == 2 + assert normalized[0].name == "group1" + assert normalized[1].name == "group2" + + def test_normalize_legacy_format(self): + """Test normalization converts legacy replicas to default group""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 3, + "resources": {"gpu": "H100:1"}, + "backends": ["aws"], + "regions": ["us-west-2"], + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + # Should create single "default" group + assert len(normalized) == 1 + assert normalized[0].name == "default" + assert normalized[0].replicas == Range(min=3, max=3) + assert normalized[0].resources.gpu.name == ["H100"] + + # Should inherit profile params + assert normalized[0].backends == ["aws"] + assert normalized[0].regions == ["us-west-2"] + + def test_normalize_legacy_with_range(self): + """Test normalization with legacy autoscaling""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": "1..5", + "resources": {"gpu": "RTX5090:1"}, + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + assert len(normalized) == 1 + assert normalized[0].name == "default" + assert normalized[0].replicas == Range(min=1, max=5) + + +class TestReplicaGroupAutoscaling: + """Test autoscaling behavior with replica groups""" + + def test_autoscalable_group_detection(self): + """Test identifying which groups are autoscalable""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "fixed", + "replicas": 1, + "resources": {"gpu": "H100:1"}, + }, + { + "name": "scalable", + "replicas": "1..3", + "resources": {"gpu": "RTX5090:1"}, + }, + ], + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + + # Fixed group: min == max + assert parsed.replica_groups[0].replicas.min == parsed.replica_groups[0].replicas.max + + # Scalable group: min != max + assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max + + def test_multiple_autoscalable_groups(self): + """Test multiple groups can be autoscalable""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replica_groups": [ + { + "name": "scalable-1", + "replicas": "1..3", + "resources": {"gpu": "H100:1"}, + }, + { + "name": "scalable-2", + "replicas": "2..5", + "resources": {"gpu": "RTX5090:1"}, + }, + ], + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + + # Both are autoscalable + assert parsed.replica_groups[0].replicas.min != parsed.replica_groups[0].replicas.max + assert parsed.replica_groups[1].replicas.min != parsed.replica_groups[1].replicas.max + + +class TestBackwardCompatibility: + """Test backward compatibility with existing configurations""" + + def test_legacy_service_config(self): + """Test that legacy service configs still work""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 2, + "resources": {"gpu": "A100:1"}, + } + + parsed = parse_run_configuration(conf) + + # Should parse successfully + assert isinstance(parsed, ServiceConfiguration) + assert parsed.replicas == Range(min=2, max=2) + assert parsed.replica_groups is None # Not using new format + + def test_legacy_autoscaling_config(self): + """Test legacy autoscaling configurations""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": "0..5", + "resources": {"gpu": "A100:1"}, + "scaling": { + "metric": "rps", + "target": 10, + }, + } + + parsed = parse_run_configuration(conf) + + # Should parse successfully + assert parsed.replicas == Range(min=0, max=5) + assert parsed.scaling is not None + + def test_normalization_preserves_all_profile_params(self): + """Test that normalization copies all ProfileParams fields""" + conf = { + "type": "service", + "commands": ["python3 app.py"], + "port": 8000, + "replicas": 1, + "resources": {"gpu": "H100:1"}, + "backends": ["aws"], + "regions": ["us-east-1"], + "instance_types": ["p4d.24xlarge"], + "spot_policy": "spot", + "max_price": 10.0, + } + + parsed = parse_run_configuration(conf) + normalized = get_normalized_replica_groups(parsed) + + # Check all fields are copied + group = normalized[0] + assert group.backends == ["aws"] + assert group.regions == ["us-east-1"] + assert group.instance_types == ["p4d.24xlarge"] + assert group.spot_policy == "spot" + assert group.max_price == 10.0 diff --git a/src/tests/_internal/core/test_backward_compatibility.py b/src/tests/_internal/core/test_backward_compatibility.py new file mode 100644 index 000000000..ca49f12a3 --- /dev/null +++ b/src/tests/_internal/core/test_backward_compatibility.py @@ -0,0 +1,127 @@ +"""Test backward compatibility for replica_groups with older servers.""" + +from dstack._internal.core.compatibility.runs import get_get_plan_excludes, get_run_spec_excludes +from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.repos import LocalRunRepoData +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.schemas.runs import GetRunPlanRequest + + +class TestReplicaGroupsBackwardCompatibility: + """Test that replica_groups field is excluded when None for backward compatibility.""" + + def test_replica_groups_excluded_when_none(self): + """replica_groups should be excluded from JSON when None.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replicas={"min": 1, "max": 1}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + # Get excludes + excludes = get_run_spec_excludes(run_spec) + + # replica_groups should be in excludes + assert "configuration" in excludes + assert "replica_groups" in excludes["configuration"] + assert excludes["configuration"]["replica_groups"] is True + + def test_replica_groups_not_excluded_when_set(self): + """replica_groups should NOT be excluded when set.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "gpu-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100"}}, + } + ], + scaling={"metric": "rps", "target": 10}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + # Get excludes + excludes = get_run_spec_excludes(run_spec) + + # replica_groups should NOT be in excludes (or be False) + if "configuration" in excludes and "replica_groups" in excludes["configuration"]: + assert excludes["configuration"]["replica_groups"] is not True + + def test_get_plan_request_serialization_without_replica_groups(self): + """GetRunPlanRequest should not include replica_groups in JSON when None.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replicas={"min": 1, "max": 1}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + request = GetRunPlanRequest(run_spec=run_spec, max_offers=None) + excludes = get_get_plan_excludes(request) + + # Serialize with excludes + json_str = request.json(exclude=excludes) + + # replica_groups should not appear in JSON + assert "replica_groups" not in json_str + + def test_get_plan_request_serialization_with_replica_groups(self): + """GetRunPlanRequest should include replica_groups in JSON when set.""" + config = ServiceConfiguration( + type="service", + port=8000, + commands=["echo test"], + replica_groups=[ + { + "name": "gpu-group", + "replicas": "1", + "resources": {"gpu": {"name": "A100"}}, + } + ], + scaling={"metric": "rps", "target": 10}, + ) + + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data=LocalRunRepoData(repo_dir="/tmp"), + configuration=config, + profile=None, + ) + + request = GetRunPlanRequest(run_spec=run_spec, max_offers=None) + excludes = get_get_plan_excludes(request) + + # Serialize with excludes + json_str = request.json(exclude=excludes) + + # replica_groups SHOULD appear in JSON + assert "replica_groups" in json_str + assert "gpu-group" in json_str diff --git a/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py b/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py new file mode 100644 index 000000000..3e9e2e3d5 --- /dev/null +++ b/src/tests/_internal/server/background/tasks/test_migrate_legacy_jobs.py @@ -0,0 +1,242 @@ +""" +Tests for migrating legacy jobs without replica_group_name. +""" + +import pytest + +from dstack._internal.core.models.configurations import ( + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import Range, ResourcesSpec +from dstack._internal.core.models.runs import JobStatus, RunSpec +from dstack._internal.server.background.tasks.process_runs import ( + _migrate_legacy_job_replica_groups, +) +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, +) + + +class TestMigrateLegacyJobs: + @pytest.mark.asyncio + async def test_migrates_jobs_without_replica_group_name( + self, test_db, session, socket_enabled + ): + """Test that jobs without replica_group_name get migrated correctly.""" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + # Create a run with replica_groups configuration + service_config = ServiceConfiguration( + replica_groups=[ + { + "name": "h100-gpu", + "replicas": Range(min=1, max=1), + "resources": ResourcesSpec(gpu="H100:1"), + }, + { + "name": "rtx5090-gpu", + "replicas": Range(min=1, max=1), + "resources": ResourcesSpec(gpu="RTX5090:1"), + }, + ], + commands=["echo hello"], + port=8000, + ) + + run_spec = RunSpec( + run_name="test-service", + repo_id="test-repo", + configuration=service_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-service", + run_spec=run_spec, + ) + + # Create jobs WITHOUT replica_group_name (simulating old code) + job1 = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name=None, # Old job without group + status=JobStatus.RUNNING, + ) + + job2 = await create_job( + session=session, + run=run, + replica_num=1, + replica_group_name=None, # Old job without group + status=JobStatus.RUNNING, + ) + + # Verify jobs have no group + assert job1.replica_group_name is None + assert job2.replica_group_name is None + + # Refresh run to load jobs relationship + await session.refresh(run, ["jobs"]) + + # Run migration + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job1) + await session.refresh(job2) + + # Verify jobs now have correct groups + assert job1.replica_group_name == "h100-gpu" + assert job2.replica_group_name == "rtx5090-gpu" + + @pytest.mark.asyncio + async def test_skips_already_migrated_jobs(self, test_db, session): + """Test that jobs with replica_group_name are not re-migrated.""" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + service_config = ServiceConfiguration( + replica_groups=[ + { + "name": "gpu-group", + "replicas": Range(min=1, max=1), + "resources": ResourcesSpec(gpu="A100:1"), + }, + ], + commands=["echo hello"], + port=8000, + ) + + run_spec = RunSpec( + run_name="test-service", + repo_id="test-repo", + configuration=service_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-service", + run_spec=run_spec, + ) + + # Create job WITH replica_group_name (already migrated) + job = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name="gpu-group", + status=JobStatus.RUNNING, + ) + + original_group = job.replica_group_name + + # Run migration (should be a no-op) + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job) + + # Verify group unchanged + assert job.replica_group_name == original_group + + @pytest.mark.asyncio + async def test_skips_non_service_runs(self, test_db, session): + """Test that non-service runs are skipped.""" + from dstack._internal.core.models.configurations import TaskConfiguration + + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + task_config = TaskConfiguration( + commands=["echo hello"], + ) + + run_spec = RunSpec( + run_name="test-task", + repo_id="test-repo", + configuration=task_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-task", + run_spec=run_spec, + ) + + job = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name=None, + status=JobStatus.RUNNING, + ) + + # Run migration (should skip task runs) + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job) + + # Verify no change + assert job.replica_group_name is None + + @pytest.mark.asyncio + async def test_skips_legacy_replicas_config(self, test_db, session): + """Test that runs using legacy 'replicas' (not replica_groups) are skipped.""" + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + repo = await create_repo(session=session, project_id=project.id) + + # Use legacy replicas configuration + service_config = ServiceConfiguration( + replicas=Range(min=2, max=2), + commands=["echo hello"], + port=8000, + ) + + run_spec = RunSpec( + run_name="test-service", + repo_id="test-repo", + configuration=service_config, + merged_profile=Profile(), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-service", + run_spec=run_spec, + ) + + job = await create_job( + session=session, + run=run, + replica_num=0, + replica_group_name=None, + status=JobStatus.RUNNING, + ) + + # Run migration (should skip legacy replicas) + await _migrate_legacy_job_replica_groups(session, run) + await session.refresh(job) + + # Verify no change (legacy replicas don't use replica_group_name) + assert job.replica_group_name is None diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index f4e481f53..42758111c 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -233,6 +233,7 @@ def get_dev_env_run_plan_dict( "privileged": True if docker else privileged, "job_name": f"{run_name}-0-0", "replica_num": 0, + "replica_group_name": None, "job_num": 0, "jobs_per_replica": 1, "single_branch": False, @@ -441,6 +442,7 @@ def get_dev_env_run_dict( "privileged": True if docker else privileged, "job_name": f"{run_name}-0-0", "replica_num": 0, + "replica_group_name": None, "job_num": 0, "jobs_per_replica": 1, "single_branch": False, diff --git a/src/tests/_internal/server/services/test_get_plan_replica_groups.py b/src/tests/_internal/server/services/test_get_plan_replica_groups.py new file mode 100644 index 000000000..84a549656 --- /dev/null +++ b/src/tests/_internal/server/services/test_get_plan_replica_groups.py @@ -0,0 +1,206 @@ +"""Test get_plan() offer fetching logic for replica groups.""" + +from dstack._internal.core.models.resources import ResourcesSpec +from dstack._internal.core.models.runs import Requirements + + +class TestGetPlanOfferFetchingLogic: + """Test the logic for determining when to fetch offers per-job vs. shared.""" + + def test_requirements_equality_check(self): + """Test that Requirements objects can be compared for equality.""" + # Identical requirements + req1 = Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}), + ) + req2 = Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 1}), + ) + assert req1 == req2 + + # Different GPU names + req3 = Requirements( + resources=ResourcesSpec(gpu={"name": "H100", "count": 1}), + ) + assert req1 != req3 + + # Different GPU counts + req4 = Requirements( + resources=ResourcesSpec(gpu={"name": "A100", "count": 2}), + ) + assert req1 != req4 + + def test_identical_requirements_detection_logic(self): + """Test logic for detecting when all jobs have identical requirements.""" + + # Simulate job specs with requirements + class MockJobSpec: + def __init__(self, gpu_name: str, gpu_count: int = 1): + self.requirements = Requirements( + resources=ResourcesSpec(gpu={"name": gpu_name, "count": gpu_count}), + ) + + class MockJob: + def __init__(self, gpu_name: str, gpu_count: int = 1): + self.job_spec = MockJobSpec(gpu_name, gpu_count) + + # Test 1: All identical + jobs = [MockJob("A100"), MockJob("A100"), MockJob("A100")] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is True + + # Test 2: Different GPU types + jobs = [MockJob("A100"), MockJob("H100"), MockJob("A100")] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is False + + # Test 3: Different GPU counts + jobs = [MockJob("A100", 1), MockJob("A100", 2)] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is False + + # Test 4: Single job (always identical) + jobs = [MockJob("V100")] + all_requirements_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_requirements_identical is True + + def test_offer_fetch_decision_logic(self): + """Test the decision logic for when to use shared vs per-job offer fetching.""" + + class MockJobSpec: + def __init__(self, gpu_name: str): + self.requirements = Requirements( + resources=ResourcesSpec(gpu={"name": gpu_name, "count": 1}), + ) + + class MockJob: + def __init__(self, group_name: str, gpu_name: str): + self.job_spec = MockJobSpec(gpu_name) + self.job_spec.replica_group_name = group_name + + # Scenario 1: Replica groups with different GPUs -> per-job fetch + jobs = [ + MockJob("l40s-group", "L40S"), + MockJob("rtx4080-group", "RTX4080"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_identical is False, "Different GPU types should trigger per-job offer fetch" + + # Scenario 2: Replica groups with same GPU -> shared fetch (optimization) + jobs = [ + MockJob("group-1", "A100"), + MockJob("group-2", "A100"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_identical is True, "Identical GPUs should use shared offer fetch" + + # Scenario 3: Legacy replicas (same requirements) -> shared fetch + jobs = [ + MockJob("default", "V100"), + MockJob("default", "V100"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_identical is True, "Legacy replicas with same GPU should use shared fetch" + + # Scenario 4: Mixed groups (2 same + 1 different) -> per-job fetch + jobs = [ + MockJob("a100-group-1", "A100"), + MockJob("h100-group", "H100"), + MockJob("a100-group-2", "A100"), + ] + all_identical = all( + job.job_spec.requirements == jobs[0].job_spec.requirements for job in jobs + ) + assert all_identical is False, "Mix of different GPUs should trigger per-job fetch for all" + + +class TestReplicaGroupOfferSearchIntegration: + """Integration tests for replica group offer search behavior.""" + + def test_different_gpu_types_creates_different_requirements(self): + """Different replica group GPU types should create different Requirements objects.""" + # This tests the data model behavior that get_plan() relies on + req_l40s = Requirements( + resources=ResourcesSpec( + gpu={"name": "L40S", "count": 1}, + ) + ) + + req_rtx4080 = Requirements( + resources=ResourcesSpec( + gpu={"name": "RTX4080", "count": 1}, + ) + ) + + # These should NOT be equal + assert req_l40s != req_rtx4080 + + # Verify the GPU names are different + assert req_l40s.resources.gpu.name != req_rtx4080.resources.gpu.name + + def test_identical_gpu_types_creates_identical_requirements(self): + """Identical replica group GPU types should create equal Requirements objects.""" + req_a = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "count": 1}, + ) + ) + + req_b = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "count": 1}, + ) + ) + + # These SHOULD be equal (enables optimization) + assert req_a == req_b + + def test_requirements_with_different_memory(self): + """Requirements with different GPU memory should not be equal.""" + req_16gb = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "memory": "16GB", "count": 1}, + ) + ) + + req_40gb = Requirements( + resources=ResourcesSpec( + gpu={"name": "A100", "memory": "40GB", "count": 1}, + ) + ) + + # Different memory specifications + assert req_16gb != req_40gb + + def test_requirements_with_different_cpu_specs(self): + """Requirements with different CPU specs should not be equal.""" + req_low_cpu = Requirements( + resources=ResourcesSpec( + cpu={"min": 2}, + gpu={"name": "A100", "count": 1}, + ) + ) + + req_high_cpu = Requirements( + resources=ResourcesSpec( + cpu={"min": 16}, + gpu={"name": "A100", "count": 1}, + ) + ) + + # Different CPU requirements + assert req_low_cpu != req_high_cpu diff --git a/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py b/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py new file mode 100644 index 000000000..65bf86ccf --- /dev/null +++ b/src/tests/_internal/server/services/test_replica_groups_profile_overrides.py @@ -0,0 +1,265 @@ +"""Tests for replica group profile overrides (regions, spot_policy, etc.)""" + +import pytest +from pydantic import parse_obj_as + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import ( + ReplicaGroup, + ScalingSpec, + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile, SpotPolicy +from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.services.runs import _get_job_profile + +pytestmark = pytest.mark.usefixtures("image_config_mock") + + +def test_spot_policy_override_per_group(): + """Test that each replica group can have its own spot_policy.""" + # Create a service with different spot policies per group + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="on-demand-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + spot_policy=SpotPolicy.ONDEMAND, + ), + ReplicaGroup( + name="spot-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + spot_policy=SpotPolicy.SPOT, + ), + ReplicaGroup( + name="auto-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), + spot_policy=SpotPolicy.AUTO, + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile(name="test-profile", spot_policy=SpotPolicy.AUTO) # base policy + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + # Test on-demand group + on_demand_profile = _get_job_profile(run_spec, "on-demand-group") + assert on_demand_profile.spot_policy == SpotPolicy.ONDEMAND + + # Test spot group + spot_profile = _get_job_profile(run_spec, "spot-group") + assert spot_profile.spot_policy == SpotPolicy.SPOT + + # Test auto group + auto_profile = _get_job_profile(run_spec, "auto-group") + assert auto_profile.spot_policy == SpotPolicy.AUTO + + # Test legacy (no group) uses base profile + legacy_profile = _get_job_profile(run_spec, None) + assert legacy_profile.spot_policy == SpotPolicy.AUTO + + +def test_regions_override_per_group(): + """Test that each replica group can have its own regions.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="us-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-east-1", "us-west-2"], + ), + ReplicaGroup( + name="eu-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + regions=["eu-west-1", "eu-central-1"], + ), + ReplicaGroup( + name="asia-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), + regions=["ap-northeast-1"], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile(name="test-profile", regions=["us-east-1"]) # base regions + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + # Test US group + us_profile = _get_job_profile(run_spec, "us-group") + assert us_profile.regions == ["us-east-1", "us-west-2"] + + # Test EU group + eu_profile = _get_job_profile(run_spec, "eu-group") + assert eu_profile.regions == ["eu-west-1", "eu-central-1"] + + # Test Asia group + asia_profile = _get_job_profile(run_spec, "asia-group") + assert asia_profile.regions == ["ap-northeast-1"] + + # Test legacy (no group) uses base profile + legacy_profile = _get_job_profile(run_spec, None) + assert legacy_profile.regions == ["us-east-1"] + + +def test_backends_override_per_group(): + """Test that each replica group can have its own backends.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="aws-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + backends=[BackendType.AWS], + ), + ReplicaGroup( + name="vastai-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + backends=[BackendType.VASTAI], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile( + name="test-profile", + backends=[BackendType.AWS, BackendType.GCP], # base backends + ) + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + # Test AWS group + aws_profile = _get_job_profile(run_spec, "aws-group") + assert aws_profile.backends == [BackendType.AWS] + + # Test VastAI group + vastai_profile = _get_job_profile(run_spec, "vastai-group") + assert vastai_profile.backends == [BackendType.VASTAI] + + # Test legacy (no group) uses base profile + legacy_profile = _get_job_profile(run_spec, None) + assert legacy_profile.backends == [BackendType.AWS, BackendType.GCP] + + +def test_multiple_profile_overrides_per_group(): + """Test that a replica group can override multiple profile parameters at once.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="specialized-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-west-2"], + backends=[BackendType.AWS], + spot_policy=SpotPolicy.ONDEMAND, + max_price=5.0, + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile( + name="test-profile", + regions=["us-east-1"], + backends=[BackendType.GCP], + spot_policy=SpotPolicy.SPOT, + max_price=1.0, + ) + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + specialized_profile = _get_job_profile(run_spec, "specialized-group") + assert specialized_profile.regions == ["us-west-2"] + assert specialized_profile.backends == [BackendType.AWS] + assert specialized_profile.spot_policy == SpotPolicy.ONDEMAND + assert specialized_profile.max_price == 5.0 + + +def test_partial_profile_override(): + """Test that only specified profile parameters are overridden, others inherit from base.""" + config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="partial-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-west-2"], # Only override regions + # spot_policy, backends, max_price should inherit from base + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + profile = Profile( + name="test-profile", + regions=["us-east-1"], + backends=[BackendType.GCP, BackendType.AWS], + spot_policy=SpotPolicy.SPOT, + max_price=2.5, + ) + run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=config, + profile=profile, + ssh_key_pub="ssh_key", + ) + + partial_profile = _get_job_profile(run_spec, "partial-group") + # Overridden + assert partial_profile.regions == ["us-west-2"] + # Inherited from base + assert partial_profile.backends == [BackendType.GCP, BackendType.AWS] + assert partial_profile.spot_policy == SpotPolicy.SPOT + assert partial_profile.max_price == 2.5 diff --git a/src/tests/_internal/server/services/test_replica_groups_scaling.py b/src/tests/_internal/server/services/test_replica_groups_scaling.py new file mode 100644 index 000000000..9cea6193b --- /dev/null +++ b/src/tests/_internal/server/services/test_replica_groups_scaling.py @@ -0,0 +1,395 @@ +"""Integration tests for replica groups scaling functionality""" + +from typing import List + +import pytest +from pydantic import parse_obj_as +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from dstack._internal.core.models.configurations import ( + ReplicaGroup, + ScalingSpec, + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile +from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import JobStatus, JobTerminationReason +from dstack._internal.server.models import RunModel +from dstack._internal.server.services.runs import scale_run_replicas +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, + get_run_spec, +) + +pytestmark = pytest.mark.usefixtures("image_config_mock") + + +async def scale_wrapper(session: AsyncSession, run: RunModel, diff: int): + """Wrapper that handles commit and refresh like existing tests""" + await scale_run_replicas(session, run, diff) + await session.commit() + await session.refresh(run) + + +async def make_run_with_groups( + session: AsyncSession, + groups_config: List[dict], # List of {name, replicas_range, gpu, initial_jobs} +) -> RunModel: + """Helper to create a run with replica groups""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + + # Build replica groups + replica_groups = [] + for group_cfg in groups_config: + replica_groups.append( + ReplicaGroup( + name=group_cfg["name"], + replicas=parse_obj_as(Range[int], group_cfg["replicas_range"]), + resources=ResourcesSpec(gpu=GPUSpec(name=[group_cfg["gpu"]], count=1)), + ) + ) + + profile = Profile(name="test-profile") + run_spec = get_run_spec( + repo_id=repo.name, + run_name="test-run", + profile=profile, + configuration=ServiceConfiguration( + commands=["python app.py"], + port=8000, + replica_groups=replica_groups, + scaling=ScalingSpec(metric="rps", target=10), + ), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + + # Create initial jobs + replica_num = 0 + for group_cfg in groups_config: + for job_status in group_cfg.get("initial_jobs", []): + job = await create_job( + session=session, + run=run, + status=job_status, + replica_num=replica_num, + replica_group_name=group_cfg["name"], + ) + run.jobs.append(job) + replica_num += 1 + + await session.commit() + + # Reload with jobs and project + res = await session.execute( + select(RunModel) + .where(RunModel.id == run.id) + .options(selectinload(RunModel.jobs), selectinload(RunModel.project)) + ) + return res.scalar_one() + + +class TestReplicaGroupsScaleDown: + """Test scaling down with replica groups""" + + @pytest.mark.asyncio + async def test_scale_down_only_from_autoscalable_groups(self, session: AsyncSession): + """Test that scale down only affects autoscalable groups""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-h100", + "replicas_range": "1..1", # Fixed + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "scalable-rtx", + "replicas_range": "1..3", # Autoscalable + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + # Scale down by 1 (should only affect scalable group) + await scale_wrapper(session, run, -1) + + # Check: fixed group should still have 1 running job + fixed_jobs = [j for j in run.jobs if j.replica_group_name == "fixed-h100"] + assert len(fixed_jobs) == 1 + assert fixed_jobs[0].status == JobStatus.RUNNING + + # Check: scalable group should have 1 terminated, 1 running + scalable_jobs = [j for j in run.jobs if j.replica_group_name == "scalable-rtx"] + assert len(scalable_jobs) == 2 + terminating = [j for j in scalable_jobs if j.status == JobStatus.TERMINATING] + assert len(terminating) == 1 + assert terminating[0].termination_reason == JobTerminationReason.SCALED_DOWN + + @pytest.mark.asyncio + async def test_scale_down_respects_group_minimums(self, session: AsyncSession): + """Test that scale down respects each group's minimum""" + run = await make_run_with_groups( + session, + [ + { + "name": "group-a", + "replicas_range": "1..3", # Min=1 + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], # At minimum + }, + { + "name": "group-b", + "replicas_range": "2..5", # Min=2 + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + # Try to scale down by 2 + await scale_wrapper(session, run, -2) + + # Group A should still have 1 (at minimum) + group_a_jobs = [j for j in run.jobs if j.replica_group_name == "group-a"] + assert len([j for j in group_a_jobs if j.status == JobStatus.RUNNING]) == 1 + + # Group B should have terminated 1 (3 -> 2, which is minimum) + group_b_jobs = [j for j in run.jobs if j.replica_group_name == "group-b"] + terminating = [j for j in group_b_jobs if j.status == JobStatus.TERMINATING] + assert len(terminating) == 1 + + @pytest.mark.asyncio + async def test_scale_down_all_groups_fixed(self, session: AsyncSession): + """Test scaling down when all groups are fixed (should not terminate anything)""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-1", + "replicas_range": "1..1", + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "fixed-2", + "replicas_range": "2..2", + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + initial_count = len(run.jobs) + + # Try to scale down + await scale_wrapper(session, run, -1) + + # No jobs should be terminated (all groups are fixed) + assert len(run.jobs) == initial_count + assert all(j.status == JobStatus.RUNNING for j in run.jobs) + + +class TestReplicaGroupsScaleUp: + """Test scaling up with replica groups""" + + @pytest.mark.asyncio + async def test_scale_up_selects_autoscalable_group(self, session: AsyncSession): + """Test that scale up only creates jobs in autoscalable groups""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-h100", + "replicas_range": "1..1", # Fixed + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "scalable-rtx", + "replicas_range": "1..3", # Autoscalable + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING], + }, + ], + ) + + initial_count = len(run.jobs) + + # Scale up by 1 + await scale_wrapper(session, run, 1) + + # Should have one more job + assert len(run.jobs) == initial_count + 1 + + # New job should be in scalable group + new_jobs = [j for j in run.jobs if j.replica_num == initial_count] + assert len(new_jobs) == 1 + assert new_jobs[0].replica_group_name == "scalable-rtx" + assert new_jobs[0].status == JobStatus.SUBMITTED + + @pytest.mark.asyncio + async def test_scale_up_respects_group_maximums(self, session: AsyncSession): + """Test that scale up respects group maximums""" + run = await make_run_with_groups( + session, + [ + { + "name": "small-group", + "replicas_range": "1..2", # Max=2 + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], # At max + }, + { + "name": "large-group", + "replicas_range": "1..5", # Max=5 + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING], + }, + ], + ) + + # Try to scale up by 2 + await scale_wrapper(session, run, 2) + + # Small group should still have 2 (at max) + small_jobs = [j for j in run.jobs if j.replica_group_name == "small-group"] + assert len(small_jobs) == 2 + + # Large group should have grown by 2 + large_jobs = [j for j in run.jobs if j.replica_group_name == "large-group"] + assert len(large_jobs) == 3 + + @pytest.mark.asyncio + async def test_scale_up_no_autoscalable_groups(self, session: AsyncSession): + """Test scale up does nothing when no autoscalable groups exist""" + run = await make_run_with_groups( + session, + [ + { + "name": "fixed-1", + "replicas_range": "1..1", + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING], + }, + { + "name": "fixed-2", + "replicas_range": "2..2", + "gpu": "RTX5090", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], + }, + ], + ) + + initial_count = len(run.jobs) + + # Try to scale up + await scale_wrapper(session, run, 2) + + # Should not have added any jobs + assert len(run.jobs) == initial_count + + @pytest.mark.asyncio + async def test_scale_up_all_groups_at_max(self, session: AsyncSession): + """Test scale up when all autoscalable groups are at maximum""" + run = await make_run_with_groups( + session, + [ + { + "name": "group-a", + "replicas_range": "1..2", + "gpu": "H100", + "initial_jobs": [JobStatus.RUNNING, JobStatus.RUNNING], # At max + }, + { + "name": "group-b", + "replicas_range": "1..3", + "gpu": "RTX5090", + "initial_jobs": [ + JobStatus.RUNNING, + JobStatus.RUNNING, + JobStatus.RUNNING, + ], # At max + }, + ], + ) + + initial_count = len(run.jobs) + + # Try to scale up + await scale_wrapper(session, run, 1) + + # Should not have added any jobs (all at max) + assert len(run.jobs) == initial_count + + +class TestReplicaGroupsBackwardCompatibility: + """Test backward compatibility with legacy configs""" + + @pytest.mark.asyncio + async def test_legacy_config_scaling(self, session: AsyncSession): + """Test scaling works with legacy replicas configuration""" + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + + # Use legacy format (no replica_groups) + profile = Profile(name="test-profile") + run_spec = get_run_spec( + repo_id=repo.name, + run_name="test-run", + profile=profile, + configuration=ServiceConfiguration( + commands=["python app.py"], + port=8000, + replicas=parse_obj_as(Range[int], "1..3"), # Legacy format + scaling=ScalingSpec(metric="rps", target=10), + ), + ) + + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + + # Add initial job (no group name) + job = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + replica_num=0, + replica_group_name=None, # Legacy jobs have no group + ) + run.jobs.append(job) + await session.commit() + + # Scale up should work + await scale_wrapper(session, run, 1) + + # Should have 2 jobs now + assert len(run.jobs) == 2 + + # New job should have "default" group name or None + new_job = [j for j in run.jobs if j.replica_num == 1][0] + assert new_job.replica_group_name in [None, "default"] diff --git a/src/tests/_internal/server/services/test_replica_groups_update.py b/src/tests/_internal/server/services/test_replica_groups_update.py new file mode 100644 index 000000000..1ae7148cc --- /dev/null +++ b/src/tests/_internal/server/services/test_replica_groups_update.py @@ -0,0 +1,172 @@ +"""Tests for updating services with replica groups.""" + +from pydantic import parse_obj_as + +from dstack._internal.core.models.configurations import ( + ReplicaGroup, + ScalingSpec, + ServiceConfiguration, +) +from dstack._internal.core.models.profiles import Profile, SpotPolicy +from dstack._internal.core.models.resources import GPUSpec, Range, ResourcesSpec +from dstack._internal.core.models.runs import RunSpec +from dstack._internal.server.services.runs import _check_can_update_run_spec + + +def test_can_update_from_replicas_to_replica_groups(): + """Test that we can update a service from simple replicas to replica_groups.""" + # Old config with simple replicas + old_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replicas=2, + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + ) + + old_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=old_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # New config with replica_groups + new_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="h100-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-east-1"], + ), + ReplicaGroup( + name="rtx5090-group", + replicas=parse_obj_as(Range[int], "0..3"), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "RTX5090:1")), + regions=["jp-japan"], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + new_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=new_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # This should NOT raise an error + _check_can_update_run_spec(old_run_spec, new_run_spec) + + +def test_can_update_from_replica_groups_to_replicas(): + """Test that we can update a service from replica_groups back to simple replicas.""" + # Old config with replica_groups + old_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="h100-group", + replicas=parse_obj_as(Range[int], 1), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + ), + ], + ) + + old_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=old_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # New config with simple replicas + new_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replicas=2, + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), + ) + + new_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=new_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # This should NOT raise an error + _check_can_update_run_spec(old_run_spec, new_run_spec) + + +def test_can_update_replica_groups(): + """Test that we can update replica_groups in place.""" + # Old config + old_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="gpu-group", + replicas=parse_obj_as(Range[int], "1..3"), + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "H100:1")), + regions=["us-east-1"], + ), + ], + scaling=ScalingSpec(metric="rps", target=10), + ) + + old_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=old_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # New config with different replica_groups + new_config = ServiceConfiguration( + commands=["echo hello"], + port=8000, + replica_groups=[ + ReplicaGroup( + name="gpu-group", + replicas=parse_obj_as(Range[int], "2..5"), # Changed range + resources=ResourcesSpec(gpu=parse_obj_as(GPUSpec, "A100:1")), # Changed GPU + regions=["us-west-2"], # Changed region + spot_policy=SpotPolicy.SPOT, # Added spot policy + ), + ], + scaling=ScalingSpec(metric="rps", target=20), # Changed target + ) + + new_run_spec = RunSpec( + run_name="test-run", + repo_id="test-repo", + repo_data={"repo_type": "local", "repo_dir": "/repo"}, + configuration_path="dstack.yaml", + configuration=new_config, + profile=Profile(name="test-profile"), + ssh_key_pub="ssh_key", + ) + + # This should NOT raise an error (replica_groups + resources + scaling are all updatable) + _check_can_update_run_spec(old_run_spec, new_run_spec)