Add pp_size support to evaluation harness and local_test examples#59
Add pp_size support to evaluation harness and local_test examples#59shivam-MBZUAI wants to merge 26 commits into
Conversation
WalkthroughAdd pipeline parallelism (pp_size) across evaluator and local trainers, implement a new pipeline-parallel trainer, shift pre-warmup trainability checks with cross-rank sync, update multi‑GPU topology/diagnostics/MFU logic, and revise docs, Docker NCCL envs, and tokenizer handling. Changes
Sequence Diagram(s)sequenceDiagram
participant Stage0 as Stage 0 (pp rank 0)
participant Stage1 as Stage 1 (pp rank 1)
participant DP as DP subgroup
participant Optimizer as Optimizer
Note over Stage0,Stage1: Preload batches & compute total_tokens
Stage0->>Stage1: send(activations) rgba(135,206,250,0.5)
Stage1->>Stage1: forward -> logits -> loss
Stage1->>Stage0: send(hidden.grad) rgba(135,206,250,0.5)
Stage0->>Stage0: backward using received grad
Stage0->>DP: all_reduce(grads) rgba(144,238,144,0.5)
DP->>Optimizer: averaged grads
Optimizer->>Stage0: step & zero_grad rgba(255,182,193,0.5)
Optimizer->>Stage1: step & zero_grad rgba(255,182,193,0.5)
Stage1->>Stage0: send(final_loss, final_logits) rgba(255,182,193,0.5)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
environments/templar/env.py (1)
1854-1877:⚠️ Potential issue | 🔴 Critical
pp_sizeis accepted here, but verification is still rank-0-centric.Lines 1854-1877 update GPU accounting and data sharding for PP, but later
evaluate()still only validates theInnerStepsResultcoming from global rank 0. In the PP topology introduced by this PR, rank 0 is stage 0 and never owns the terminal logits/loss, so PP submissions are accepted here and then fail warmup/output verification downstream. The harness needs a stage-aware verification rank or an out-of-band gather of the last-stage result beforepp_size > 1can work end-to-end.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/train_pp.py`:
- Around line 202-227: _gather_pp_state() uses world-wide collectives with
differing src ranks across data-parallel replicas causing deadlock; fix by
creating and using a per-pipeline process group so the broadcasts/receives run
only among pipeline peers. In _gather_pp_state(), compute the pipeline peer
ranks (based on dp/pp decomposition using pp_size/dp_size or existing
pp_peer/global_rank logic), create a pipeline_group via
torch.distributed.new_group(pipeline_ranks), and call
dist.broadcast_object_list(..., group=pipeline_group, src=...) for the
keys_and_shapes exchanges and the peer_obj exchange so every rank in that
subgroup uses the same src; keep dist.recv for the actual tensor data but ensure
the matching sends/receives occur inside the same pipeline subgroup (or
alternatively restrict the assembly to a single replica like global_rank==0 if
you prefer that simpler approach).
- Line 57: The code uses the forbidden built-in __import__ when setting
local_rank; replace that usage by adding a standard top-level import os and then
use os.environ to read LOCAL_RANK (i.e., change the local_rank assignment that
references __import__("os") to use os.environ.get("LOCAL_RANK", "0")). Ensure
you add the import os at module scope so only os.environ is used in the
local_rank initialization.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 12d7976d-d9e8-46a0-9d36-eb0fa84d5e7b
📒 Files selected for processing (7)
environments/templar/env.pylocal_test/simulate_validator.pylocal_test/train_ddp.pylocal_test/train_fsdp.pylocal_test/train_mixed.pylocal_test/train_pp.pylocal_test/train_tp.py
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
local_test/train_pp.py (2)
107-107: Unused loop variablestep.Per static analysis,
stepis not used within the loop body.Rename to underscore prefix
- for step in range(num_steps): + for _step in range(num_steps):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` at line 107, The loop variable "step" in the for loop "for step in range(num_steps):" is unused; rename it to "_" to indicate an intentionally ignored loop variable (i.e., change the loop header in train_pp.py to use "_" instead of "step") so static analysis no longer flags an unused variable and behavior remains unchanged.
243-244: Unused loop variablek.Per static analysis,
kis not used in the loop body—onlyvis sent.Rename to underscore prefix
- for k, v in my_state.items(): + for _k, v in my_state.items(): dist.send(v.contiguous(), dst=pp_peer)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 243 - 244, The loop over my_state currently binds an unused key variable k; update the iteration to avoid the unused binding by iterating over values only or prefixing the key as unused: change the loop that reads "for k, v in my_state.items():" to either iterate values with my_state.values() or use "for _, v in my_state.items():" so that only v is used when calling dist.send(v.contiguous(), dst=pp_peer); keep the call to dist.send, v.contiguous(), and pp_peer unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/train_pp.py`:
- Around line 129-132: The hidden tensor is being created with a hardcoded
dtype=torch.bfloat16 which can mismatch the model embeddings; change the
creation of hidden to use the model/embedding dtype instead (e.g., infer dtype
from the embedding layer or model parameters via something like
embedding.weight.dtype or next(model.parameters()).dtype) and keep device
consistent before calling dist.recv and subsequent
.detach().requires_grad_(True) on hidden; update the torch.zeros call that
initializes hidden to use the inferred dtype rather than torch.bfloat16 so recv
and subsequent ops match the model's actual dtype.
---
Nitpick comments:
In `@local_test/train_pp.py`:
- Line 107: The loop variable "step" in the for loop "for step in
range(num_steps):" is unused; rename it to "_" to indicate an intentionally
ignored loop variable (i.e., change the loop header in train_pp.py to use "_"
instead of "step") so static analysis no longer flags an unused variable and
behavior remains unchanged.
- Around line 243-244: The loop over my_state currently binds an unused key
variable k; update the iteration to avoid the unused binding by iterating over
values only or prefixing the key as unused: change the loop that reads "for k, v
in my_state.items():" to either iterate values with my_state.values() or use
"for _, v in my_state.items():" so that only v is used when calling
dist.send(v.contiguous(), dst=pp_peer); keep the call to dist.send,
v.contiguous(), and pp_peer unchanged.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: ea22b369-ec5c-499a-a00c-3bf7cdce55ba
📒 Files selected for processing (2)
environments/templar/env.pylocal_test/train_pp.py
🚧 Files skipped from review as they are similar to previous changes (1)
- environments/templar/env.py
| hidden = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=torch.bfloat16) | ||
| hidden.requires_grad_(True) | ||
| dist.recv(hidden, src=pp_peer) | ||
| hidden = hidden.detach().requires_grad_(True) |
There was a problem hiding this comment.
Hardcoded bfloat16 dtype may cause mismatch.
The hidden tensor is created with dtype=torch.bfloat16, but the model's embedding layer may output a different dtype (e.g., float16 or float32). Consider inferring the dtype from the model's config or embedding output.
Suggested fix
+ model_dtype = next(model.parameters()).dtype
- hidden = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=torch.bfloat16)
+ hidden = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=model_dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_pp.py` around lines 129 - 132, The hidden tensor is being
created with a hardcoded dtype=torch.bfloat16 which can mismatch the model
embeddings; change the creation of hidden to use the model/embedding dtype
instead (e.g., infer dtype from the embedding layer or model parameters via
something like embedding.weight.dtype or next(model.parameters()).dtype) and
keep device consistent before calling dist.recv and subsequent
.detach().requires_grad_(True) on hidden; update the torch.zeros call that
initializes hidden to use the inferred dtype rather than torch.bfloat16 so recv
and subsequent ops match the model's actual dtype.
38568ea to
ab6942e
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
♻️ Duplicate comments (2)
local_test/train_pp.py (2)
216-244:⚠️ Potential issue | 🔴 Critical
_gather_pp_state()is invalid for the evaluator’s NCCL topology.
environments/templar/env.pyinitializes NCCL at Line 1996. This helper then (1) issues world-groupbroadcast_object_list()calls with differentsrcranks per replica and (2) sends CPU tensors created via.cpu().clone()throughdist.send/recv. The first can deadlock, and the second is not supported on an NCCL process group. Restrict state assembly to one pipeline replica or a dedicated pipeline subgroup, and keep tensor payloads on CUDA until after transfer.With the NCCL backend, are CPU tensors supported by `torch.distributed.send/recv`, and do all ranks in the same process group need to call `broadcast_object_list()` in the same order with the same source rank?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 216 - 244, _gather_pp_state() currently uses world-group broadcast_object_list() with varying src and sends CPU tensors via dist.send/recv causing deadlocks and NCCL errors; change it to perform object-list exchanges only within a single pipeline replica or a dedicated pipeline process_group (create a subgroup process group for the ranks participating in this pipeline) so all ranks in that group call dist.broadcast_object_list() with the same src and ordering, and when transferring tensors keep them on CUDA (create CUDA buffers with torch.empty(..., device='cuda', dtype=...) and use contiguous CUDA tensors) before calling dist.send/dist.recv or use dist.send/recv on the subgroup process group; update references in _gather_pp_state(), my_state population, the is_first_stage/else branches, and use pp_peer only inside the subgroup to avoid world-group mismatches.
24-28:⚠️ Potential issue | 🔴 CriticalReplace
__import__("os"); the validator rejects it.Line 60 uses a bare
__import__, andenvironments/templar/env.py::_scan_for_dangerous_patterns()rejects that name during AST validation. This file never reachesinner_steps()as written.Suggested fix
+import os import torch import torch.distributed as dist import torch.nn.functional as F @@ - local_rank = int(__import__("os").environ.get("LOCAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0"))Also applies to: 59-60
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 24 - 28, The code uses a bare dynamic import __import__("os") (referenced near inner_steps() and lines 59-60), which fails AST validation; replace that dynamic import with a static import (add import os at the top of the file) or, if dynamic behavior is required, use importlib.import_module("os") with an explicit importlib import; update any references that expect the __import__ return accordingly so inner_steps() and related functions use the module via the static name (os) or the importlib result.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@environments/templar/Dockerfile`:
- Around line 27-34: The Dockerfile currently always calls
AutoTokenizer.from_pretrained(model) even when an override exists; change the
logic to select a single tokenizer source (e.g., tok_source = tokenizer_name or
model) and call AutoTokenizer.from_pretrained(tok_source,
trust_remote_code=True) only once so the override (tokenizer_name from cfg) is
used when provided and the unconditional download of
AutoTokenizer.from_pretrained(model) is removed; update the block around
tokenizer_name, AutoTokenizer.from_pretrained, and the print statements to use
tok_source and avoid double downloads.
In `@environments/templar/env.py`:
- Around line 374-385: The code around the try/except that reads hparams.json
(setting _CACHE["_hparams_tokenizer_name_resolved"] and variable name) is
silently swallowing all exceptions; change the except block to catch Exception
as e and log a clear error including the file path and exception (e) and stack
(use logger.exception or logger.error with exc_info=True) so failures are
visible, then either re-raise the exception to fail fast or set a clear sentinel
in _CACHE and return/continue accordingly; update the block that returns name
(and the _CACHE entry) to only do so when the file was parsed successfully.
- Around line 1352-1397: The code that reloads models directly via _load_model()
(before calling load_state_dict(initial_state)) must use the same
tokenizer-aware resizing logic as _get_cached_model to avoid shape mismatches;
replace those direct _load_model(...) calls with a call to
_get_cached_model(model_path, use_random_init=...) or, after _load_model(...),
invoke the tokenizer-resize steps identical to _get_cached_model (check
tokenizer_name from _get_hparams_tokenizer_name(), compare vocab sizes, and call
model.resize_token_embeddings) before saving/using initial_state and calling
load_state_dict(initial_state).
In `@local_test/train_pp.py`:
- Around line 87-90: The current logic constructs dp_group by calling
dist.new_group() only for the local rank's DP subset (using dp_ranks computed
from num_gpus, pp_size, pp_rank), causing different ranks to call new_group() in
different orders and hang; fix by deterministically creating all DP subgroups in
the same order on every rank (e.g., loop r0 in range(pp_size) to build each
group's rank list from num_gpus/pp_size and call dist.new_group(...) for every
subgroup), store each created subgroup in a list, and then select the subgroup
for this rank into dp_group (still handle the dp_size==1 case by using the
world_group or None as before); update code around dp_size, dp_ranks,
dist.new_group, and dp_group to follow this deterministic subgroup creation
pattern.
---
Duplicate comments:
In `@local_test/train_pp.py`:
- Around line 216-244: _gather_pp_state() currently uses world-group
broadcast_object_list() with varying src and sends CPU tensors via
dist.send/recv causing deadlocks and NCCL errors; change it to perform
object-list exchanges only within a single pipeline replica or a dedicated
pipeline process_group (create a subgroup process group for the ranks
participating in this pipeline) so all ranks in that group call
dist.broadcast_object_list() with the same src and ordering, and when
transferring tensors keep them on CUDA (create CUDA buffers with
torch.empty(..., device='cuda', dtype=...) and use contiguous CUDA tensors)
before calling dist.send/dist.recv or use dist.send/recv on the subgroup process
group; update references in _gather_pp_state(), my_state population, the
is_first_stage/else branches, and use pp_peer only inside the subgroup to avoid
world-group mismatches.
- Around line 24-28: The code uses a bare dynamic import __import__("os")
(referenced near inner_steps() and lines 59-60), which fails AST validation;
replace that dynamic import with a static import (add import os at the top of
the file) or, if dynamic behavior is required, use importlib.import_module("os")
with an explicit importlib import; update any references that expect the
__import__ return accordingly so inner_steps() and related functions use the
module via the static name (os) or the importlib result.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 148be152-0473-41ab-b160-b6c970b37cd5
📒 Files selected for processing (10)
environments/templar/Dockerfileenvironments/templar/env.pylocal_test/setup_benchmark.pylocal_test/simulate_validator.pylocal_test/train_ddp.pylocal_test/train_fsdp.pylocal_test/train_mixed.pylocal_test/train_pp.pylocal_test/train_tp.pysrc/crusades/config.py
🚧 Files skipped from review as they are similar to previous changes (5)
- local_test/simulate_validator.py
- local_test/train_mixed.py
- local_test/train_fsdp.py
- local_test/train_ddp.py
- local_test/train_tp.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
environments/templar/env.py (1)
243-258:⚠️ Potential issue | 🟠 MajorReject invalid explicit
pp_sizevalues instead of defaulting them to1.At Line 256, any explicit
pp_sizethat is not a literalintgets dropped and treated as “omitted”.{"dp_size": 2, "tp_size": 1, "pp_size": -1}or{"pp_size": some_var}therefore parses as(2, 1, 1), so the validator can shard and verify under the wrong topology instead of failing fast.🔧 Proposed fix
if isinstance(stmt.value, ast.Dict): - d: dict[str, int] = {} + d: dict[str, int] = {} + known_keys = {"dp_size", "tp_size", "pp_size"} for k, v in zip(stmt.value.keys, stmt.value.values): - if ( - isinstance(k, ast.Constant) - and isinstance(k.value, str) - and isinstance(v, ast.Constant) - and isinstance(v.value, int) - ): - d[k.value] = v.value + if not (isinstance(k, ast.Constant) and isinstance(k.value, str)): + continue + if k.value in known_keys and not ( + isinstance(v, ast.Constant) and type(v.value) is int + ): + return None + if isinstance(v, ast.Constant) and type(v.value) is int: + d[k.value] = v.value if "dp_size" in d and "tp_size" in d: dp = d["dp_size"] tp = d["tp_size"]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@environments/templar/env.py` around lines 243 - 258, The current AST dict parsing in environments/templar/env.py wrongly treats an explicitly-present but non-literal or invalid "pp_size" as if it were omitted (defaulting to 1); update the handling in the block that constructs ParallelismConfig so that you detect whether "pp_size" key exists in the original AST (i.e., check presence before defaulting), require that if present it must be an ast.Constant with an int value >= 1, and if it is present but invalid (non-constant or <1) abort/skip returning a ParallelismConfig rather than silently using pp_size=1; reference the dict-parsing code that currently sets dp, tp, pp and the ParallelismConfig(...) return.
♻️ Duplicate comments (1)
environments/templar/env.py (1)
374-387:⚠️ Potential issue | 🟡 MinorDon’t hide
hparams.jsonread/parse failures.If
benchmark_tokenizer_namecannot be read here, the code silently falls back to the model tokenizer and then caches that fallback at Line 387. That turns a broken benchmark config into a hard-to-debug tokenizer/model mismatch for the rest of the worker lifetime. Log the exception and avoid caching the empty sentinel on parse failures.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@environments/templar/env.py` around lines 374 - 387, The current try/except around reading hparams.json swallows all exceptions and the code later writes an empty sentinel to _CACHE["_hparams_tokenizer_name_resolved"], hiding parse/read failures; change the except block to log the failure (use logger.exception or logger.error(..., exc_info=True) and include the file path p) so you capture the error details, and do NOT set the empty sentinel when the exception occurs—only set _CACHE["_hparams_tokenizer_name_resolved"] = "" when the file was successfully read and no benchmark_tokenizer_name was present (i.e., keep the successful-path assignments in place but remove or guard the final empty-cache write on parse/read exceptions).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@environments/templar/env.py`:
- Around line 243-258: The current AST dict parsing in
environments/templar/env.py wrongly treats an explicitly-present but non-literal
or invalid "pp_size" as if it were omitted (defaulting to 1); update the
handling in the block that constructs ParallelismConfig so that you detect
whether "pp_size" key exists in the original AST (i.e., check presence before
defaulting), require that if present it must be an ast.Constant with an int
value >= 1, and if it is present but invalid (non-constant or <1) abort/skip
returning a ParallelismConfig rather than silently using pp_size=1; reference
the dict-parsing code that currently sets dp, tp, pp and the
ParallelismConfig(...) return.
---
Duplicate comments:
In `@environments/templar/env.py`:
- Around line 374-387: The current try/except around reading hparams.json
swallows all exceptions and the code later writes an empty sentinel to
_CACHE["_hparams_tokenizer_name_resolved"], hiding parse/read failures; change
the except block to log the failure (use logger.exception or logger.error(...,
exc_info=True) and include the file path p) so you capture the error details,
and do NOT set the empty sentinel when the exception occurs—only set
_CACHE["_hparams_tokenizer_name_resolved"] = "" when the file was successfully
read and no benchmark_tokenizer_name was present (i.e., keep the successful-path
assignments in place but remove or guard the final empty-cache write on
parse/read exceptions).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 2faac8b1-f3b6-46f5-961c-040ba6a3605c
📒 Files selected for processing (3)
environments/templar/Dockerfileenvironments/templar/env.pylocal_test/simulate_validator.py
✅ Files skipped from review due to trivial changes (1)
- local_test/simulate_validator.py
🚧 Files skipped from review as they are similar to previous changes (1)
- environments/templar/Dockerfile
ea50043 to
8b5f134
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
local_test/train_pp.py (3)
60-60:⚠️ Potential issue | 🔴 CriticalReplace
__import__("os")with standardimport os.Line 60 uses
__import__("os")which is inFORBIDDEN_NAMESand will cause the validator's security scan to reject this code.+import os import torch import torch.distributed as dist import torch.nn.functional as FThen change line 60:
- local_rank = int(__import__("os").environ.get("LOCAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` at line 60, Replace the dynamic import usage at the local_rank assignment: remove __import__("os") and add a top-level standard import os, then use os.environ.get("LOCAL_RANK", "0") when computing local_rank (the variable assignment referencing local_rank should be updated); ensure the file-level import is added instead of using the forbidden __import__ call.
216-241:⚠️ Potential issue | 🔴 Critical
_gather_pp_state()broadcasts with mismatchedsrcvalues will deadlock whendp_size > 1.Lines 221, 224, 235, and 241 use
broadcast_object_list()on the default world group withsrcvalues that differ across data-parallel replicas. With the declared topology (dp_size=2, pp_size=2), this creates invalid collective sequences that will hang.Fix by creating a per-pipeline process group for state gathering, or restrict state assembly to a single DP replica:
+ # Create pipeline-local group for state gathering + pipeline_id = global_rank // pp_size # 0 for ranks 0,1; 1 for ranks 2,3 + pipeline_ranks = [pipeline_id * pp_size + s for s in range(pp_size)] + pp_group = dist.new_group(pipeline_ranks) + if is_first_stage: keys_and_shapes = [] for k, v in my_state.items(): keys_and_shapes.append((k, v.shape, v.dtype)) obj_list = [keys_and_shapes] - dist.broadcast_object_list(obj_list, src=dist.get_rank()) + dist.broadcast_object_list(obj_list, src=global_rank, group=pp_group) # ... similar changes for other broadcasts🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 216 - 241, The _gather_pp_state() implementation uses dist.broadcast_object_list() with different src ranks across data-parallel replicas and will deadlock when dp_size>1; fix by either performing the gather only on one DP replica (e.g., guard the whole gather/send/recv block with a check like "if global_rank % dp_size == 0" so only one replica participates) or create a per-pipeline process group and use that group for the collectives (call torch.distributed.new_group(...) with the ranks for that pipeline slice and pass the group to broadcast_object_list/recv via the group argument), and ensure both the sending side (is_first_stage) and receiving side use the same group and src rank (pp_peer within that group) when calling dist.broadcast_object_list and dist.recv so collectives are matched across all processes.
87-91:⚠️ Potential issue | 🔴 CriticalCreate all DP groups in the same order on every rank.
PyTorch requires all ranks to participate in
new_group()calls in the same order. Currently, different ranks callnew_group()with different rank lists ([0,2]vs[1,3]), which will hang.- if dp_size > 1: - dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == pp_rank] - dp_group = dist.new_group(dp_ranks) - else: - dp_group = None + dp_group = None + if dp_size > 1: + for stage in range(pp_size): + ranks = [r for r in range(num_gpus) if (r % pp_size) == stage] + group = dist.new_group(ranks) + if stage == pp_rank: + dp_group = group🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 87 - 91, The current creation of dp_group with dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == pp_rank] causes different ranks to call dist.new_group() with different rank lists and will hang; instead, build all DP groups in the same order on every rank by iterating over all possible pp partitions (e.g., for each i in range(pp_size) compute ranks_i = [r for r in range(num_gpus) if (r % pp_size) == i]), call dist.new_group(ranks_i) for each i so every rank participates in the same sequence of new_group() calls, collect those groups, and then assign dp_group = groups[pp_rank] (or None when dp_size <= 1); ensure the symbols referenced are dp_size, dp_ranks, dp_group, dist.new_group, pp_size, pp_rank, num_gpus.
🧹 Nitpick comments (2)
local_test/train_pp.py (2)
107-107: Rename unused loop variablestepto_step.Per static analysis, the loop control variable is unused.
- for step in range(num_steps): + for _step in range(num_steps):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` at line 107, The loop control variable "step" in the for loop "for step in range(num_steps):" is unused; rename it to "_step" to satisfy the linter and make intent clear (change "for step in range(num_steps):" to "for _step in range(num_steps):"), and verify there are no usages of "step" inside the loop that need to be updated.
243-244: Iterate over values only sincekis unused.- for k, v in my_state.items(): - dist.send(v.contiguous(), dst=pp_peer) + for v in my_state.values(): + dist.send(v.contiguous(), dst=pp_peer)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 243 - 244, The loop iterates over my_state items but never uses the key `k`; change the iteration to loop over values only to avoid unused variable warnings and clarify intent—replace the for k, v in my_state.items() loop with a values-only iteration (e.g., for v in my_state.values()) and keep the body calling dist.send(v.contiguous(), dst=pp_peer) so the send logic (dist.send, v.contiguous(), pp_peer) remains unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/train_pp.py`:
- Around line 167-172: The fallback creates a 1D tensor when final_logits is
None which fails the validator's 3D-shape check; update the None case in the
is_last_stage block so final_logits is a properly-shaped 3D zero tensor (match
the expected logits shape used elsewhere, e.g., batch/seq/feature dims) and set
its device and dtype consistently before calling dist.send; adjust the creation
of final_logits in the block that references final_loss/device/pp_peer/dist.send
so you produce a contiguous 3D torch.zeros tensor (or use a stored logits shape
variable if available) instead of torch.zeros(1).
---
Duplicate comments:
In `@local_test/train_pp.py`:
- Line 60: Replace the dynamic import usage at the local_rank assignment: remove
__import__("os") and add a top-level standard import os, then use
os.environ.get("LOCAL_RANK", "0") when computing local_rank (the variable
assignment referencing local_rank should be updated); ensure the file-level
import is added instead of using the forbidden __import__ call.
- Around line 216-241: The _gather_pp_state() implementation uses
dist.broadcast_object_list() with different src ranks across data-parallel
replicas and will deadlock when dp_size>1; fix by either performing the gather
only on one DP replica (e.g., guard the whole gather/send/recv block with a
check like "if global_rank % dp_size == 0" so only one replica participates) or
create a per-pipeline process group and use that group for the collectives (call
torch.distributed.new_group(...) with the ranks for that pipeline slice and pass
the group to broadcast_object_list/recv via the group argument), and ensure both
the sending side (is_first_stage) and receiving side use the same group and src
rank (pp_peer within that group) when calling dist.broadcast_object_list and
dist.recv so collectives are matched across all processes.
- Around line 87-91: The current creation of dp_group with dp_ranks = [r for r
in range(num_gpus) if (r % pp_size) == pp_rank] causes different ranks to call
dist.new_group() with different rank lists and will hang; instead, build all DP
groups in the same order on every rank by iterating over all possible pp
partitions (e.g., for each i in range(pp_size) compute ranks_i = [r for r in
range(num_gpus) if (r % pp_size) == i]), call dist.new_group(ranks_i) for each i
so every rank participates in the same sequence of new_group() calls, collect
those groups, and then assign dp_group = groups[pp_rank] (or None when dp_size
<= 1); ensure the symbols referenced are dp_size, dp_ranks, dp_group,
dist.new_group, pp_size, pp_rank, num_gpus.
---
Nitpick comments:
In `@local_test/train_pp.py`:
- Line 107: The loop control variable "step" in the for loop "for step in
range(num_steps):" is unused; rename it to "_step" to satisfy the linter and
make intent clear (change "for step in range(num_steps):" to "for _step in
range(num_steps):"), and verify there are no usages of "step" inside the loop
that need to be updated.
- Around line 243-244: The loop iterates over my_state items but never uses the
key `k`; change the iteration to loop over values only to avoid unused variable
warnings and clarify intent—replace the for k, v in my_state.items() loop with a
values-only iteration (e.g., for v in my_state.values()) and keep the body
calling dist.send(v.contiguous(), dst=pp_peer) so the send logic (dist.send,
v.contiguous(), pp_peer) remains unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: c932e0fe-7a0e-43db-86b4-7bc9cf918a80
📒 Files selected for processing (9)
environments/templar/env.pylocal_test/setup_benchmark.pylocal_test/simulate_validator.pylocal_test/train_ddp.pylocal_test/train_fsdp.pylocal_test/train_mixed.pylocal_test/train_pp.pylocal_test/train_tp.pysrc/crusades/tui/app.py
✅ Files skipped from review due to trivial changes (1)
- local_test/simulate_validator.py
🚧 Files skipped from review as they are similar to previous changes (5)
- local_test/train_fsdp.py
- local_test/train_mixed.py
- local_test/train_tp.py
- local_test/train_ddp.py
- local_test/setup_benchmark.py
5424b49 to
5b60689
Compare
* Add LLM-based exploit detection for miner submissions Uses OpenRouter (Claude) to analyze miner code for subtle exploits that bypass the static AST scanner. Runs after code download and hash verification but before GPU evaluation. Fail-open policy: if the LLM call fails for any reason (timeout, API error, missing key, unparseable response), the submission is allowed through — the static scanner remains the hard gate. Made-with: Cursor * Add configurable tokenizer support with local Gemma 3 tokenizer Bakes the google/gemma-3-27b-it tokenizer into the repo and Docker image so evaluation containers load it locally without network access or HF tokens. Adds benchmark_tokenizer_name to hparams.json and env.py helpers to resolve, cache, and resize model embeddings when a separate tokenizer is configured. Made-with: Cursor * Remove unnecessary tokenizer download from Dockerfile The tokenizer files are already COPY'd from tokenizers/ into the image, so the build-time download/verification step was redundant. Made-with: Cursor * Update train_fsdp.py for larger tokenizer vocab and fix TUI client query - Switch to FULL_SHARD, gradient checkpointing, chunked CE loss to handle 262K vocab from Gemma tokenizer without OOM - Fix TUI client to show oldest evaluating submission (the one actively being processed) and include miner_uid in current_evaluation Made-with: Cursor * Fix reference eval OOM: use per-layer FSDP wrapping and gradient checkpointing Single-unit FSDP wrap tried to all-gather the entire 8.4B model (~15.6GB) on each GPU at once, causing OOM with the 262K-vocab tokenizer. Now wraps per transformer layer and enables gradient checkpointing. Also reduces reference micro-batch size from 4 to 2. Made-with: Cursor * Log tokenizer config read failures instead of silently swallowing them Made-with: Cursor * Guard _load_env_file() with try/except for fail-open consistency Made-with: Cursor * Align LLM exploit detection prompt with actual evaluator contract - get_strategy() accepts dp_size/tp_size only (pp_size is a separate PR) - Fix evaluation order: reference run happens before miner code loads - Remove duplicate "Module load" step, correct step numbering Made-with: Cursor * Fix backward_prefetch type and chunked cross-entropy edge case in train_fsdp - backward_prefetch expects BackwardPrefetch enum, not bool - Initialize n_valid as tensor to avoid AttributeError when all labels are -100 Made-with: Cursor * Send head + tail of large files to LLM exploit detector Instead of only the first 50K chars (exploit code at the end would be missed), send 75% from the head and 25% from the tail so both ends are covered. Made-with: Cursor * Send full miner code to LLM exploit detector without truncation Made-with: Cursor * Fix all local_test train files for 262K-vocab tokenizer (OOM prevention) All files now use: - Chunked cross-entropy to avoid materializing full [batch*seq, 262K] logits - Gradient checkpointing to reduce activation memory - Explicit bf16 model casting - Proper tensor initialization for edge cases (all labels == -100) train.py: simplified single-GPU reference (only for num_gpus=1 setups) train_ddp.py: DDP with micro-batch=1 gradient accumulation train_tp.py: tensor parallel with DTensor state gathering train_mixed.py: 2D DP+TP with manual gradient all-reduce train_fsdp.py: already fixed in prior commit Made-with: Cursor * Add Docker container cleanup to prevent GPU leaks between evaluations Name eval containers (crusades-eval-*), force-remove on timeout/shutdown, and clean up stale containers + networks on validator startup. Made-with: Cursor * Fix OOM in TP/mixed train files by chunking lm_head loss computation Replace _chunked_cross_entropy (which still materialized full [batch, seq, vocab] logits) with _chunked_lm_loss that processes lm_head in chunks over the sequence dimension. Peak memory drops from O(batch*seq*vocab) to O(chunk_tokens*vocab), saving ~4-8 GB VRAM with 262K vocab. Made-with: Cursor * Use datetime.UTC alias in backup_db.py (ruff UP017) Made-with: Cursor * Add .env.example with placeholder credentials Made-with: Cursor * Fix final_logits to use full sequence length for validator check The validator expects final_logits with shape [batch, seq_len, vocab], not [batch, 1, vocab]. Computing under torch.no_grad() avoids OOM since no gradients are stored for the lm_head output. Made-with: Cursor * Update hparams.json --------- Co-authored-by: Shivam <shivam@mbzuai.ac.ae>
Extend the miner parallelism contract to support pipeline parallelism (pp_size) alongside existing dp_size and tp_size. Update env.py strategy detection, validation (dp*tp*pp == num_gpus), and data sharding. Add train_pp.py reference implementation and update all local_test files with explicit pp_size in get_strategy() and documentation. Made-with: Cursor
The validator only verifies rank 0's InnerStepsResult, but rank 0 is the first pipeline stage which never computes loss or logits. Add inter-stage communication so the last stage sends final_logits and final_loss back to rank 0 after training. Also fix stale comments in env.py and update MFU docstring to document pp_size behavior. Made-with: Cursor
Read benchmark_tokenizer_name from hparams.json to decouple the tokenizer from the model. When configured, env.py loads the separate tokenizer for data tokenization and resizes the model's embedding and lm_head layers to match the new vocab size (262k for Gemma 3 vs 152k for Qwen). The Dockerfile caches the Gemma tokenizer alongside the model at build time. No API surface changes needed — hparams.json is already mounted into the container. Made-with: Cursor
…before warmup - Add _resize_model_for_tokenizer() helper to centralize embedding resize - Call resize after _load_model in post-reference and post-warmup reload paths to prevent state_dict shape mismatch with different tokenizer - Move _verify_trainable_params before warmup so PP miners that freeze layers in inner_steps don't fail the 100% trainable check - Use all_reduce to broadcast trainable check failure across ranks - Switch Dockerfile to BuildKit secrets for HF_TOKEN (no leak in history) - Update simulate_validator.py docker build command to match Made-with: Cursor
The Gemma tokenizer is baked in via COPY tokenizers/ and the Qwen model is not gated, so no HF_TOKEN secret mount is needed during build. Made-with: Cursor
Uses statistics.median_low to match the validator's scoring logic. Previously computed average but the website labels it as "MEDIAN MFU". Made-with: Cursor
Made-with: Cursor
a14696e to
0e22a27
Compare
Made-with: Cursor
… hparams Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
Made-with: Cursor
… support - docs/Validator.md: fix all thresholds, GPU counts, eval params, add pp_size - arbos/README.md: fix GPU counts in examples and FAQ - arbos/test_agent.py: fix stale assertions (model, gpus, thresholds) - arbos/tester.py: default num_gpus 2→4 - arbos/llm_client.py: default num_gpus 2→4 - local_test/setup_benchmark.py: fallback model Qwen2.5-3B→7B - Dockerfile: add comments about tokenizer setup and HF_TOKEN Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/crusades/tui/app.py (1)
469-492:⚠️ Potential issue | 🟡 MinorUse explicit success/None checks instead of truthiness for MFU.
At Line 472 and Line 480,
if mfutreats0.0as absent. That can hide valid zero values and makes the title fallback at Line 491 ambiguous (0.00%can mean “no valid MFU” or “actual zero median”).Suggested patch
@@ - for idx, eval_data in enumerate(detail.evaluations, 1): - mfu = eval_data.get("mfu", 0) + for idx, eval_data in enumerate(detail.evaluations, 1): + mfu = eval_data.get("mfu") tps = eval_data.get("tokens_per_second", 0) - if mfu: - mfu_scores.append(mfu) - success = eval_data.get("success", False) + if success and mfu is not None: + mfu_scores.append(float(mfu)) status_display = "[green]✓[/green]" if success else "[red]✗[/red]" @@ - f"{mfu:.2f}%" if mfu else "-", + f"{float(mfu):.2f}%" if (success and mfu is not None) else "-", @@ - median_mfu = statistics.median_low(mfu_scores) if mfu_scores else 0 - title = f"[bold]Evaluations[/bold] (Median MFU: [green]{median_mfu:.2f}%[/green])" + median_mfu = statistics.median_low(mfu_scores) if mfu_scores else None + median_label = f"{median_mfu:.2f}%" if median_mfu is not None else "N/A" + title = f"[bold]Evaluations[/bold] (Median MFU: [green]{median_label}[/green])"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/crusades/tui/app.py` around lines 469 - 492, The loop currently treats mfu via truthiness which hides valid 0.0 values; change checks to explicit None checks: use mfu = eval_data.get("mfu") (no default 0), append to mfu_scores when mfu is not None (so zeros are included), and when rendering use "-" if mfu is None else f"{mfu:.2f}%". Also compute median_mfu = statistics.median_low(mfu_scores) if mfu_scores else None and update the title construction to show "-" (or a clear placeholder) when median_mfu is None instead of showing "0.00%". Apply these edits around detail.evaluations, mfu_scores, median_mfu and the title string building.
♻️ Duplicate comments (5)
local_test/train_pp.py (5)
60-60:⚠️ Potential issue | 🔴 CriticalReplace
__import__("os")with standardimport os.The
__import__builtin is inFORBIDDEN_NAMESand will be rejected by the validator's security scanner beforeinner_steps()runs.Fix: Use standard import
Add at the top of the file:
+import os import torchThen change line 60:
- local_rank = int(__import__("os").environ.get("LOCAL_RANK", "0")) + local_rank = int(os.environ.get("LOCAL_RANK", "0"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` at line 60, Replace the dynamic __import__("os") call in the local_rank assignment with a normal module import: add "import os" at the top of the file (if not already present) and update the expression using os.environ.get(...) instead of __import__("os").environ.get(...); target the local_rank assignment to locate and update the code.
129-132:⚠️ Potential issue | 🟡 MinorHardcoded
bfloat16dtype may mismatch model's actual dtype.The hidden tensor assumes
bfloat16, but the model may use a different dtype. Infer from model parameters.Fix: Infer dtype from model
+ model_dtype = next(model.parameters()).dtype - hidden = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=torch.bfloat16) + hidden = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=model_dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 129 - 132, The hidden tensor is created with a hardcoded dtype=torch.bfloat16 which can mismatch the model; instead infer the dtype from the model parameters (e.g. param_dtype = next(model.parameters()).dtype) and use that when creating hidden (dtype=param_dtype or hidden = torch.zeros(..., device=device, dtype=param_dtype)), then keep the rest of the sequence (hidden.requires_grad_(True), dist.recv(hidden, src=pp_peer), hidden = hidden.detach().requires_grad_(True)) unchanged; reference the hidden variable, bs, seq_len, hidden_size, device and model/next(model.parameters()).dtype to locate where to change.
216-241:⚠️ Potential issue | 🔴 Critical
_gather_pp_state()uses world-group collectives with mismatchedsrcvalues.Lines 221 and 224/235 call
broadcast_object_list()withsrcvalues that differ across ranks. With 4 ranks (dp_size=2, pp_size=2), ranks call broadcasts with different sources, causing deadlock on the default world group.Conceptual fix
Either:
- Create a per-pipeline process group and use it for all broadcasts within the pipeline
- Or restrict state assembly to a single replica (e.g., only pipeline replica 0)
Example using pipeline groups:
# Create pipeline groups (all ranks create all groups in same order) pipeline_group = None for dp_idx in range(dp_size): ranks = [dp_idx * pp_size + stage for stage in range(pp_size)] group = dist.new_group(ranks) if global_rank in ranks: pipeline_group = group # Then use group=pipeline_group in broadcast_object_list calls🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 216 - 241, In _gather_pp_state() the calls to dist.broadcast_object_list use the global world group with different src ranks across pipeline replicas which deadlocks; fix by creating a per-pipeline process group (e.g., build pipeline_group using dp_size/pp_size and dist.new_group with the ranks for this pipeline) and pass group=pipeline_group into all dist.broadcast_object_list and dist.recv calls in _gather_pp_state (ensure pipeline_group is created on all ranks in the same order and selected into a local variable), or alternatively gate the assembly so only a single replica (pipeline replica 0) performs the broadcasts/receives; update usages of dist.broadcast_object_list, dist.recv, pp_peer and any src arguments to use the pipeline_group or run only on the designated replica.
167-172:⚠️ Potential issue | 🟡 MinorFallback logits tensor has wrong shape for validator's 3D check.
If
final_logitsremainsNone, the fallback creates a 1D tensor which fails the validator's shape validation.Fix: Create properly-shaped 3D tensor
if is_last_stage: if final_logits is None: - final_logits = torch.zeros(1, device=device) + vocab_size = model.config.vocab_size + final_logits = torch.zeros(1, 1, vocab_size, device=device, dtype=torch.bfloat16)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 167 - 172, The fallback creates a 1D final_logits which fails the validator expecting a 3D tensor; change the fallback to allocate a 3D tensor matching the model/validator shape (e.g., batch x seq_len x vocab). Obtain the shape from an existing logits/template tensor if available (use something like template_logits.shape) or define the correct dims explicitly, then replace final_logits = torch.zeros(1, device=device) with final_logits = torch.zeros((batch, seq_len, vocab), device=device, dtype=appropriate_dtype) so the downstream validator accepts it before the dist.send to pp_peer.
87-91:⚠️ Potential issue | 🔴 CriticalDP group creation may hang when
dp_size > 1due to inconsistentnew_group()order.All ranks must call
dist.new_group()in the same order. Currently, each rank only creates its own DP group, which can cause hangs. Create all DP groups on all ranks.Fix: Create all DP groups deterministically
- if dp_size > 1: - dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == pp_rank] - dp_group = dist.new_group(dp_ranks) - else: - dp_group = None + dp_group = None + if dp_size > 1: + # All ranks must create all groups in the same order + for stage in range(pp_size): + ranks = [r for r in range(num_gpus) if (r % pp_size) == stage] + group = dist.new_group(ranks) + if stage == pp_rank: + dp_group = group🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 87 - 91, DP group creation hangs because only the local rank calls dist.new_group() for its dp_ranks, violating the requirement that all ranks call dist.new_group() in the same order; to fix, when dp_size > 1 compute the full list of dp rank lists for every pipeline partition i in range(pp_size) (e.g., dp_ranks_for_i = [r for r in range(num_gpus) if (r % pp_size) == i]) and then loop over these lists and call dist.new_group(dp_ranks_for_i) on every process in the same deterministic order, storing the returned groups in a mapping and finally selecting dp_group = groups[pp_rank]; keep dp_group = None when dp_size == 1.
🧹 Nitpick comments (2)
local_test/train_pp.py (1)
107-107: Unused loop variablestep.Rename to
_stepto indicate intentional non-use.- for step in range(num_steps): + for _step in range(num_steps):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` at line 107, The loop variable `step` in the `for step in range(num_steps):` loop is unused; rename it to `_step` to signal intentional non-use and silence linters. Update the loop header in local_test/train_pp.py from `for step in range(num_steps):` to `for _step in range(num_steps):` and adjust any references (there should be none) accordingly so this change is purely a rename.local_test/setup_benchmark.py (1)
27-29: Consider addingHF_HOMEorHF_HUB_TOKENas fallback.The function checks
HF_TOKENandHUGGING_FACE_HUB_TOKEN, but some tools also useHF_HUB_TOKEN. This is a minor consistency improvement.Optional: Add HF_HUB_TOKEN fallback
def _hf_token() -> str | None: """Return HF token from env (needed for gated models like Gemma).""" - return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or None + return ( + os.environ.get("HF_TOKEN") + or os.environ.get("HF_HUB_TOKEN") + or os.environ.get("HUGGING_FACE_HUB_TOKEN") + or None + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/setup_benchmark.py` around lines 27 - 29, Update the _hf_token function to include additional common environment fallbacks: check HF_HUB_TOKEN and HF_HOME after HF_TOKEN and HUGGING_FACE_HUB_TOKEN so the function returns the token from any of those variables; reference the _hf_token function and the env vars HF_TOKEN, HUGGING_FACE_HUB_TOKEN, HF_HUB_TOKEN, and HF_HOME when making the change to ensure broader compatibility.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/train.py`:
- Around line 103-104: The strategy dict returned by get_strategy incorrectly
declares pp_size=2 while the training uses FullyShardedDataParallel (FSDP);
update get_strategy to return the correct FSDP topology (e.g., {"dp_size": 4,
"tp_size": 1}) and remove or set pp_size to 1, and also update the file header
comment (the topology description near the top) to match the corrected FSDP
strategy; if you intended pipeline parallelism instead, move to train_pp.py
rather than changing FSDP usage.
---
Outside diff comments:
In `@src/crusades/tui/app.py`:
- Around line 469-492: The loop currently treats mfu via truthiness which hides
valid 0.0 values; change checks to explicit None checks: use mfu =
eval_data.get("mfu") (no default 0), append to mfu_scores when mfu is not None
(so zeros are included), and when rendering use "-" if mfu is None else
f"{mfu:.2f}%". Also compute median_mfu = statistics.median_low(mfu_scores) if
mfu_scores else None and update the title construction to show "-" (or a clear
placeholder) when median_mfu is None instead of showing "0.00%". Apply these
edits around detail.evaluations, mfu_scores, median_mfu and the title string
building.
---
Duplicate comments:
In `@local_test/train_pp.py`:
- Line 60: Replace the dynamic __import__("os") call in the local_rank
assignment with a normal module import: add "import os" at the top of the file
(if not already present) and update the expression using os.environ.get(...)
instead of __import__("os").environ.get(...); target the local_rank assignment
to locate and update the code.
- Around line 129-132: The hidden tensor is created with a hardcoded
dtype=torch.bfloat16 which can mismatch the model; instead infer the dtype from
the model parameters (e.g. param_dtype = next(model.parameters()).dtype) and use
that when creating hidden (dtype=param_dtype or hidden = torch.zeros(...,
device=device, dtype=param_dtype)), then keep the rest of the sequence
(hidden.requires_grad_(True), dist.recv(hidden, src=pp_peer), hidden =
hidden.detach().requires_grad_(True)) unchanged; reference the hidden variable,
bs, seq_len, hidden_size, device and model/next(model.parameters()).dtype to
locate where to change.
- Around line 216-241: In _gather_pp_state() the calls to
dist.broadcast_object_list use the global world group with different src ranks
across pipeline replicas which deadlocks; fix by creating a per-pipeline process
group (e.g., build pipeline_group using dp_size/pp_size and dist.new_group with
the ranks for this pipeline) and pass group=pipeline_group into all
dist.broadcast_object_list and dist.recv calls in _gather_pp_state (ensure
pipeline_group is created on all ranks in the same order and selected into a
local variable), or alternatively gate the assembly so only a single replica
(pipeline replica 0) performs the broadcasts/receives; update usages of
dist.broadcast_object_list, dist.recv, pp_peer and any src arguments to use the
pipeline_group or run only on the designated replica.
- Around line 167-172: The fallback creates a 1D final_logits which fails the
validator expecting a 3D tensor; change the fallback to allocate a 3D tensor
matching the model/validator shape (e.g., batch x seq_len x vocab). Obtain the
shape from an existing logits/template tensor if available (use something like
template_logits.shape) or define the correct dims explicitly, then replace
final_logits = torch.zeros(1, device=device) with final_logits =
torch.zeros((batch, seq_len, vocab), device=device, dtype=appropriate_dtype) so
the downstream validator accepts it before the dist.send to pp_peer.
- Around line 87-91: DP group creation hangs because only the local rank calls
dist.new_group() for its dp_ranks, violating the requirement that all ranks call
dist.new_group() in the same order; to fix, when dp_size > 1 compute the full
list of dp rank lists for every pipeline partition i in range(pp_size) (e.g.,
dp_ranks_for_i = [r for r in range(num_gpus) if (r % pp_size) == i]) and then
loop over these lists and call dist.new_group(dp_ranks_for_i) on every process
in the same deterministic order, storing the returned groups in a mapping and
finally selecting dp_group = groups[pp_rank]; keep dp_group = None when dp_size
== 1.
---
Nitpick comments:
In `@local_test/setup_benchmark.py`:
- Around line 27-29: Update the _hf_token function to include additional common
environment fallbacks: check HF_HUB_TOKEN and HF_HOME after HF_TOKEN and
HUGGING_FACE_HUB_TOKEN so the function returns the token from any of those
variables; reference the _hf_token function and the env vars HF_TOKEN,
HUGGING_FACE_HUB_TOKEN, HF_HUB_TOKEN, and HF_HOME when making the change to
ensure broader compatibility.
In `@local_test/train_pp.py`:
- Line 107: The loop variable `step` in the `for step in range(num_steps):` loop
is unused; rename it to `_step` to signal intentional non-use and silence
linters. Update the loop header in local_test/train_pp.py from `for step in
range(num_steps):` to `for _step in range(num_steps):` and adjust any references
(there should be none) accordingly so this change is purely a rename.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: e6b9c1dd-6fc7-49f2-a18c-e8312c3ba120
📒 Files selected for processing (17)
README.mdarbos/README.mdarbos/llm_client.pyarbos/test_agent.pyarbos/tester.pydocs/Validator.mdenvironments/templar/Dockerfileenvironments/templar/env.pylocal_test/setup_benchmark.pylocal_test/simulate_validator.pylocal_test/train.pylocal_test/train_ddp.pylocal_test/train_fsdp.pylocal_test/train_mixed.pylocal_test/train_pp.pylocal_test/train_tp.pysrc/crusades/tui/app.py
✅ Files skipped from review due to trivial changes (2)
- environments/templar/Dockerfile
- arbos/test_agent.py
🚧 Files skipped from review as they are similar to previous changes (6)
- local_test/simulate_validator.py
- local_test/train_mixed.py
- local_test/train_fsdp.py
- local_test/train_ddp.py
- local_test/train_tp.py
- environments/templar/env.py
Disable NVLink P2P (NCCL_P2P_DISABLE=1) and NVLink SHARP (NCCL_NVLS_ENABLE=0) to force GPU communication through host memory (~32 GB/s) instead of NVLink (~600 GB/s). This makes collective-heavy strategies (FSDP full-shard, TP) much slower while pipeline parallelism — which only sends activations between stages — becomes the dominant strategy for high MFU. Applied in: runner.py (Docker mode), Dockerfile (Basilica mode), simulate_validator.py, README.md, arbos/tester.py. Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@environments/templar/Dockerfile`:
- Around line 2-7: The Dockerfile in environments/templar expects COPY of
repo-root paths (hparams/hparams.json and tokenizers/), but
AffinetesRunner.build_validator_image() currently runs docker build with
environments/templar as the context; change build_validator_image() so the
docker build context is the repository root (.) and pass the Dockerfile path via
the --file/ -f flag (pointing to environments/templar/Dockerfile) so the COPY
instructions can access hparams/ and tokenizers/ from the root; update any tests
or callers that assume the old context if needed.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 581a249a-04e5-4e11-8c33-af59efa9ba0c
📒 Files selected for processing (5)
README.mdarbos/tester.pyenvironments/templar/Dockerfilelocal_test/simulate_validator.pysrc/crusades/affinetes/runner.py
🚧 Files skipped from review as they are similar to previous changes (3)
- arbos/tester.py
- local_test/simulate_validator.py
- README.md
| # Run from repo root so hparams.json and tokenizers/ are accessible. | ||
| # | ||
| # The tokenizers/ directory must exist locally (it's in .gitignore). | ||
| # Run `uv run local_test/setup_benchmark.py` first — it downloads the | ||
| # gated Gemma tokenizer (requires HF_TOKEN env var) and saves it to | ||
| # tokenizers/google/gemma-3-27b-it/ for this COPY step. |
There was a problem hiding this comment.
The automated image builder still uses the wrong Docker context.
These instructions are correct for manual builds, but AffinetesRunner.build_validator_image() still runs docker build with environments/templar as the context. Because this Dockerfile COPYs hparams/hparams.json and tokenizers/ from the repo root, that code path will keep failing.
🔧 Align build_validator_image() with this Dockerfile
- cmd = [
- "docker",
- "build",
- "-t",
- self.validator_image,
- str(env_path),
- ]
+ repo_root = env_path.parent.parent
+ cmd = [
+ "docker",
+ "build",
+ "-f",
+ str(env_path / "Dockerfile"),
+ "-t",
+ self.validator_image,
+ str(repo_root),
+ ]🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@environments/templar/Dockerfile` around lines 2 - 7, The Dockerfile in
environments/templar expects COPY of repo-root paths (hparams/hparams.json and
tokenizers/), but AffinetesRunner.build_validator_image() currently runs docker
build with environments/templar as the context; change build_validator_image()
so the docker build context is the repository root (.) and pass the Dockerfile
path via the --file/ -f flag (pointing to environments/templar/Dockerfile) so
the COPY instructions can access hparams/ and tokenizers/ from the root; update
any tests or callers that assume the old context if needed.
Apply high-MFU optimizations from train.py to all strategy variants: - torch.compile + inductor/dynamo tuning (FSDP, DDP, PP) - flash_attn CrossEntropyLoss (all files) - Selective Activation Checkpointing with matmul-save policy (all files) - bf16 logits via lm_head graph break, no fp32 upcast (all files) - Pre-loaded batches before training loop (all files) - TF32 matmul, expandable segments allocator (all files) - Fused AdamW where compatible (FSDP, DDP, PP) PP-specific: per-stage torch.compile, manual SAC wrapping per layer, _UNCHECKPOINT_LAST_N_PER_STAGE=4 for half-model stages. TP/Mixed: no torch.compile (DTensor incompatible with inductor). Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
local_test/train_pp.py (2)
283-330:⚠️ Potential issue | 🔴 Critical
_gather_pp_state()can’t complete under the NCCL evaluator as written.The first-stage ranks are issuing world-group
broadcast_object_list()calls with differentsrcvalues, and the tensor leg is sending/receiving.cpu()clones even though the evaluator initializes NCCL. That will deadlock or raise before rank 0 can assemblefinal_state.Fix direction
- Create a per-pipeline process group for each replica instead of using the world group.
- Run the metadata exchange inside that subgroup with a constant subgroup-local source.
- Keep the actual
send/recvtensors on CUDA, then move them to CPU only after rank 0 has received them.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 283 - 330, _gather_pp_state is deadlocking under NCCL because it uses the global world group with varying src ranks and moves tensors to CPU before send/recv; fix by creating a per-pipeline process group (subgroup) for the pipeline replica and use that subgroup for metadata exchange and point all broadcast_object_list() calls to a constant subgroup-local src (e.g., the first rank in the subgroup) instead of dist.get_rank()/pp_peer on the world group, then perform the tensor transfer using CUDA tensors (keep v on device and call dist.send/dist.recv on CUDA buffers) and only .cpu().clone() the received tensors on rank 0 after recv completes; update the branches in _gather_pp_state that call dist.broadcast_object_list, dist.send, and dist.recv to use the new subgroup and device tensors and return the assembled my_state on global_rank == 0.
188-192:⚠️ Potential issue | 🔴 CriticalCreate every DP subgroup in the same global order.
With
pp_size=2, ranks 0/2 callnew_group([0, 2])while ranks 1/3 callnew_group([1, 3])as their first subgroup. PyTorch expects subgroup creation order to match across the world, so this can hang before the first DP gradient sync.Suggested fix
- dp_group = None - if dp_size > 1: - dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == pp_rank] - dp_group = dist.new_group(dp_ranks) + dp_group = None + if dp_size > 1: + for stage in range(pp_size): + ranks = [r for r in range(num_gpus) if (r % pp_size) == stage] + group = dist.new_group(ranks) + if stage == pp_rank: + dp_group = group🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 188 - 192, The DP subgroup creation order differs across pipeline ranks causing a hang; fix by having every process create all DP subgroups in the same global order and pick the one for its pp_rank. Concretely, replace the single dist.new_group(dp_ranks) call with a loop over dp_start in range(pp_size) that builds dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == dp_start], calls dist.new_group(dp_ranks) for each dp_start (so new_group is invoked in identical order everywhere), and set dp_group when dp_start == pp_rank; keep references to dp_size, pp_size, pp_rank, dp_ranks, dp_group, and dist.new_group to locate the change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/Validator.md`:
- Around line 78-85: The documented validator thresholds differ from the
evaluator defaults; update the evaluator defaults to match the docs by changing
the default values for max_loss_difference to 0.4, weight_relative_error_max to
0.015, and max_plausible_mfu to 85.0 where they are declared/used (specifically
in the Actor.evaluate(...) signature/implementation and the EvaluateRequest
default values) so that callers that rely on defaults enforce the same contract
described in the Validator.md table.
---
Duplicate comments:
In `@local_test/train_pp.py`:
- Around line 283-330: _gather_pp_state is deadlocking under NCCL because it
uses the global world group with varying src ranks and moves tensors to CPU
before send/recv; fix by creating a per-pipeline process group (subgroup) for
the pipeline replica and use that subgroup for metadata exchange and point all
broadcast_object_list() calls to a constant subgroup-local src (e.g., the first
rank in the subgroup) instead of dist.get_rank()/pp_peer on the world group,
then perform the tensor transfer using CUDA tensors (keep v on device and call
dist.send/dist.recv on CUDA buffers) and only .cpu().clone() the received
tensors on rank 0 after recv completes; update the branches in _gather_pp_state
that call dist.broadcast_object_list, dist.send, and dist.recv to use the new
subgroup and device tensors and return the assembled my_state on global_rank ==
0.
- Around line 188-192: The DP subgroup creation order differs across pipeline
ranks causing a hang; fix by having every process create all DP subgroups in the
same global order and pick the one for its pp_rank. Concretely, replace the
single dist.new_group(dp_ranks) call with a loop over dp_start in range(pp_size)
that builds dp_ranks = [r for r in range(num_gpus) if (r % pp_size) ==
dp_start], calls dist.new_group(dp_ranks) for each dp_start (so new_group is
invoked in identical order everywhere), and set dp_group when dp_start ==
pp_rank; keep references to dp_size, pp_size, pp_rank, dp_ranks, dp_group, and
dist.new_group to locate the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 162fa1d0-9f9a-4647-8f8c-38a2574d1f4e
📒 Files selected for processing (18)
README.mdarbos/README.mdarbos/llm_client.pyarbos/test_agent.pyarbos/tester.pydocs/Validator.mdenvironments/templar/Dockerfileenvironments/templar/env.pylocal_test/setup_benchmark.pylocal_test/simulate_validator.pylocal_test/train.pylocal_test/train_ddp.pylocal_test/train_fsdp.pylocal_test/train_mixed.pylocal_test/train_pp.pylocal_test/train_tp.pysrc/crusades/affinetes/runner.pysrc/crusades/tui/app.py
🚧 Files skipped from review as they are similar to previous changes (2)
- local_test/setup_benchmark.py
- local_test/simulate_validator.py
| | **Loss Validity** | Loss must be positive, not NaN, close to reference | `max_loss_difference: 0.4` | Active | | ||
| | **Final Weight Verification** | Model weights after training must match reference | `weight_relative_error_max: 0.015` | Active | | ||
| | **Trainable Params** | All params must be trainable | `100%` | Active | | ||
| | **Params Changed** | Most param elements must change during training | `min: 70%` | Active | | ||
| | **Params Changed** | Most param elements must change during training | `min: 75%` | Active | | ||
| | **Timer Integrity** | Multiple timer sources must agree | `timer_divergence_threshold: 0.5%` | Active | | ||
| | **Min MFU** | Floor threshold -- submissions below are rejected | `min_mfu: 35%` | Active | | ||
| | **Max Plausible MFU** | Ceiling cap -- no legitimate code exceeds this | `max_plausible_mfu: 75%` | Active | | ||
| | **Success Rate** | Majority of runs must pass | `min_success_rate: 0.5` | Active | | ||
| | **Min MFU** | Floor threshold -- submissions below are rejected | `min_mfu: 50%` | Active | | ||
| | **Max Plausible MFU** | Ceiling cap -- no legitimate code exceeds this | `max_plausible_mfu: 85%` | Active | | ||
| | **Success Rate** | Minimum passing runs to accept | `min_success_rate: 0.3` | Active | |
There was a problem hiding this comment.
These published thresholds no longer match the evaluator defaults.
environments/templar/env.py still defaults max_loss_difference=0.3, weight_relative_error_max=0.008, and max_plausible_mfu=75.0 in both Actor.evaluate() and EvaluateRequest. Unless every caller overrides them explicitly, this table describes a different validator contract than the one the evaluator actually enforces.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/Validator.md` around lines 78 - 85, The documented validator thresholds
differ from the evaluator defaults; update the evaluator defaults to match the
docs by changing the default values for max_loss_difference to 0.4,
weight_relative_error_max to 0.015, and max_plausible_mfu to 85.0 where they are
declared/used (specifically in the Actor.evaluate(...) signature/implementation
and the EvaluateRequest default values) so that callers that rely on defaults
enforce the same contract described in the Validator.md table.
train_ddp.py: Replace vanilla DDP with FSDP SHARD_GRAD_OP (ZeRO-2). Vanilla DDP + AdamW on 8.4B params needs ~101 GB (model 17 + adam 67 + grads 17), exceeding A100-80GB. SHARD_GRAD_OP shards optimizer states and gradients across GPUs: ~38 GB per GPU. train_mixed.py: Use checkpointed chunked lm_head+CE to avoid materializing full [B*S, 262K] logits (~8.6 GB) during backward. With dp=2 + tp=2, peak was ~75 GB leaving only 5 GB headroom. Checkpointing per chunk saves ~8 GB (stores 29 MB chunk_h instead of 2 GB chunk_logits), bringing peak to ~67 GB. Made-with: Cursor
Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (5)
local_test/train_mixed.py (2)
153-155: Unused loop variableidxin first enumeration.The loop at line 153 uses
enumerate()but never usesidx. A plain iteration would be clearer.♻️ Proposed fix
- for idx, layer in enumerate(model.model.layers): + for layer in model.model.layers: if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"): layer.self_attn.layer_idx = 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_mixed.py` around lines 153 - 155, The enumeration variable idx is unused in the loop over model.model.layers; replace the enumerate-based loop with a plain iteration to improve clarity: iterate directly over model.model.layers (for layer in model.model.layers) and keep the existing hasattr checks and assignment to layer.self_attn.layer_idx = 0 so behavior of the block (finding layers with self_attn and setting layer.self_attn.layer_idx) remains unchanged.
64-72: Unused fields in_AllSACclass.
num_ckpt_layersand_countare stored/incremented but never read. If these are vestigial from refactoring, consider removing them for clarity.♻️ Proposed simplification
class _AllSAC: - def __init__(self, num_ckpt_layers): - self.num_ckpt_layers = num_ckpt_layers - self._count = 0 + def __init__(self): + pass def __call__(self, fn, *args, **kwargs): - self._count += 1 ctx_fn = functools.partial(create_selective_checkpoint_contexts, _sac_policy) return ckpt.checkpoint(fn, *args, use_reentrant=False, context_fn=ctx_fn, **kwargs)Note: If you simplify here, also update line 168 to
_AllSAC().🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_mixed.py` around lines 64 - 72, Remove the unused fields from the _AllSAC class: delete the num_ckpt_layers attribute and the _count counter and any increments (they are neither read nor used); update the constructor to take no args (i.e., def __init__(self): pass or remove it entirely) so the class only defines __call__ which creates ctx_fn via functools.partial(create_selective_checkpoint_contexts, _sac_policy) and returns ckpt.checkpoint(fn, *args, use_reentrant=False, context_fn=ctx_fn, **kwargs); also update its instantiation to _AllSAC() where it’s created.local_test/train_ddp.py (3)
68-70: Unguarded import offlash_attnwill fail if the package is not installed.Other optional features (SAC, inductor config, dynamo config) are guarded with try/except, but
flash_attnis imported unconditionally. If this is a hard requirement, consider documenting it in the module docstring or raising a clearer error.♻️ Option A: Guard the import with a fallback
+try: from flash_attn.losses.cross_entropy import CrossEntropyLoss as _FlashCELoss _flash_ce_inst = _FlashCELoss(ignore_index=-100) +except ImportError: + raise ImportError("flash_attn is required for train_ddp.py. Install with: pip install flash-attn")♻️ Option B: Document in module docstring
# High-MFU Data-Parallel strategy for 262K vocab (google/gemma-3-27b-it tokenizer) # +# Requirements: flash_attn (for efficient cross-entropy loss) +# # Topology: dp_size=4, tp_size=1, pp_size=1🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_ddp.py` around lines 68 - 70, The file imports flash_attn unguarded via "from flash_attn.losses.cross_entropy import CrossEntropyLoss as _FlashCELoss" and instantiates _flash_ce_inst, which will crash if flash_attn is not installed; wrap this import/instantiation in a try/except ImportError and provide a safe fallback (e.g., set _FlashCELoss or _flash_ce_inst to None or to a local torch.nn.CrossEntropyLoss alternative) or re-raise a clearer error explaining the optional dependency; update any code that uses _flash_ce_inst to handle the None/fallback case.
122-124: Unused loop variableidx.The loop uses
enumerate()butidxis not referenced. A plain iteration would be clearer.♻️ Proposed fix
- for idx, layer in enumerate(model.model.layers): + for layer in model.model.layers: if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"): layer.self_attn.layer_idx = 0🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_ddp.py` around lines 122 - 124, The loop in train_ddp.py unnecessarily uses enumerate (for idx, layer in enumerate(model.model.layers)) while never using idx; change it to a simple iteration (for layer in model.model.layers) to remove the unused variable and clarify intent, leaving the inner check and assignment to layer.self_attn.layer_idx unchanged.
102-103: Missing GPU count validation unlike sibling files.
train_mixed.pyand othertrain_*.pyfiles validate thatnum_gpus == dp_size * tp_size * pp_size. This file'sinner_steps()lacks that check, which could lead to confusing runtime errors if misconfigured.♻️ Proposed fix: add validation at the start of inner_steps()
def inner_steps(model, data_iterator, optimizer, num_steps, device, num_gpus=1): _prepare_model(model) + + strategy = get_strategy() + expected_gpus = strategy["dp_size"] * strategy["tp_size"] * strategy.get("pp_size", 1) + if num_gpus != expected_gpus: + raise ValueError( + f"get_strategy() requires {expected_gpus} GPUs " + f"(dp={strategy['dp_size']}*tp={strategy['tp_size']}*pp={strategy.get('pp_size', 1)}), " + f"but num_gpus={num_gpus}" + ) bf16_policy = MixedPrecision(Also applies to: 160-178
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_ddp.py` around lines 102 - 103, inner_steps() is missing the GPU-count validation present in sibling scripts; add a check at the start of inner_steps() that computes expected_gpus = dp_size * tp_size * pp_size (using the values returned by get_strategy() or the strategy variable) and validate it against the actual num_gpus, raising a clear error (ValueError or assert with a descriptive message) if they differ so misconfiguration fails fast; reference get_strategy() and inner_steps() to locate where to insert the check and mirror the validation behavior from train_mixed.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/setup_benchmark.py`:
- Around line 74-80: The current check uses docker_tok_dir.exists() as a cache
hit which is unsafe; update the Docker export logic around docker_tok_dir and
the tokenizer-aware setup in main() to validate the tokenizer cache by checking
for required files (e.g., presence of tokenizer files created by
tokenizer.save_pretrained such as config/tokenizer.json or vocab files and the
train.pt artifact and/or a small manifest file) rather than only exists(), and
write a durable manifest atomically after a successful save; if validation fails
(missing or stale files or benchmark_tokenizer_name mismatch) remove or
overwrite the directory and call tokenizer.save_pretrained(docker_tok_dir)
before printing the success message. Ensure you reference docker_tok_dir,
tokenizer.save_pretrained, tok_source and benchmark_tokenizer_name when locating
and validating the cache.
In `@local_test/train_mixed.py`:
- Around line 223-243: The loop references an undefined name `_CHUNK`, causing a
NameError; fix by using the defined `chunk_size` (or define `_CHUNK =
chunk_size`) so the chunking logic is consistent: update the slices and n_tok
calculation to use `chunk_size` instead of `_CHUNK` (look for the loop and calls
to ckpt.checkpoint(_chunk_ce, _head, ch, cl, ...) where ch/cl are created, and
ensure _chunk_ce, _backbone, _head remain unchanged).
---
Nitpick comments:
In `@local_test/train_ddp.py`:
- Around line 68-70: The file imports flash_attn unguarded via "from
flash_attn.losses.cross_entropy import CrossEntropyLoss as _FlashCELoss" and
instantiates _flash_ce_inst, which will crash if flash_attn is not installed;
wrap this import/instantiation in a try/except ImportError and provide a safe
fallback (e.g., set _FlashCELoss or _flash_ce_inst to None or to a local
torch.nn.CrossEntropyLoss alternative) or re-raise a clearer error explaining
the optional dependency; update any code that uses _flash_ce_inst to handle the
None/fallback case.
- Around line 122-124: The loop in train_ddp.py unnecessarily uses enumerate
(for idx, layer in enumerate(model.model.layers)) while never using idx; change
it to a simple iteration (for layer in model.model.layers) to remove the unused
variable and clarify intent, leaving the inner check and assignment to
layer.self_attn.layer_idx unchanged.
- Around line 102-103: inner_steps() is missing the GPU-count validation present
in sibling scripts; add a check at the start of inner_steps() that computes
expected_gpus = dp_size * tp_size * pp_size (using the values returned by
get_strategy() or the strategy variable) and validate it against the actual
num_gpus, raising a clear error (ValueError or assert with a descriptive
message) if they differ so misconfiguration fails fast; reference get_strategy()
and inner_steps() to locate where to insert the check and mirror the validation
behavior from train_mixed.py.
In `@local_test/train_mixed.py`:
- Around line 153-155: The enumeration variable idx is unused in the loop over
model.model.layers; replace the enumerate-based loop with a plain iteration to
improve clarity: iterate directly over model.model.layers (for layer in
model.model.layers) and keep the existing hasattr checks and assignment to
layer.self_attn.layer_idx = 0 so behavior of the block (finding layers with
self_attn and setting layer.self_attn.layer_idx) remains unchanged.
- Around line 64-72: Remove the unused fields from the _AllSAC class: delete the
num_ckpt_layers attribute and the _count counter and any increments (they are
neither read nor used); update the constructor to take no args (i.e., def
__init__(self): pass or remove it entirely) so the class only defines __call__
which creates ctx_fn via functools.partial(create_selective_checkpoint_contexts,
_sac_policy) and returns ckpt.checkpoint(fn, *args, use_reentrant=False,
context_fn=ctx_fn, **kwargs); also update its instantiation to _AllSAC() where
it’s created.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 655e0c0b-9c24-4ecc-bbe6-c7bf06b910f9
📒 Files selected for processing (3)
local_test/setup_benchmark.pylocal_test/train_ddp.pylocal_test/train_mixed.py
| # Also save to tokenizers/<name> for Docker build (Dockerfile COPYs this directory) | ||
| project_root = Path(__file__).parent.parent | ||
| docker_tok_dir = project_root / "tokenizers" / tok_source | ||
| if not docker_tok_dir.exists(): | ||
| docker_tok_dir.mkdir(parents=True, exist_ok=True) | ||
| tokenizer.save_pretrained(docker_tok_dir) | ||
| print(f" Tokenizer also saved to: {docker_tok_dir} (for Docker build)") |
There was a problem hiding this comment.
The new tokenizer cache isn't safe to reuse on exists() alone.
Both the Docker export here and the tokenizer-aware setup calls in main() treat “path exists” as a complete cache hit. That makes reruns fragile: an older benchmark/model prevents tokenizers/<tok_source>/ from ever being created for the Docker build, an interrupted save can leave an empty tokenizers/<tok_source>/ directory that future runs trust, and train.pt stays stale after benchmark_tokenizer_name changes. Please validate expected files or a small setup manifest instead of only checking whether the path exists.
Also applies to: 205-205, 213-221
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/setup_benchmark.py` around lines 74 - 80, The current check uses
docker_tok_dir.exists() as a cache hit which is unsafe; update the Docker export
logic around docker_tok_dir and the tokenizer-aware setup in main() to validate
the tokenizer cache by checking for required files (e.g., presence of tokenizer
files created by tokenizer.save_pretrained such as config/tokenizer.json or
vocab files and the train.pt artifact and/or a small manifest file) rather than
only exists(), and write a durable manifest atomically after a successful save;
if validation fails (missing or stale files or benchmark_tokenizer_name
mismatch) remove or overwrite the directory and call
tokenizer.save_pretrained(docker_tok_dir) before printing the success message.
Ensure you reference docker_tok_dir, tokenizer.save_pretrained, tok_source and
benchmark_tokenizer_name when locating and validating the cache.
Use device.index instead of os.environ.get("LOCAL_RANK") to
determine pipeline stage assignment, avoiding security violation.
Made-with: Cursor
PyTorch 2.9 rejects assigning meta tensors to CUDA parameters via set_data(). Use torch.empty(0) on the same CUDA device instead to free memory for non-owned pipeline layers. Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (3)
local_test/train_pp.py (3)
187-190:⚠️ Potential issue | 🔴 CriticalCreate all DP groups in the same global order.
Ranks with
pp_rank == 0callnew_group([0, 2]), while ranks withpp_rank == 1callnew_group([1, 3])as their first and only subgroup.torch.distributed.new_group()requires every process to enternew_group()and says groups should be created in the same order on all processes, so this branchy setup can hang before the first gradient sync. (docs.pytorch.org)Suggested fix
dp_group = None if dp_size > 1: - dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == pp_rank] - dp_group = dist.new_group(dp_ranks) + for stage in range(pp_size): + ranks = [r for r in range(num_gpus) if (r % pp_size) == stage] + group = dist.new_group(ranks) + if stage == pp_rank: + dp_group = group🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 187 - 190, The current conditional creates dp groups only on some ranks in different orders (dp_group, dp_size, pp_size, pp_rank, dp_ranks, dist.new_group), which can hang because new_group must be called in the same order on all processes; fix by deterministically iterating over all pipeline ranks (for pp in range(pp_size)) and on every process compute the corresponding dp_ranks list and call dist.new_group(dp_ranks) in that same global order, storing the group whose pp equals the local pp_rank into dp_group so every process enters new_group calls in identical order while only keeping its own group.
222-223:⚠️ Potential issue | 🟠 Major
num_steps == 0still breaks the post-loop exchange.If the loop never runs, the last stage falls back to a 1D
final_logits, and the first stage reaches Line 267 withbsandseq_lenundefined. That gives either anUnboundLocalErroror arecv()shape mismatch. Guard the zero-step case before the training loop, or carry an explicit fallback logits shape into this block.Also applies to: 256-268
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 222 - 223, The loop over "for step in range(num_steps)" can be skipped when num_steps == 0 which leaves bs and seq_len undefined and yields an invalid post-loop exchange with "final_logits"; add an explicit zero-step guard before the loop (or immediately after getting all_inputs) that handles num_steps == 0 by setting bs and seq_len from a known source (e.g., all_inputs shape[0] or a provided batch/seq metadata) and producing a correctly-shaped fallback final_logits (or short-circuiting the exchange logic) so the downstream recv()/send() logic and the first stage that references bs/seq_len do not see uninitialized variables.
305-326:⚠️ Potential issue | 🔴 Critical
broadcast_object_list()still deadlocks across DP replicas.These calls omit
group=, so they execute on the default world group. PyTorch's distributed docs say default-group collectives require all processes to enter the same collective sequence; here ranks 0/2 and 1/3 use differentsrcranks based on their local pipeline peer, so_gather_pp_state()can hang as soon asdp_size > 1. Use a per-pipeline subgroup for the metadata exchange, or restrict full-state assembly to a single replica. (docs.pytorch.org)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 305 - 326, The metadata exchange uses dist.broadcast_object_list/recv on the default group which deadlocks across data-parallel replicas; update _gather_pp_state so it forms a per-pipeline subgroup (via dist.new_group or torch.distributed.new_group) for the pipeline-peer exchanges and pass group=pipeline_group into dist.broadcast_object_list and dist.recv (references: keys_and_shapes, peer_obj, dist.broadcast_object_list, dist.recv, pp_peer, my_state, global_rank) so only ranks in the same pipeline replica participate in the collectives; alternatively, restrict the full-state assembly to a single replica (e.g. only run the broadcast/recv sequence when local_dp_rank == 0) and then broadcast the assembled state within the DP replica to avoid using the default world group.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/train_pp.py`:
- Around line 288-300: The saved state tensors are being cloned to CPU (.cpu())
before NCCL send/recv, which will fail since NCCL requires GPU tensors; update
the logic around is_first_stage and the loops that populate my_state
(references: model.model.embed_tokens.state_dict(), all_layers,
model.model.norm.state_dict(), model.lm_head.state_dict()) so that tensors
remain on CUDA for the point-to-point dist.send()/dist.recv() calls (or only
call .cpu() after a recv completes). Concretely, remove or postpone .cpu() in
those assignments (or explicitly call .cuda() on tensors before using
dist.send()/dist.recv()), and ensure any final .cpu() conversion happens after
the corresponding dist.recv() has returned; keep the send/recv usage unchanged
but operate on GPU tensors to satisfy the NCCL backend.
---
Duplicate comments:
In `@local_test/train_pp.py`:
- Around line 187-190: The current conditional creates dp groups only on some
ranks in different orders (dp_group, dp_size, pp_size, pp_rank, dp_ranks,
dist.new_group), which can hang because new_group must be called in the same
order on all processes; fix by deterministically iterating over all pipeline
ranks (for pp in range(pp_size)) and on every process compute the corresponding
dp_ranks list and call dist.new_group(dp_ranks) in that same global order,
storing the group whose pp equals the local pp_rank into dp_group so every
process enters new_group calls in identical order while only keeping its own
group.
- Around line 222-223: The loop over "for step in range(num_steps)" can be
skipped when num_steps == 0 which leaves bs and seq_len undefined and yields an
invalid post-loop exchange with "final_logits"; add an explicit zero-step guard
before the loop (or immediately after getting all_inputs) that handles num_steps
== 0 by setting bs and seq_len from a known source (e.g., all_inputs shape[0] or
a provided batch/seq metadata) and producing a correctly-shaped fallback
final_logits (or short-circuiting the exchange logic) so the downstream
recv()/send() logic and the first stage that references bs/seq_len do not see
uninitialized variables.
- Around line 305-326: The metadata exchange uses
dist.broadcast_object_list/recv on the default group which deadlocks across
data-parallel replicas; update _gather_pp_state so it forms a per-pipeline
subgroup (via dist.new_group or torch.distributed.new_group) for the
pipeline-peer exchanges and pass group=pipeline_group into
dist.broadcast_object_list and dist.recv (references: keys_and_shapes, peer_obj,
dist.broadcast_object_list, dist.recv, pp_peer, my_state, global_rank) so only
ranks in the same pipeline replica participate in the collectives;
alternatively, restrict the full-state assembly to a single replica (e.g. only
run the broadcast/recv sequence when local_dp_rank == 0) and then broadcast the
assembled state within the DP replica to avoid using the default world group.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 40a71b2c-810c-495d-90f4-cf38c4ba6bd7
📒 Files selected for processing (1)
local_test/train_pp.py
Newer transformers moved rotary embedding to model level; layers now require position_embeddings=(cos, sin) explicitly. Without it, 'cos, sin = None' causes TypeError during forward pass. Also replaced forbidden getattr() with hasattr + direct access. Made-with: Cursor
Newer transformers returns a raw tensor (not tuple) from decoder layers when use_cache=False and output_attentions=False. Using [0] on a tensor slices the batch dim instead of extracting hidden_states. Now checks isinstance(out, tuple) before indexing. Also passes rotary position_embeddings to each layer (required since rotary emb moved to model level) and uses hasattr instead of forbidden getattr(). Verified locally: - AST security scan: PASSED - Syntax/symbols check: PASSED - Forward pass with Qwen2.5-7B config: PASSED Made-with: Cursor
There was a problem hiding this comment.
♻️ Duplicate comments (5)
local_test/train_pp.py (5)
244-247:⚠️ Potential issue | 🟡 MinorHardcoded
bfloat16dtype may cause silent mismatch.While line 124 casts the model to
bfloat16, hardcoding the dtype here is fragile. If the model dtype changes, the recv buffer would silently mismatch, causing incorrect results or errors.Proposed fix: infer dtype from model
+ model_dtype = next(model.parameters()).dtype - recv_buf = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=torch.bfloat16) + recv_buf = torch.zeros(bs, seq_len, hidden_size, device=device, dtype=model_dtype)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 244 - 247, The recv buffer is hardcoded to torch.bfloat16 which can mismatch the model dtype; instead, derive the dtype dynamically (e.g., from the model or an existing tensor) and use that when allocating recv_buf so it matches the model’s parameters; update the allocation of recv_buf (the variable created before dist.recv) to use the inferred dtype and then keep the subsequent detach().requires_grad_(True) for hidden.
299-311:⚠️ Potential issue | 🔴 CriticalCPU tensors will fail with NCCL
send/recvoperations.Lines 301, 305, 309, 311 clone tensors to CPU, then lines 327 and 340 use these CPU tensors with
dist.recv()anddist.send(). NCCL backend only supports GPU point-to-point operations—this will raise a runtime error.Keep tensors on CUDA for the communication, then move to CPU afterward.
Proposed fix
if is_first_stage: for k, v in model.model.embed_tokens.state_dict().items(): - my_state[f"model.embed_tokens.{k}"] = v.detach().cpu().clone() + my_state[f"model.embed_tokens.{k}"] = v.detach().clone() for idx in my_layer_indices: for k, v in all_layers[idx].state_dict().items(): - my_state[f"model.layers.{idx}.{k}"] = v.detach().cpu().clone() + my_state[f"model.layers.{idx}.{k}"] = v.detach().clone() if not is_first_stage: for k, v in model.model.norm.state_dict().items(): - my_state[f"model.norm.{k}"] = v.detach().cpu().clone() + my_state[f"model.norm.{k}"] = v.detach().clone() for k, v in model.lm_head.state_dict().items(): - my_state[f"lm_head.{k}"] = v.detach().cpu().clone() + my_state[f"lm_head.{k}"] = v.detach().clone()Then move to CPU after gathering is complete (at line 330):
- return my_state if global_rank == 0 else None + if global_rank == 0: + return {k: v.cpu() for k, v in my_state.items()} + return None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 299 - 311, The code is copying tensors to CPU before distributed send/recv which will fail with the NCCL backend; instead keep tensors on CUDA during communication and only move them to CPU after all dist.send()/dist.recv() operations complete. Specifically, in the block that populates my_state from model.model.embed_tokens, all_layers (iterating my_layer_indices), model.model.norm, and model.lm_head, remove the .cpu()/.detach().cpu().clone() so tensors remain on CUDA (keep .detach().clone() if needed), perform the dist.send()/dist.recv() with those GPU tensors, and after the gather/communication steps convert stored tensors in my_state to CPU (call .cpu()/.clone() once) before any CPU-only usage or saving.
267-272:⚠️ Potential issue | 🟡 MinorFallback logits tensor has wrong shape for validation.
When
num_steps=0,final_logitsremainsNoneand the fallback creates a 1D tensorzeros(1). This will fail any downstream shape validation expecting 3D(batch, seq, vocab).Proposed fix
if is_last_stage: if final_logits is None: - final_logits = torch.zeros(1, device=device) + vocab_size = model.config.vocab_size + final_logits = torch.zeros(1, 1, vocab_size, device=device, dtype=torch.bfloat16)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 267 - 272, final_logits fallback uses torch.zeros(1) which is 1D and breaks downstream validation; when final_logits is None (e.g., num_steps==0) allocate a zeros tensor with the expected 3D shape (batch, seq, vocab) using the same device and dtype as the model outputs and ensure it is contiguous before dist.send. Locate the final_logits check in the is_last_stage block and replace the torch.zeros(1, device=device) with torch.zeros(batch_size, seq_len, vocab_size, device=device, dtype=final_logits_dtype_or_model_dtype) (or derive shape/dtype from saved model/output shape variables), then call final_logits.contiguous() before sending to pp_peer.
316-340:⚠️ Potential issue | 🔴 Critical
_gather_pp_state()uses mismatched world-group broadcasts that will deadlock whendp_size > 1.With
dp_size=2andpp_size=2:
- Ranks 0 and 2 (first stage) call
broadcast_object_list(src=0)andbroadcast_object_list(src=2)respectively- Ranks 1 and 3 call
broadcast_object_list(src=1)andbroadcast_object_list(src=3)respectivelyUsing world-group broadcasts with different
srcvalues across replicas causes a collective mismatch deadlock. Use per-pipeline subgroups or restrict state gathering to a single replica.Proposed fix: use per-pipeline process group
Create pipeline groups at the start of
inner_steps(similar to DP groups), then pass to_gather_pp_state:+ # Create per-pipeline groups for state gathering + pp_group = None + if pp_size > 1: + for pipe_id in range(dp_size): + pipe_ranks = [pipe_id * pp_size + s for s in range(pp_size)] + group = dist.new_group(pipe_ranks) + if rank in pipe_ranks: + pp_group = groupThen in
_gather_pp_state, usegroup=pp_groupfor all broadcasts and adjustsrcto be relative to the pipeline group.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 316 - 340, _gather_pp_state() currently uses world-wide dist.broadcast_object_list calls with differing src ranks per pipeline replica which deadlocks when dp_size>1; change it to use a per-pipeline process group and use group-scoped collectives with src values relative to that group. Specifically: create pipeline groups (e.g., pp_group) alongside DP groups in inner_steps, pass pp_group into _gather_pp_state, and replace broadcast_object_list(..., src=...) and any send/recv that assume world ranks with broadcast_object_list(..., src=local_src, group=pp_group) and send/recv using ranks local to pp_group so all replicas in the same pipeline use the same src and group. Ensure pp_peer is computed relative to the pipeline group membership used.
198-201:⚠️ Potential issue | 🔴 CriticalDP group creation will deadlock: all ranks must create groups in the same order.
PyTorch's
new_group()is a collective operation requiring all ranks to participate in the same order. Currently:
- Ranks with
pp_rank=0callnew_group([0,2])- Ranks with
pp_rank=1callnew_group([1,3])This divergence causes a hang before the first gradient sync.
Proposed fix: create all DP groups deterministically on every rank
dp_group = None if dp_size > 1: - dp_ranks = [r for r in range(num_gpus) if (r % pp_size) == pp_rank] - dp_group = dist.new_group(dp_ranks) + for stage in range(pp_size): + ranks = [r for r in range(num_gpus) if (r % pp_size) == stage] + group = dist.new_group(ranks) + if stage == pp_rank: + dp_group = group🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 198 - 201, The DP group creation currently calls dist.new_group() only on ranks sharing the same pp_rank which diverges across ranks and deadlocks; instead, deterministically create all DP groups in the same order on every rank (e.g., iterate k in range(pp_size) sorted, compute dp_ranks_k = [r for r in range(num_gpus) if (r % pp_size) == k], call dist.new_group(dp_ranks_k) for each k, store handles in a list like dp_groups) and then set dp_group = dp_groups[pp_rank] (or None if dp_size <= 1); use the existing symbols dp_group, dp_size, dp_ranks, pp_size, pp_rank, num_gpus, and dist.new_group to locate and change the code.
🧹 Nitpick comments (2)
local_test/train_pp.py (2)
339-340: Unused loop variablek.The loop variable
kis not used in the loop body.Proposed fix
- for k, v in my_state.items(): + for _k, v in my_state.items(): dist.send(v.contiguous(), dst=pp_peer)Or simply:
- for k, v in my_state.items(): + for v in my_state.values(): dist.send(v.contiguous(), dst=pp_peer)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 339 - 340, The loop over my_state uses an unused variable k; update the loop to ignore it (e.g., for _, v in my_state.items()) or iterate only over values (for v in my_state.values()) and then call dist.send(v.contiguous(), dst=pp_peer) so the unused variable is removed and intent is clear; locate the loop that references my_state, k, v and dist.send to apply this change.
40-65: Consider logging exceptions in configuration blocks.The
try-except-passpattern silently swallows all exceptions, making debugging difficult if these configurations fail unexpectedly. While non-critical for optional features, logging at debug level would aid troubleshooting.Example for one block
+import logging +_logger = logging.getLogger(__name__) + try: torch.cuda.memory._set_allocator_settings("expandable_segments:True") -except Exception: - pass +except Exception as e: + _logger.debug("CUDA allocator settings not applied: %s", e)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@local_test/train_pp.py` around lines 40 - 65, The three silent try/except blocks (the torch.cuda.memory._set_allocator_settings call and the torch._inductor.config and torch._dynamo.config configuration blocks referencing _ind_cfg and _dyn_cfg) should log exceptions instead of swallowing them; add a module-level logger (import logging; logger = logging.getLogger(__name__)) and in each except block catch Exception as e and call logger.debug or logger.exception with a short context message (e.g., "failed to set CUDA allocator settings", "failed to apply inductor config", "failed to apply dynamo config") and include the exception info (exc_info=True) so failures are visible during debugging.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Duplicate comments:
In `@local_test/train_pp.py`:
- Around line 244-247: The recv buffer is hardcoded to torch.bfloat16 which can
mismatch the model dtype; instead, derive the dtype dynamically (e.g., from the
model or an existing tensor) and use that when allocating recv_buf so it matches
the model’s parameters; update the allocation of recv_buf (the variable created
before dist.recv) to use the inferred dtype and then keep the subsequent
detach().requires_grad_(True) for hidden.
- Around line 299-311: The code is copying tensors to CPU before distributed
send/recv which will fail with the NCCL backend; instead keep tensors on CUDA
during communication and only move them to CPU after all dist.send()/dist.recv()
operations complete. Specifically, in the block that populates my_state from
model.model.embed_tokens, all_layers (iterating my_layer_indices),
model.model.norm, and model.lm_head, remove the .cpu()/.detach().cpu().clone()
so tensors remain on CUDA (keep .detach().clone() if needed), perform the
dist.send()/dist.recv() with those GPU tensors, and after the
gather/communication steps convert stored tensors in my_state to CPU (call
.cpu()/.clone() once) before any CPU-only usage or saving.
- Around line 267-272: final_logits fallback uses torch.zeros(1) which is 1D and
breaks downstream validation; when final_logits is None (e.g., num_steps==0)
allocate a zeros tensor with the expected 3D shape (batch, seq, vocab) using the
same device and dtype as the model outputs and ensure it is contiguous before
dist.send. Locate the final_logits check in the is_last_stage block and replace
the torch.zeros(1, device=device) with torch.zeros(batch_size, seq_len,
vocab_size, device=device, dtype=final_logits_dtype_or_model_dtype) (or derive
shape/dtype from saved model/output shape variables), then call
final_logits.contiguous() before sending to pp_peer.
- Around line 316-340: _gather_pp_state() currently uses world-wide
dist.broadcast_object_list calls with differing src ranks per pipeline replica
which deadlocks when dp_size>1; change it to use a per-pipeline process group
and use group-scoped collectives with src values relative to that group.
Specifically: create pipeline groups (e.g., pp_group) alongside DP groups in
inner_steps, pass pp_group into _gather_pp_state, and replace
broadcast_object_list(..., src=...) and any send/recv that assume world ranks
with broadcast_object_list(..., src=local_src, group=pp_group) and send/recv
using ranks local to pp_group so all replicas in the same pipeline use the same
src and group. Ensure pp_peer is computed relative to the pipeline group
membership used.
- Around line 198-201: The DP group creation currently calls dist.new_group()
only on ranks sharing the same pp_rank which diverges across ranks and
deadlocks; instead, deterministically create all DP groups in the same order on
every rank (e.g., iterate k in range(pp_size) sorted, compute dp_ranks_k = [r
for r in range(num_gpus) if (r % pp_size) == k], call dist.new_group(dp_ranks_k)
for each k, store handles in a list like dp_groups) and then set dp_group =
dp_groups[pp_rank] (or None if dp_size <= 1); use the existing symbols dp_group,
dp_size, dp_ranks, pp_size, pp_rank, num_gpus, and dist.new_group to locate and
change the code.
---
Nitpick comments:
In `@local_test/train_pp.py`:
- Around line 339-340: The loop over my_state uses an unused variable k; update
the loop to ignore it (e.g., for _, v in my_state.items()) or iterate only over
values (for v in my_state.values()) and then call dist.send(v.contiguous(),
dst=pp_peer) so the unused variable is removed and intent is clear; locate the
loop that references my_state, k, v and dist.send to apply this change.
- Around line 40-65: The three silent try/except blocks (the
torch.cuda.memory._set_allocator_settings call and the torch._inductor.config
and torch._dynamo.config configuration blocks referencing _ind_cfg and _dyn_cfg)
should log exceptions instead of swallowing them; add a module-level logger
(import logging; logger = logging.getLogger(__name__)) and in each except block
catch Exception as e and call logger.debug or logger.exception with a short
context message (e.g., "failed to set CUDA allocator settings", "failed to apply
inductor config", "failed to apply dynamo config") and include the exception
info (exc_info=True) so failures are visible during debugging.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: 69974e07-54eb-407b-8a41-ce45fcfdb361
📒 Files selected for processing (1)
local_test/train_pp.py
- Keep state tensors on CUDA during P2P transfer (NCCL rejects CPU) - Create per-pipeline sub-groups for broadcast_object_list - Move to CPU only in final return - All ranks participate in all new_group() calls (collective requirement) Made-with: Cursor
- Add NCCL_MAX_NCHANNELS=1 to Dockerfile, runner.py, tester.py, and simulate_validator.py to limit NCCL to 1 channel (~3-6 GB/s effective). Combined with P2P_DISABLE, this makes collective-heavy strategies (FSDP/DDP/TP) communication-bound while PP sends only tiny activations. - Rewrite train_pp.py: dp=1, pp=4 (4-stage pipeline, zero collectives). Uses standard 1F1B schedule with M=16 microbatches (84.2% utilization). isend for forward activation transfers avoids cross-stage deadlocks. Expected MFU: 55-65% under throttling vs <30% for FSDP. - Fix train.py: restore get_strategy to dp=4 matching its FSDP code. Previously declared dp=2,pp=2 but ran FSDP, halving MFU credit. - Fix train_mixed.py: undefined _CHUNK variable → chunk_size. Made-with: Cursor
There was a problem hiding this comment.
Actionable comments posted: 8
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@local_test/train_mixed.py`:
- Around line 171-172: The inner_steps function must reject num_steps == 0
because the preload/final-logit path expects at least one batch; add an early
check in inner_steps (e.g., at the top of the function where
_prepare_model(model) is called) that raises a clear ValueError or
AssertionError if num_steps <= 0, so loss is always assigned and subsequent code
that does all_inputs[-1] / accesses the last batch is safe.
- Around line 207-217: The loop currently overwrites tokens_per_batch each
iteration and uses batch.numel() (which counts the dropped token) to compute
total_tokens; instead, accumulate the actual tokens used for forward passes by
summing the sizes of the preloaded slices (use batch[:, :-1].numel() or
equivalent) across iterations and assign that running sum to total_tokens (or a
tokens_accum variable) so total_tokens = sum_of_slices over the num_steps;
update references to tokens_per_batch/total_tokens accordingly and keep using
all_inputs/all_labels as before (look for variables all_inputs, all_labels,
tokens_per_batch, total_tokens, data_iterator, num_steps).
- Around line 153-155: The loop currently hardcodes every attention layer's
cache index to 0 which corrupts the KV cache; in the for loop that iterates over
model.model.layers (the enumerate loop with variables idx and layer), set each
attention module's layer_idx to the loop index instead of a literal 0—i.e.,
assign layer.self_attn.layer_idx using idx so each layer uses a unique cache
slot during autoregressive generation.
In `@local_test/train_pp.py`:
- Line 243: Clamp num_warmup to the actual microbatch count: replace the current
calculation of num_warmup (num_warmup = pp_size - 1 - pp_rank) with a guarded
clamp using the length of micro_inputs (e.g., num_micro = len(micro_inputs)) so
num_warmup = max(0, min(pp_size - 1 - pp_rank, num_micro - 1)). Apply the same
clamp in the other occurrence around lines 341-356 so Phase 1/3 never index past
micro_inputs and backprop slots remain aligned.
- Around line 129-136: The code currently calls model.to(device, dtype=...)
before pruning, which materializes the full model on every rank; move the
pruning of non-owned endpoint modules (embed_tokens, model.norm, lm_head and any
middle-stage modules not in my_layer_indices) to run before model.to(device) and
also apply the same stage-pruning logic in the earlier blocks (around the
all_layers loop and before building trainable_params) so ranks that never
execute those endpoints set p.requires_grad_(False) and replace p.data/buffers
with empty tensors (or delete them) on CPU first; ensure the optimizer
construction that builds trainable_params (the code around where
trainable_params is assembled) only sees the retained parameters after this
pruning so you avoid peak full-model allocation and exclude endpoint weights on
non-owning ranks.
- Around line 214-236: The total_tokens computation is wrong because
tokens_per_batch is overwritten each loop and then multiplied by num_steps;
instead accumulate the actual token count per batch as you iterate over
data_iterator. Inside the for _ in range(num_steps) loop (which uses batch =
next(data_iterator) and sets tokens_per_batch = batch.numel()), add the
batch.numel() to a running counter (e.g., total_tokens) rather than relying on a
single tokens_per_batch value, and remove the final total_tokens = num_steps *
tokens_per_batch so total_tokens reflects the true sum of batch.numel() across
iterations.
- Around line 453-460: When reconstructing the full model on rank 0, don't keep
tensors on CUDA: create full_state as CPU copies of my_state (e.g., full_state =
{k: v.cpu() for k, v in my_state.items()}) and allocate recv buffers on CPU (buf
= torch.empty(shape, dtype=dtype, device='cpu')) before calling
dist.recv(src=peer); store the received CPU buffer into full_state[k] so no
per-peer tensors are ever assembled on the GPU and you avoid the transient
full-model GPU memory spike.
- Around line 116-122: The code currently uses device.index to compute pp_rank
which is brittle; instead derive pp_rank from the distributed rank (use rank =
dist.get_rank() when dist.is_initialized(), else 0) and assign pp_rank = rank;
keep local_rank as device.index or torch.cuda.current_device() for local device
mapping if needed, then compute is_first_stage, is_last_stage, pp_prev and
pp_next from pp_rank (e.g., is_first_stage = pp_rank == 0, is_last_stage =
pp_rank == pp_size - 1, pp_prev = pp_rank - 1 if not is_first_stage else -1,
pp_next = pp_rank + 1 if not is_last_stage else -1) so P2P routing and layer
selection (all_stage_layers[pp_rank]) use the distributed rank rather than
device.index.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Organization UI
Review profile: CHILL
Plan: Pro
Run ID: d9ae7405-27ce-4caa-bee4-a855504d99dd
📒 Files selected for processing (7)
arbos/tester.pyenvironments/templar/Dockerfilelocal_test/simulate_validator.pylocal_test/train.pylocal_test/train_mixed.pylocal_test/train_pp.pysrc/crusades/affinetes/runner.py
✅ Files skipped from review due to trivial changes (2)
- environments/templar/Dockerfile
- local_test/simulate_validator.py
🚧 Files skipped from review as they are similar to previous changes (3)
- src/crusades/affinetes/runner.py
- local_test/train.py
- arbos/tester.py
| for idx, layer in enumerate(model.model.layers): | ||
| if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "layer_idx"): | ||
| layer.self_attn.layer_idx = 0 |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
In the Hugging Face transformer attention implementation used by this trainer, what is self_attn.layer_idxused for, and is it safe to set every layer'sself_attn.layer_idx = 0 instead of preserving each layer's original index?
💡 Result:
self_attn.layer_idx is used to identify the specific layer when updating the key-value (KV) cache during autoregressive generation. In the attention forward pass (e.g., LlamaAttention), when past_key_values (a Cache object) is provided, it calls past_key_values.update(key_states, value_states, self.layer_idx, ...), where layer_idx specifies which layer's cache slice to append the new KV states to. The Cache is structured as a list of per-layer caches, indexed by this layer_idx to ensure KV states are stored and retrieved correctly for each layer. It is NOT safe to set every layer's self_attn.layer_idx = 0. All layers would then update the same cache slot (layer 0), causing: - Overwriting of KV states from other layers. - Retrieval of incorrect past KV states during subsequent forward passes. - Index errors if the cache has fewer layers than expected, or silent corruption otherwise. This is evidenced by GitHub issues where layer removal or reordering without updating layer_idx causes index errors in cache.update (e.g., list index out of range when layer_idx exceeds cache length). Preserving original indices ensures correct per-layer KV caching, critical for the Trainer's use of the model during evaluation or generation.
Citations:
- 1: Removing model layers throws an index error. huggingface/transformers#30508
- 2: Move layer_idx from a layer property to function argument. huggingface/transformers#28462
- 3: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
- 4: Cache updating when use_cache = False huggingface/transformers#32843
- 5: https://github.com/huggingface/transformers/blob/main/src/transformers/cache_utils.py
- 6: https://huggingface.co/docs/transformers/v4.55.4/cache_explanation
Do not hardcode layer_idx = 0 for all attention layers; use the loop index instead.
Setting layer.self_attn.layer_idx = 0 for every layer breaks the KV cache mechanism. The layer_idx attribute is used by Hugging Face transformers to index into the per-layer cache during autoregressive generation. When all layers have layer_idx = 0, they all write to and read from the same cache slot, causing KV state corruption, incorrect inference, and potential index errors. Change line 155 to:
layer.self_attn.layer_idx = idxThis ensures each layer updates its own cache slot.
🧰 Tools
🪛 Ruff (0.15.9)
[warning] 153-153: Loop control variable idx not used within loop body
(B007)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_mixed.py` around lines 153 - 155, The loop currently
hardcodes every attention layer's cache index to 0 which corrupts the KV cache;
in the for loop that iterates over model.model.layers (the enumerate loop with
variables idx and layer), set each attention module's layer_idx to the loop
index instead of a literal 0—i.e., assign layer.self_attn.layer_idx using idx so
each layer uses a unique cache slot during autoregressive generation.
| def inner_steps(model, data_iterator, optimizer, num_steps, device, num_gpus=1): | ||
| _prepare_model(model) |
There was a problem hiding this comment.
Reject zero-step calls up front.
The new preload/final-logit path assumes at least one batch. If num_steps == 0, loss is never assigned and Lines 257-258 index all_inputs[-1].
💡 Suggested change
def inner_steps(model, data_iterator, optimizer, num_steps, device, num_gpus=1):
+ if num_steps <= 0:
+ raise ValueError("num_steps must be >= 1")
_prepare_model(model)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def inner_steps(model, data_iterator, optimizer, num_steps, device, num_gpus=1): | |
| _prepare_model(model) | |
| def inner_steps(model, data_iterator, optimizer, num_steps, device, num_gpus=1): | |
| if num_steps <= 0: | |
| raise ValueError("num_steps must be >= 1") | |
| _prepare_model(model) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_mixed.py` around lines 171 - 172, The inner_steps function
must reject num_steps == 0 because the preload/final-logit path expects at least
one batch; add an early check in inner_steps (e.g., at the top of the function
where _prepare_model(model) is called) that raises a clear ValueError or
AssertionError if num_steps <= 0, so loss is always assigned and subsequent code
that does all_inputs[-1] / accesses the last batch is safe.
| all_inputs = [] | ||
| all_labels = [] | ||
| tokens_per_batch = 0 | ||
| for _ in range(num_steps): | ||
| batch = next(data_iterator).to(device, dtype=torch.long, non_blocking=True) | ||
| all_inputs.append(batch[:, :-1].contiguous()) | ||
| all_labels.append(batch[:, 1:].contiguous()) | ||
| tokens_per_batch = batch.numel() | ||
|
|
||
| for step in range(num_steps): | ||
| batch = next(data_iterator).to(device, dtype=torch.long) | ||
| input_ids = batch[:, :-1] | ||
| labels = batch[:, 1:] | ||
| torch.cuda.synchronize(device) | ||
| total_tokens = num_steps * tokens_per_batch |
There was a problem hiding this comment.
Accumulate total_tokens from the actual preloaded slices.
Line 214 overwrites tokens_per_batch on each iteration, and Line 217 multiplies only the last value by num_steps. That miscounts heterogeneous batches, and batch.numel() also includes the token you drop before the forward pass. Since the harness uses total_tokens for TPS/MFU, the downstream metrics drift.
💡 Suggested change
all_inputs = []
all_labels = []
- tokens_per_batch = 0
+ total_tokens = 0
for _ in range(num_steps):
batch = next(data_iterator).to(device, dtype=torch.long, non_blocking=True)
- all_inputs.append(batch[:, :-1].contiguous())
- all_labels.append(batch[:, 1:].contiguous())
- tokens_per_batch = batch.numel()
+ inputs = batch[:, :-1].contiguous()
+ labels = batch[:, 1:].contiguous()
+ all_inputs.append(inputs)
+ all_labels.append(labels)
+ total_tokens += inputs.numel()
torch.cuda.synchronize(device)
- total_tokens = num_steps * tokens_per_batch📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| all_inputs = [] | |
| all_labels = [] | |
| tokens_per_batch = 0 | |
| for _ in range(num_steps): | |
| batch = next(data_iterator).to(device, dtype=torch.long, non_blocking=True) | |
| all_inputs.append(batch[:, :-1].contiguous()) | |
| all_labels.append(batch[:, 1:].contiguous()) | |
| tokens_per_batch = batch.numel() | |
| for step in range(num_steps): | |
| batch = next(data_iterator).to(device, dtype=torch.long) | |
| input_ids = batch[:, :-1] | |
| labels = batch[:, 1:] | |
| torch.cuda.synchronize(device) | |
| total_tokens = num_steps * tokens_per_batch | |
| all_inputs = [] | |
| all_labels = [] | |
| total_tokens = 0 | |
| for _ in range(num_steps): | |
| batch = next(data_iterator).to(device, dtype=torch.long, non_blocking=True) | |
| inputs = batch[:, :-1].contiguous() | |
| labels = batch[:, 1:].contiguous() | |
| all_inputs.append(inputs) | |
| all_labels.append(labels) | |
| total_tokens += inputs.numel() | |
| torch.cuda.synchronize(device) |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_mixed.py` around lines 207 - 217, The loop currently
overwrites tokens_per_batch each iteration and uses batch.numel() (which counts
the dropped token) to compute total_tokens; instead, accumulate the actual
tokens used for forward passes by summing the sizes of the preloaded slices (use
batch[:, :-1].numel() or equivalent) across iterations and assign that running
sum to total_tokens (or a tokens_accum variable) so total_tokens = sum_of_slices
over the num_steps; update references to tokens_per_batch/total_tokens
accordingly and keep using all_inputs/all_labels as before (look for variables
all_inputs, all_labels, tokens_per_batch, total_tokens, data_iterator,
num_steps).
| rank = dist.get_rank() if dist.is_initialized() else 0 | ||
| local_rank = device.index if device.index is not None else 0 | ||
| pp_rank = local_rank | ||
| is_first_stage = pp_rank == 0 | ||
| is_last_stage = pp_rank == pp_size - 1 | ||
| pp_prev = pp_rank - 1 if not is_first_stage else -1 | ||
| pp_next = pp_rank + 1 if not is_last_stage else -1 |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
find . -name "train_pp.py" -type fRepository: one-covenant/crusades
Length of output: 89
🏁 Script executed:
wc -l local_test/train_pp.pyRepository: one-covenant/crusades
Length of output: 91
🏁 Script executed:
sed -n '100,135p' local_test/train_pp.pyRepository: one-covenant/crusades
Length of output: 1382
🏁 Script executed:
sed -n '1,50p' local_test/train_pp.pyRepository: one-covenant/crusades
Length of output: 1716
🏁 Script executed:
sed -n '110,200p' local_test/train_pp.pyRepository: one-covenant/crusades
Length of output: 2996
🏁 Script executed:
rg "dist\.(send|recv|irecv|isend)" local_test/train_pp.py -A 2 -B 2Repository: one-covenant/crusades
Length of output: 2594
🏁 Script executed:
rg "def get_strategy" local_test/train_pp.py -A 10Repository: one-covenant/crusades
Length of output: 360
Route PP by distributed rank, not device.index.
device.index is brittle: torch.device("cuda") yields index=None, so every process falls back to pp_rank == 0. This breaks layer selection (all_stage_layers[pp_rank]), stage detection (is_first_stage, is_last_stage), and P2P routing (pp_prev, pp_next), since all ranks will attempt to communicate as rank 0. For this fixed dp=1,tp=1,pp=4 topology where the PP group is the default world group, pp_rank should come from dist.get_rank() instead.
Suggested change
rank = dist.get_rank() if dist.is_initialized() else 0
- local_rank = device.index if device.index is not None else 0
- pp_rank = local_rank
+ pp_rank = rank if dist.is_initialized() else 0📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| local_rank = device.index if device.index is not None else 0 | |
| pp_rank = local_rank | |
| is_first_stage = pp_rank == 0 | |
| is_last_stage = pp_rank == pp_size - 1 | |
| pp_prev = pp_rank - 1 if not is_first_stage else -1 | |
| pp_next = pp_rank + 1 if not is_last_stage else -1 | |
| rank = dist.get_rank() if dist.is_initialized() else 0 | |
| pp_rank = rank if dist.is_initialized() else 0 | |
| is_first_stage = pp_rank == 0 | |
| is_last_stage = pp_rank == pp_size - 1 | |
| pp_prev = pp_rank - 1 if not is_first_stage else -1 | |
| pp_next = pp_rank + 1 if not is_last_stage else -1 |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_pp.py` around lines 116 - 122, The code currently uses
device.index to compute pp_rank which is brittle; instead derive pp_rank from
the distributed rank (use rank = dist.get_rank() when dist.is_initialized(),
else 0) and assign pp_rank = rank; keep local_rank as device.index or
torch.cuda.current_device() for local device mapping if needed, then compute
is_first_stage, is_last_stage, pp_prev and pp_next from pp_rank (e.g.,
is_first_stage = pp_rank == 0, is_last_stage = pp_rank == pp_size - 1, pp_prev =
pp_rank - 1 if not is_first_stage else -1, pp_next = pp_rank + 1 if not
is_last_stage else -1) so P2P routing and layer selection
(all_stage_layers[pp_rank]) use the distributed rank rather than device.index.
| model = model.to(device=device, dtype=torch.bfloat16) | ||
|
|
||
| for i, layer in enumerate(all_layers): | ||
| if i not in my_layer_indices: | ||
| for p in layer.parameters(): | ||
| p.requires_grad_(False) | ||
| p.data = torch.empty(0, dtype=p.dtype, device=device) | ||
|
|
There was a problem hiding this comment.
This still forces each rank to pay almost the full-model memory cost.
model.to(device) runs before pruning, so every process briefly materializes the whole model. After that, only model.model.layers are trimmed, leaving the middle stages still holding embed_tokens, model.norm, and lm_head; when this function builds its own optimizer, Line 201 also includes those unused weights in trainable_params. On a 262K-vocab model, those endpoint modules dominate memory, so this can wipe out most of the PP headroom.
Please stage-prune the non-owned endpoint modules too on the ranks that never execute them, ideally before the .to(device) so you avoid the peak allocation.
Also applies to: 176-195, 201-210
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_pp.py` around lines 129 - 136, The code currently calls
model.to(device, dtype=...) before pruning, which materializes the full model on
every rank; move the pruning of non-owned endpoint modules (embed_tokens,
model.norm, lm_head and any middle-stage modules not in my_layer_indices) to run
before model.to(device) and also apply the same stage-pruning logic in the
earlier blocks (around the all_layers loop and before building trainable_params)
so ranks that never execute those endpoints set p.requires_grad_(False) and
replace p.data/buffers with empty tensors (or delete them) on CPU first; ensure
the optimizer construction that builds trainable_params (the code around where
trainable_params is assembled) only sees the retained parameters after this
pruning so you avoid peak full-model allocation and exclude endpoint weights on
non-owning ranks.
| all_micro_inputs = [] | ||
| all_micro_labels = [] | ||
| tokens_per_batch = 0 | ||
| hidden_size = model.config.hidden_size | ||
| n_microbatches = _NUM_MICROBATCHES | ||
|
|
||
| for _ in range(num_steps): | ||
| batch = next(data_iterator).to(device, dtype=torch.long, non_blocking=True) | ||
| inp = batch[:, :-1].contiguous() | ||
| lab = batch[:, 1:].contiguous() | ||
| tokens_per_batch = batch.numel() | ||
| bs = inp.shape[0] | ||
| micro_bs = max(1, bs // n_microbatches) | ||
| step_inputs = [] | ||
| step_labels = [] | ||
| for m in range(0, bs, micro_bs): | ||
| step_inputs.append(inp[m : m + micro_bs]) | ||
| step_labels.append(lab[m : m + micro_bs]) | ||
| all_micro_inputs.append(step_inputs) | ||
| all_micro_labels.append(step_labels) | ||
|
|
||
| torch.cuda.synchronize(device) | ||
| total_tokens = num_steps * tokens_per_batch |
There was a problem hiding this comment.
Accumulate total_tokens instead of assuming fixed batch shapes.
tokens_per_batch is overwritten on each iteration and then multiplied by num_steps, so the returned count is wrong as soon as the iterator yields a shorter last batch or variable sequence lengths.
Suggested change
- tokens_per_batch = 0
+ total_tokens = 0
hidden_size = model.config.hidden_size
n_microbatches = _NUM_MICROBATCHES
for _ in range(num_steps):
batch = next(data_iterator).to(device, dtype=torch.long, non_blocking=True)
@@
- tokens_per_batch = batch.numel()
+ total_tokens += batch.numel()
bs = inp.shape[0]
micro_bs = max(1, bs // n_microbatches)
@@
- total_tokens = num_steps * tokens_per_batch🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_pp.py` around lines 214 - 236, The total_tokens computation
is wrong because tokens_per_batch is overwritten each loop and then multiplied
by num_steps; instead accumulate the actual token count per batch as you iterate
over data_iterator. Inside the for _ in range(num_steps) loop (which uses batch
= next(data_iterator) and sets tokens_per_batch = batch.numel()), add the
batch.numel() to a running counter (e.g., total_tokens) rather than relying on a
single tokens_per_batch value, and remove the final total_tokens = num_steps *
tokens_per_batch so total_tokens reflects the true sum of batch.numel() across
iterations.
| final_loss = 0.0 | ||
|
|
||
| # 1F1B schedule parameters | ||
| num_warmup = pp_size - 1 - pp_rank |
There was a problem hiding this comment.
Clamp warmup to the actual microbatch count.
On rank 0/1/2, num_warmup is 3/2/1. If a step produces fewer microbatches than that, Phase 1 walks past micro_inputs, and Phase 3 can backprop the wrong slots. This shows up quickly on small-batch/debug runs.
Suggested change
micro_inputs = all_micro_inputs[step]
micro_labels = all_micro_labels[step]
n_micro = len(micro_inputs)
total_micro_tokens = sum(ml.numel() for ml in micro_labels)
+ warmup = min(num_warmup, n_micro)
@@
- for i in range(num_warmup):
+ for i in range(warmup):
_do_forward(i)
# ---- Phase 2: Steady state (1 backward + 1 forward) ----
- for i in range(num_warmup, n_micro):
- bwd_idx = i - num_warmup
+ for i in range(warmup, n_micro):
+ bwd_idx = i - warmup
if not is_last_stage:
_do_backward(bwd_idx)
_do_forward(i)
@@
- for i in range(n_micro - num_warmup, n_micro):
+ for i in range(max(0, n_micro - warmup), n_micro):
_do_backward(i)Also applies to: 341-356
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_pp.py` at line 243, Clamp num_warmup to the actual
microbatch count: replace the current calculation of num_warmup (num_warmup =
pp_size - 1 - pp_rank) with a guarded clamp using the length of micro_inputs
(e.g., num_micro = len(micro_inputs)) so num_warmup = max(0, min(pp_size - 1 -
pp_rank, num_micro - 1)). Apply the same clamp in the other occurrence around
lines 341-356 so Phase 1/3 never index past micro_inputs and backprop slots
remain aligned.
| if global_rank == 0: | ||
| full_state = dict(my_state) | ||
| for peer in range(1, pp_size): | ||
| for k, shape, dtype in all_keys[peer]: | ||
| buf = torch.empty(shape, dtype=dtype, device=device) | ||
| dist.recv(buf, src=peer) | ||
| full_state[k] = buf | ||
| return {k: v.cpu() for k, v in full_state.items()} |
There was a problem hiding this comment.
Don't reassemble the full state on CUDA.
Rank 0 stores every received tensor in full_state on GPU and only calls .cpu() after the entire model has been reconstructed. That transiently requires one GPU to hold the full model again, which is exactly the memory spike PP is supposed to avoid.
Suggested change
if global_rank == 0:
- full_state = dict(my_state)
+ full_state = {k: v.cpu() for k, v in my_state.items()}
+ my_state.clear()
for peer in range(1, pp_size):
for k, shape, dtype in all_keys[peer]:
buf = torch.empty(shape, dtype=dtype, device=device)
dist.recv(buf, src=peer)
- full_state[k] = buf
- return {k: v.cpu() for k, v in full_state.items()}
+ full_state[k] = buf.cpu()
+ del buf
+ return full_state🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@local_test/train_pp.py` around lines 453 - 460, When reconstructing the full
model on rank 0, don't keep tensors on CUDA: create full_state as CPU copies of
my_state (e.g., full_state = {k: v.cpu() for k, v in my_state.items()}) and
allocate recv buffers on CPU (buf = torch.empty(shape, dtype=dtype,
device='cpu')) before calling dist.recv(src=peer); store the received CPU buffer
into full_state[k] so no per-peer tensors are ever assembled on the GPU and you
avoid the transient full-model GPU memory spike.
Extend the miner parallelism contract to support pipeline parallelism (pp_size) alongside existing dp_size and tp_size. Update env.py strategy detection, validation (dptppp == num_gpus), and data sharding. Add train_pp.py reference implementation and update all local_test files with explicit pp_size in get_strategy() and documentation.
Made-with: Cursor
Summary by CodeRabbit
New Features
Improvements
Documentation
Chores