Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ jobs:
run: python -m pip install --upgrade pip
- name: Install torchforge
shell: bash -l {0}
run: ./scripts/install.sh
- name: Install docs dependencies
shell: bash -l {0}
run: python -m pip install -r docs/requirements.txt
run: pip install uv && uv pip install . && uv pip install .[docs]
- name: Build docs
shell: bash -l {0}
working-directory: docs
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/gpu_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: Update pip
run: python -m pip install --upgrade pip
- name: Install torchforge
run: ./scripts/install.sh
run: pip install uv && uv pip install . && uv pip install .[dev]
- name: Run unit tests with coverage
# TODO add all tests
run: pytest tests/unit_tests --cov=. --cov-report=xml --durations=20 -vv
Expand Down
22 changes: 21 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,20 @@ You can also find our notebook tutorials (coming soon)

### Basic

torchforge requires PyTorch 2.9.0 with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan). (Note that the basic install script
torchforge requires PyTorch 2.9.0 with [Monarch](https://github.com/meta-pytorch/monarch), [vLLM](https://github.com/vllm-project/vllm), and [torchtitan](https://github.com/pytorch/torchtitan).

You can install Forge with:
```
$ conda create -n forge python=3.10
$ conda activate forge
$ uv pip install .
```

(conda-less uv install is a wip)

For your reference, we also include a basic install script that installs other system dependencies
along with torchforge:
(note that this basic install script
uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.)

```bash
Expand All @@ -45,6 +58,13 @@ Optional: By default, the packages installation uses conda. If user wants to ins

After install, you can run the following command and should see output confirming GRPO training is running (you need a minimum 3 GPU devices):


```
uv run apps/grpo/main.py --config apps/grpo/qwen3_1_7b.yaml
```

or if not using uv:

```
python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
```
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ async def continuous_training():
except KeyboardInterrupt:
print("Training interrupted by user")
finally:
print("Shutting down...")
print("Shutting down... (this may take a few seconds)")
shutdown_event.set()

try:
Expand Down
Binary file not shown.
2 changes: 1 addition & 1 deletion assets/versions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PYTORCH_VERSION="2.9.0"
VLLM_VERSION="v0.10.0"
MONARCH_VERSION="0.1.2"
TORCHTITAN_VERSION="0.2.0"
TORCHSTORE_VERSION="0.1.1"
TORCHSTORE_VERSION="0.1.2"

# Torchtitan commit hash for launching on MAST
TORCHTITAN_COMMIT_MAST="d0e25450bcac2332359b13fbda430dc701f073d4"
9 changes: 0 additions & 9 deletions docs/requirements.txt

This file was deleted.

43 changes: 24 additions & 19 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,22 @@ authors = [
keywords = ["pytorch", "training", "llm"]
dependencies = [
# PyTorch
"torch==2.9.0",
"torchdata>=0.8.0",
"torchtitan",
"torchtitan==0.2.0",
"torchmonarch==0.1.2",
"torchstore==0.1.2",
# vLLM
# TODO: pin specific vllm version
#"vllm==0.10.0",
"vllm",
# Hugging Face integrations
"datasets>=2.21.0",
"tokenizers",
# Miscellaneous
"omegaconf",
"wandb",
"hf_transfer",
"six",
"setuptools<80",
]
dynamic = ["version"]

Expand All @@ -44,10 +48,16 @@ dev = [
"pytest-asyncio",
"multiprocess",
]
oss = [
"torch",
"torchmonarch-nightly==2025.8.1",
"torchstore",
docs = [
"sphinx==7.2.6",
"pytorch-sphinx-theme2==0.1.0",
"docutils>=0.18.1,<0.21",
"sphinx-design==0.6.1",
"sphinxcontrib-mermaid==1.0.0",
"sphinx-gallery==0.19.0",
"myst-parser",
"sphinx-sitemap==2.7.1",
"sphinx-autodoc-typehints==1.25.3",
]

# ---- Explicit project build information ---- #
Expand All @@ -69,23 +79,18 @@ members = [
]

# pytorch
# TODO: get auto backend to work
[[tool.uv.index]]
name = "pytorch-nightly-cu129"
url = "https://download.pytorch.org/whl/nightly/cu129"
#explicit = true
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"

# vllm
# [[tool.uv.index]]
# name = "vllm-nightly"
# url = "https://wheels.vllm.ai/nightly"
# explicit = true
[[tool.uv.index]]
name = "vllm-forge"
url = "https://download.pytorch.org/whl/preview/forge"

[tool.uv.sources]
torchtitan = { index = "pytorch-nightly-cu129" }
torch = { index = "pytorch-nightly-cu129" }
torchstore = { git = "ssh://[email protected]/meta-pytorch/torchstore.git" }
#vllm = { index = "vllm-nightly" }
torch = { index = "pytorch-cu128" }
vllm = { index = "vllm-forge" }

[tool.uv]
# TODO: revert to stricter default uv strategy
Expand Down
80 changes: 55 additions & 25 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

"""Remote and local resource manager for allocation and provisioning."""
import asyncio
import functools
import logging

import os
Expand All @@ -19,7 +18,6 @@
from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host

from monarch.tools import commands

from monarch.utils import setup_env_for_distributed

from forge.controller.launcher import BaseLauncher, get_launcher
Expand All @@ -46,6 +44,39 @@ def get_info(self) -> tuple[str, str]:
return socket.gethostname(), _get_port()


class EnvSetter(Actor):
"""Actor to set environment variables on each proc in a mesh.

Ideally, this is handled in spawn_procs's bootstrap call which
essentially does the same thing as we're doing here.

However, Monarch's SetupActor currently fails to stop on shutdown
which leads to zombie messages sent to the SetupActor. This is a
known issue, and we will move back to bootstrap once it's fixed.

We are able to avoid this here by properly awaiting the spawning
of the actor.

"""

@endpoint
def set_env(self, env_vars: dict[str, str]):
"""Set environment variables on this proc.

Args:
env_vars: Dictionary of environment variables to set
"""
import os
import socket

# Set VLLM_HOST_IP (required for vLLM on multiple nodes)
os.environ["VLLM_HOST_IP"] = socket.gethostbyname(socket.getfqdn())

# Set user-provided environment variables
for k, v in env_vars.items():
os.environ[k] = v


async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
"""Returns the host name and port of the host mesh."""
throwaway_procs = host_mesh.spawn_procs(per_host={"procs": 1})
Expand All @@ -64,6 +95,20 @@ async def get_remote_info(host_mesh: HostMesh) -> tuple[str, str]:
return host, port


async def set_environment(proc_mesh: ProcMesh, env_vars: dict[str, str]):
"""Set environment variables on a proc mesh using EnvSetter actor.

This replaces the old bootstrap approach to avoid Monarch's SetupActor
mesh failures on shutdown.

Args:
proc_mesh: The proc mesh to set environment variables on
env_vars: Dictionary of environment variables to set
"""
env_setter = proc_mesh.spawn("_env_setter", EnvSetter)
await env_setter.set_env.call(env_vars)


class GpuManager:
"""Tracks and assigns GPU devices on a host.

Expand Down Expand Up @@ -244,26 +289,6 @@ async def get_proc_mesh(
gpu_manager = self._host_gpu_map[self._this_host_id]
host_mesh._host_id = self._this_host_id

def bootstrap(env: dict[str, str]):
"""Runs on process startup.

We use this to set environment variables like CUDA, etc.
We prefer to pass in environment variables to bootstrap,
but there are occasionally host-specific environments that can
only be set once the process is alive on remote hosts.

"""
# bootstrap is run on all processes. We use this
# to set environment variables like CUDA etc.
import os

# vLLM requires this environment variable when spawning model servers
# across multiple nodes.
os.environ["VLLM_HOST_IP"] = socket.gethostbyname(socket.getfqdn())

for k, v in env.items():
os.environ[k] = v

if with_gpus:
if not addr or not port:
addr, port = await get_remote_info(host_mesh)
Expand All @@ -281,17 +306,22 @@ def bootstrap(env: dict[str, str]):
for env_var in all_env_vars():
env_vars[env_var.name] = str(env_var.get_value())

# Spawn procs without bootstrap to avoid SetupActor mesh failures
procs = host_mesh.spawn_procs(
per_host={"procs": num_procs},
bootstrap=functools.partial(bootstrap, env=env_vars),
name=mesh_name,
)

# Set up environment variables (replaces old bootstrap)
if env_vars:
await set_environment(procs, env_vars)

# Set up PyTorch distributed environment if using GPUs
if with_gpus:
# Set up environment variables for PyTorch distributed...
await setup_env_for_distributed(
procs,
master_addr=addr,
master_port=port,
master_port=int(port),
)

if is_remote:
Expand Down
Loading
Loading