Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/bake-gcp-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ concurrency:

env:
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
DEFAULT_BASE_IMAGE: areal-cicd-test-20260425-409
DEFAULT_BASE_IMAGE: areal-cicd-test-20260506-432

jobs:
bake:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test-areal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ concurrency:
env:
GCP_PROJECT_ID: ${{ secrets.GCP_PROJECT_ID }}
RUNNER_VERSION: '2.332.0'
GCP_OS_IMAGE: areal-cicd-test-20260425-409
GCP_OS_IMAGE: areal-cicd-test-20260506-432

jobs:
determine-variants:
Expand Down
69 changes: 59 additions & 10 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
_is_multi_modal_payload_key,
extract_vision_from_multi_modal,
packed_context_parallel_forward,
reassemble_cp_packed_logprobs,
split_packed_seqs_for_context_parallel,
)
from areal.engine.megatron_utils.pipeline_parallel import (
Expand Down Expand Up @@ -863,14 +864,11 @@ def _process_output(input_, output_):
cp_labels = split_packed_seqs_for_context_parallel(
rolled_ids, padded_cu_seqlens
)
cp_loss_mask = split_packed_seqs_for_context_parallel(
mb_input.padded_mb["loss_mask"], padded_cu_seqlens
)
cp_cu_seqlens = padded_cu_seqlens // cp_size
cp_inputs = dict(mb_input.orig_mb)
cp_inputs["_cp_local_labels"] = cp_labels
cp_inputs["loss_mask"] = cp_loss_mask
cp_inputs["cu_seqlens"] = cp_cu_seqlens
cp_inputs["_cp_padded_cu_seqlens"] = padded_cu_seqlens
cp_inputs["_cp_padding_length"] = mb_input.padding_length
cp_inputs["_cp_old_cu_seqlens"] = mb_input.old_cu_seqlens
return output, functools.partial(_process_output, cp_inputs)
else:
output = unpad_logits(
Expand Down Expand Up @@ -911,9 +909,15 @@ def train_batch(
# Step 1: Prepare micro-batches
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 2: Compute total loss weight
# Step 2: Compute total loss weight.
# Use DP+CP group: after CP all-gather each rank computes the full-sequence
# loss, so all_gather's backward (reduce_scatter) sums cp_size identical
# gradients, amplifying by cp_size. Including CP in the weight all-reduce
# introduces a matching cp_size factor in the denominator, cancelling out.
total_loss_weight = compute_total_loss_weight(
mb_list, loss_weight_fn, mpu.get_data_parallel_group()
mb_list,
loss_weight_fn,
mpu.get_data_parallel_group(with_context_parallel=True),
)

# Step 3: Forward-backward using Megatron's pipeline function.
Expand Down Expand Up @@ -966,9 +970,11 @@ def eval_batch(
# Step 1: Prepare micro-batches
mb_list = self._prepare_mb_list(input_batched).to(self.device)

# Step 2: Compute total loss weight
# Step 2: Compute total loss weight (DP+CP, see train_batch comment).
total_loss_weight = compute_total_loss_weight(
mb_list, loss_weight_fn, mpu.get_data_parallel_group()
mb_list,
loss_weight_fn,
mpu.get_data_parallel_group(with_context_parallel=True),
)

# Step 3: Forward using Megatron's pipeline function, collecting losses
Expand Down Expand Up @@ -1974,6 +1980,7 @@ def _compute_logprobs_and_loss(
)
else:
cp_local_labels = inputs.get("_cp_local_labels")
cp_padded_cu_seqlens = inputs.get("_cp_padded_cu_seqlens")
if cp_local_labels is not None:
labels = cp_local_labels
else:
Expand All @@ -1988,6 +1995,48 @@ def _compute_logprobs_and_loss(
)
vocab_min_logits = output.detach().min(-1).values.float()
vocab_max_logits = output.detach().max(-1).values.float()
if cp_padded_cu_seqlens is not None:
logprobs = reassemble_cp_packed_logprobs(
logprobs, cp_padded_cu_seqlens
)
entropy = reassemble_cp_packed_logprobs(
entropy, cp_padded_cu_seqlens
)
vocab_min_logits = reassemble_cp_packed_logprobs(
vocab_min_logits, cp_padded_cu_seqlens
)
vocab_max_logits = reassemble_cp_packed_logprobs(
vocab_max_logits, cp_padded_cu_seqlens
)
cp_padding_length = inputs.get("_cp_padding_length", 0)
cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens")
logprobs = unpad_logits(
logprobs,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
entropy = unpad_logits(
entropy,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
vocab_min_logits = unpad_logits(
vocab_min_logits,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
vocab_max_logits = unpad_logits(
vocab_max_logits,
cp_padding_length,
cp_padded_cu_seqlens,
cp_old_cu_seqlens,
)
inputs = {
k: v for k, v in inputs.items() if not k.startswith("_cp_")
}
loss = loss_fn(
logprobs,
entropy,
Expand Down
79 changes: 79 additions & 0 deletions areal/engine/megatron_utils/packed_context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.distributed as dist
import torch.distributed.nn.functional as dist_F
from megatron.core import parallel_state as mpu
from megatron.core.packed_seq_params import PackedSeqParams

Expand Down Expand Up @@ -106,6 +107,84 @@ def split_packed_seqs_for_context_parallel(
return splitted


def _build_cp_reassemble_indices(
padded_cu_seqlens: torch.Tensor,
cp_size: int,
) -> torch.Tensor:
"""Build the index mapping from concatenated CP chunks to original order.

Returns a 1D LongTensor of length ``output_len`` where ``indices[dst] = src``
means the token at position ``dst`` in the full sequence comes from position
``src`` in the flattened ``torch.cat(gathered_list)`` tensor.
"""
input_lens = padded_cu_seqlens[1:] - padded_cu_seqlens[:-1]
batch_size = input_lens.shape[0]
output_len = int(padded_cu_seqlens[-1].item())
local_len = output_len // cp_size
device = padded_cu_seqlens.device

indices = torch.empty(output_len, dtype=torch.long, device=device)

for i in range(batch_size):
seq_len = int(input_lens[i].item())
chunk_size = seq_len // cp_size
half_chunk = chunk_size // 2
local_start = int(padded_cu_seqlens[i].item()) // cp_size
full_start = int(padded_cu_seqlens[i].item())

k = torch.arange(half_chunk, device=device)
for j in range(cp_size):
src_offset = j * local_len + local_start
# first half → positions [j*H, (j+1)*H) in full sequence
indices[full_start + j * half_chunk + k] = src_offset + k
# second half → mirror positions [L-(j+1)*H, L-j*H)
indices[full_start + seq_len - (j + 1) * half_chunk + k] = (
src_offset + half_chunk + k
)

return indices


def reassemble_cp_packed_logprobs(
local_tensor: torch.Tensor,
padded_cu_seqlens: torch.Tensor,
) -> torch.Tensor:
"""All-gather CP-local 1D tensors and reassemble in original sequence order.

This is the differentiable inverse of ``split_packed_seqs_for_context_parallel``.
It uses ``torch.distributed.nn.functional.all_gather`` (backward = reduce_scatter)
followed by advanced indexing (differentiable permutation) so that gradients
flow correctly back to each CP rank's local logprobs.

Args:
local_tensor: 1D tensor of shape ``(total_packed_len // cp_size,)`` holding
this rank's CP-local values (e.g. logprobs, entropy, vocab stats).
padded_cu_seqlens: Cumulative sequence lengths in the *padded* (pre-split)
layout, of shape ``(batch_size + 1,)``.

Returns:
Full-sequence 1D tensor of shape ``(total_packed_len,)`` with values placed
back in original token order. Gradients flow back through the all-gather.
"""
cp_size = mpu.get_context_parallel_world_size()
if cp_size <= 1:
return local_tensor

cp_group = mpu.get_context_parallel_group()

# Differentiable all-gather: backward is reduce_scatter(sum).
gathered_list = dist_F.all_gather(local_tensor, group=cp_group)

# Concatenate all gathered chunks into a single flat tensor.
# cat is differentiable (backward splits gradients back to each chunk).
gathered_flat = torch.cat(gathered_list, dim=0)

# Build index mapping and apply via advanced indexing (differentiable).
# indices[dst] = src means output[dst] = gathered_flat[src].
indices = _build_cp_reassemble_indices(padded_cu_seqlens, cp_size)
return gathered_flat[indices]


def postprocess_packed_seqs_context_parallel(
output: torch.Tensor,
cu_seqlens: torch.Tensor | None,
Expand Down
2 changes: 1 addition & 1 deletion areal/trainer/sft/lm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def compute_packed_sft_loss(
n_seqs[i] = True
seqlogp[i] = torch.where(m, logp.detach(), 0.0).sum() / valid_tokens

## Loggin stats
## Logging stats
stats_tracker.denominator(
n_seqs=n_seqs,
n_tokens=torch.ones(logprobs.shape[0], dtype=torch.bool, device=device),
Expand Down
13 changes: 7 additions & 6 deletions assets/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,16 @@ For background on how the project is governed and how to participate, please see

## Upcoming Meetings

| Date | Agenda | Slides | Recording |
| ---------- | -------------------------------------------------------------------------------------------------------------- | ------ | --------- |
| 2026/05/01 | [Google Doc](https://docs.google.com/document/d/1w106Eoj2rMX702EXX56Qz8OeD255Fx-36iLarL6Hi0M/edit?usp=sharing) | TBD | TBD |
| Date | Agenda | Slides | Recording |
| ---------- | ------ | ------ | --------- |
| 2026/05/16 | TBD | TBD | TBD |

## Past Meetings

| Date | Agenda | Slides | Recording |
| ---------- | -------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
| 2026/04/18 | [Google Doc](https://docs.google.com/document/d/1t4JSoXuPgtMsjAHio4keh3PM1elDeQAXC_BZGnv9kco/edit?usp=sharing) | [Google Slides](https://docs.google.com/presentation/d/1MaZL2Tq39YPYQYIIWNiKaBo2MonLLbW-/edit?usp=sharing&ouid=102752648406195568586&rtpof=true&sd=true) | [Tencent Meeting (Chinese)](https://meeting.tencent.com/crm/2MW8w0q6ef) |
| Date | Agenda | Slides | Recording |
| ---------- | -------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------- |
| 2026/04/18 | [Google Doc](https://docs.google.com/document/d/1t4JSoXuPgtMsjAHio4keh3PM1elDeQAXC_BZGnv9kco/edit?usp=sharing) | [Google Slides](https://docs.google.com/presentation/d/1MaZL2Tq39YPYQYIIWNiKaBo2MonLLbW-/edit?usp=sharing&ouid=102752648406195568586&rtpof=true&sd=true) | [Tencent Meeting (Chinese)](https://meeting.tencent.com/crm/2MW8w0q6ef) |
| 2026/05/01 | [Google Doc](https://docs.google.com/document/d/1w106Eoj2rMX702EXX56Qz8OeD255Fx-36iLarL6Hi0M/edit?usp=sharing) | [Google Slides](https://docs.google.com/presentation/d/1TXpSInTA4TLWfiOA6Fu5cZlgfM8ysHA3/edit?usp=sharing&ouid=102752648406195568586&rtpof=true&sd=true), [AReaL-DTA Talk](https://drive.google.com/file/d/11F5qdGVOWUB7y8cY763OZiXlt5BeyCPL/view?usp=sharing) | [Tencent Meeting (Chinese)](https://meeting.tencent.com/crm/2BYLvJvdc5) |

## How to Add Materials

Expand Down
Loading
Loading