Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -928,6 +928,29 @@ class InferenceEngineConfig:
)


@dataclass
class ScalingConfig:
"""Configuration for dynamic scaling of inference/training servers."""

enable_scaling: bool = field(
default=False,
metadata={"help": "Whether scaling is enabled (True/False)."},
)

mode: str = field(
default="manual",
metadata={
"help": "Scaling mode — can be 'manual' or 'auto'.",
"choices": ["manual", "auto"],
},
)

scaling_controller_port: int = field(
default=8899,
metadata={"help": "HTTP port for the scale-up service endpoint."},
)


@dataclass
class _Timer:
experiment_name: str = MISSING
Expand Down Expand Up @@ -1341,6 +1364,8 @@ class BaseExperimentConfig:

scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)

scaling: ScalingConfig = field(default_factory=ScalingConfig)


@dataclass
class SFTConfig(BaseExperimentConfig):
Expand Down
39 changes: 39 additions & 0 deletions areal/core/remote_inf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def __init__(

self.workflow_executor: WorkflowExecutor
self.local_server_processes: list[LocalInfServerInfo] = []
self.update_servers = False

def _get_or_create_session(self) -> aiohttp.ClientSession:
"""Get or create a ClientSession for the current thread/event loop.
Expand Down Expand Up @@ -424,6 +425,27 @@ def get_version(self):
with self.lock:
return self._version

def refresh_addresses(self, new_addresses: list[str]) -> None:
"""
Refresh the list of available servers dynamically.

Args:
new_addresses (list[str]): Updated list of server addresses.
"""
if not new_addresses:
raise RuntimeError("No servers provided when refreshing addresses.")

# Only log if there's an actual change
if new_addresses != self.addresses:
self.logger.info(f"Refreshing server addresses: {new_addresses}")

# Replace with the new set
self.addresses = new_addresses

# Clamp server_idx to valid range
if self.server_idx >= len(self.addresses):
self.server_idx = 0

def choose_server(self) -> str:
"""Choose a server based on the scheduling policy.

Expand All @@ -437,7 +459,16 @@ def choose_server(self) -> str:
NotImplementedError
If schedule policy other than round-robin is used
"""

if self.update_servers:
name = names.gen_servers(
self.config.experiment_name, self.config.trial_name
)
vllm_addrs = name_resolve.get_subtree(name)
self.refresh_addresses(vllm_addrs)
self.update_servers = False
if self.config.schedule_policy == "round_robin":
self.server_idx %= len(self.addresses)
server = self.addresses[self.server_idx]
self.server_idx = (self.server_idx + 1) % len(self.addresses)
return server
Expand Down Expand Up @@ -591,6 +622,12 @@ def init_weights_update_group(self, meta: WeightUpdateMeta) -> Future[None]:
assert meta.type == current_platform.communication_backend
assert not self.distributed_weight_update_initialized

# Refresh the gen servers if there is scale request
name = names.gen_servers(self.config.experiment_name, self.config.trial_name)
vllm_addrs = name_resolve.get_subtree(name)
if vllm_addrs != self.addresses:
self.refresh_addresses(vllm_addrs)

fut = self.executor.submit(
_init_weights_update_group_remote,
self.backend,
Expand Down Expand Up @@ -845,6 +882,8 @@ def pause(self):
"""Pause request submission for async rollout.
Used during evaluation to prevent data over generation.
"""
# Whenever pause for update weight, make update_servers True to dispatch request to new servers
self.update_servers = True
return self.workflow_executor.pause()

def resume(self):
Expand Down
13 changes: 11 additions & 2 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __init__(self, config: TrainEngineConfig):
self.rank: int
self.dp_head: int
self.dp_rank: int
self.scaling_count = 0
self.create_group_count = 0

@property
def data_parallel_group(self) -> dist.ProcessGroup:
Expand Down Expand Up @@ -376,14 +378,21 @@ def _init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
)
self.weight_update_group = init_custom_process_group(
backend=current_platform.communication_backend,
world_size=meta.alloc_mode.gen.world_size + 1,
world_size=meta.alloc_mode.gen.world_size + 1 + self.scaling_count,
init_method=f"tcp://{meta.nccl_master_address}:{meta.nccl_master_port}",
rank=0,
group_name=meta.nccl_group_name,
group_name=meta.nccl_group_name + str(self.create_group_count),
timeout=DIST_GROUP_DEFAULT_TIMEOUT,
)

self.create_group_count += 1
fut.result()
self.rollout_engine._engine.backend.create_group_count += 1

def _re_init_weight_update_from_distributed(self, meta: WeightUpdateMeta):
self.weight_update_group_initialized = False
self._init_weight_update_from_distributed(meta)
self.weight_update_group_initialized = True

@trace_perf("fsdp_engine.update_weights_from_distributed", category="comm")
def _update_weights_from_distributed(self, meta: WeightUpdateMeta):
Expand Down
11 changes: 8 additions & 3 deletions areal/engine/vllm_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@
class VLLMBackend:
"""vLLM-specific backend implementation for remote inference."""

def __init__(self):
self.scaling_count = 0
self.create_group_count = 0

def build_generation_request(
self, req: ModelRequest, with_lora: bool
) -> HttpRequest:
Expand Down Expand Up @@ -106,7 +110,8 @@ def build_distributed_weight_update_requests(
"names": [pspec.name for pspec in param_specs],
"dtypes": [pspec.dtype for pspec in param_specs],
"shapes": [pspec.shape for pspec in param_specs],
"group_name": meta.nccl_group_name,
"group_name": meta.nccl_group_name
+ str(self.create_group_count),
},
),
HttpRequest(
Expand All @@ -128,9 +133,9 @@ def build_init_weights_group_request(
"master_address": meta.nccl_master_address,
"master_port": str(meta.nccl_master_port),
"rank_offset": rank_offset,
"world_size": meta.alloc_mode.gen.world_size + 1,
"world_size": meta.alloc_mode.gen.world_size + 1 + self.scaling_count,
"backend": current_platform.communication_backend,
"group_name": meta.nccl_group_name,
"group_name": meta.nccl_group_name + str(self.create_group_count),
}
return HttpRequest(endpoint="/areal_init_weights_update_group", payload=payload)

Expand Down
32 changes: 29 additions & 3 deletions areal/launcher/ray.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib.util
import os
import pathlib
import re
import subprocess
import sys
import time
from collections.abc import Callable
Expand All @@ -18,6 +20,7 @@
ClusterSpecConfig,
LauncherConfig,
RecoverConfig,
ScalingConfig,
SGLangConfig,
parse_cli_args,
to_structured_cfg,
Expand All @@ -41,7 +44,23 @@
RAY_WAIT_CHECK_TIME_INTERVAL = 5 # seconds
DEFAULT_MAIN_FUNC_NAME = "main"
RAY_LAUNCHER = None
RECOVER_TIME_INTERVAL = 10 # seconds
RECOVER_TIME_INTERVAL = 10 # second


def launch_scale_common(config_path: str):
"""Launch scale_common.py as a background subprocess without blocking."""
script_path = str(
pathlib.Path(__file__).resolve().parent.joinpath("scaler/scaling_controller.py")
)
env = os.environ.copy()
env["PYTHONUNBUFFERED"] = "1"
subprocess.Popen(
[sys.executable, script_path, config_path],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
env=env,
start_new_session=True,
)


def run_func(file_path, function_name, *args, **kwargs):
Expand Down Expand Up @@ -334,8 +353,15 @@ def wait(


def main():
ray.init()
config, _ = parse_cli_args(sys.argv[1:])
ray.init(address="auto")
config, config_file = parse_cli_args(sys.argv[1:])
config.scaling = to_structured_cfg(config.scaling, ScalingConfig)
# Check whether enable scaling or not
if config.scaling.enable_scaling:
try:
launch_scale_common(str(config_file))
except Exception as e:
logger.info(f"[RayLauncher] Warning: Failed to scaler.py: {e}")
ray_main(config, run_id=0)


Expand Down
Loading