Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions configs/deepseek_v3/sft.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
max_steps = 300
output_dir = "./output/deepseek_v3"
loss_impl = "liger_fused"

[wandb]
project = "prime-rl"
name = "deepseek_v3"
offline = true

[ckpt]
interval = 300
keep_last = 5
resume_step = -1

[ckpt.weights]
save_format = "safetensors"

[model]
moe_use_grouped_mm = true
name = "v2ray/DeepSeek-V3-1B-Test"
optimization_dtype = "bfloat16"
reduce_dtype = "bfloat16"
trust_remote_code = false
seq_len = 1024
ep = 1
cp = 1
dp_replicate = 1
attn = "sdpa"
fsdp_cpu_offload = false
reshard_after_forward = true
impl = "custom"
Comment thread
cursor[bot] marked this conversation as resolved.

[model.ac]

[model.compile]

[optim]
type = "adamw"
lr = 9e-5
weight_decay = 0.01

[scheduler]

[data]
type = "sft"
name = "PrimeIntellect/Reverse-Text-SFT"
batch_size = 32
# micro_batch_size = 1
seq_len = 1024
pack_function = "stack"
shuffle = true
seed = 42

[deployment]
type = "single_node"
num_gpus = 2
4 changes: 4 additions & 0 deletions src/prime_rl/trainer/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from prime_rl.trainer.models.qwen3 import Qwen3ForCausalLM
from prime_rl.trainer.models.qwen3_5_moe import Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM
from prime_rl.trainer.models.qwen3_moe import Qwen3MoeConfig, Qwen3MoeForCausalLM
from prime_rl.trainer.models.deepseek_v3 import DeepseekV3Config, DeepseekV3ForCausalLM

# Make custom config discoverable by AutoConfig
AutoConfig.register("afmoe", AfmoeConfig, exist_ok=True)
Expand All @@ -32,6 +33,8 @@
AutoConfig.register("nemotron_h", NemotronHConfig, exist_ok=True)
AutoConfig.register("qwen3_moe", Qwen3MoeConfig, exist_ok=True)
AutoConfig.register("qwen3_5_moe_text", Qwen3_5MoeConfig, exist_ok=True)
AutoConfig.register("deepseek_v3", DeepseekV3Config, exist_ok=True)

# GptOssConfig is just HF's class - already registered by transformers, no override needed.

_CUSTOM_CAUSAL_LM_MAPPING = _LazyAutoMapping(CONFIG_MAPPING_NAMES, OrderedDict())
Expand All @@ -46,6 +49,7 @@
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3MoeConfig, Qwen3MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(Qwen3_5MoeConfig, Qwen3_5MoeForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(GptOssConfig, GptOssForCausalLM, exist_ok=True)
_CUSTOM_CAUSAL_LM_MAPPING.register(DeepseekV3Config, DeepseekV3ForCausalLM, exist_ok=True)


class AutoModelForCausalLMPrimeRL(_BaseAutoModelClass):
Expand Down
15 changes: 15 additions & 0 deletions src/prime_rl/trainer/models/deepseek_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from prime_rl.trainer.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from prime_rl.trainer.models.deepseek_v3.modeling_deepseek_v3 import (
DeepseekV3ForCausalLM,
DeepseekV3Model,
DeepseekV3PreTrainedModel,
)

__all__ = [
"DeepseekV3Config",
"DeepseekV3ForCausalLM",
"DeepseekV3Model",
"DeepseekV3PreTrainedModel",
]

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New custom model missing required KL mismatch table

Medium Severity

This PR introduces deepseek_v3 as a new custom model but does not include the required table showing mean KL mismatch across 20 steps on a math environment with batch_size=64. Per project rules, all entries in such a table must be lower than 0.015 before the PR can be accepted.

Fix in Cursor Fix in Web

Triggered by project rule: BugBot Instructions

Reviewed by Cursor Bugbot for commit b70c306. Configure here.

131 changes: 131 additions & 0 deletions src/prime_rl/trainer/models/deepseek_v3/attention_deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import torch

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file seems to copy most of the current attention impl without any (or few) changes, any reason for it? If there are any changes, let's move them to the shared impl if not breaking?

@Sirorezka Sirorezka Jun 2, 2026

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't been able to reuse them, because 'FlashAttention' and 'SDPAAttention' have their own versions of q,k,v projections which are differ from the ones that are used in DeepSeek. Because of this to reuse original attention I would need to rewrite 'init', 'forward' and 'attn_projections' methods. Which is basically the same as rewriting whole module from scratch.

from torch import Tensor, nn
import torch.nn.functional as F
from prime_rl.trainer.models.deepseek_v3.configuration_deepseek_v3 import (
DeepseekV3Config,
)
from typing import Callable

# Flash attention imports
try:
from flash_attn import flash_attn_varlen_func
from flash_attn import flash_attn_func as fa2_func
except ImportError:
flash_attn_varlen_func = None # type: ignore
fa2_func = None # type: ignore

try:
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except ImportError:
flash_attn_3_varlen_func = None # type: ignore

try:
from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func
except ImportError:
flash_attn_4_varlen_func = None # type: ignore


from prime_rl.utils.logger import get_logger


class DeepSeekAttentionCore:

_flash_attn_version_mapper = {
"flash_attention_2": 2,
"flash_attention_3": 3,
"fa4": 4,
}

_funcs = {
2: flash_attn_varlen_func,
3: flash_attn_3_varlen_func,
4: flash_attn_4_varlen_func,
}

def __init__(self, config: DeepseekV3Config):

self.func: Callable | None = None
self._flash_attn_version: int = -1

self.num_queries_per_kv = (
config.num_attention_heads // config.num_key_value_heads
)

attn_impl = config._attn_implementation
if attn_impl in ("eager", "sdpa"):
self.attn_impl = "sdpa"
elif attn_impl in self._flash_attn_version_mapper:
# flash attention
self.attn_impl = config._attn_implementation
self._flash_attn_version = self._flash_attn_version_mapper[attn_impl]
self.func = self._funcs[self._flash_attn_version]
self._flash_attn_call = self.func
if self._flash_attn_version == 4:
self._flash_attn_call = torch._dynamo.disable(self.func)
else:
raise ValueError(
f"Not supportted attention '{config._attn_implementation}'. "
)
Comment thread
cursor[bot] marked this conversation as resolved.

def _compute_attention(
self, q, k, v, cu_seqlens, max_seqlen, softmax_scale: float | None = None
):
### !! MUST BE PATCHED BY RING_ATTN

args = [q, k, v, cu_seqlens, cu_seqlens]
if self._flash_attn_version != 4:
args.extend([max_seqlen, max_seqlen])

kwargs: dict = {"causal": True}

if softmax_scale:
kwargs["softmax_scale"] = softmax_scale

out = self._flash_attn_call(*args, **kwargs)
if isinstance(out, tuple):
out = out[0]
return out

def _attention_core(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens: torch.LongTensor | None,
max_seqlen: int | None,
scale: float | None = None,
) -> torch.Tensor:
"""
Inputs:
# q,k,v = (bs, sl, nkv, d)
"""

if cu_seqlens is None:
# self.attn_impl == 'sdpa'
# q,k,v: (batch_size, seqlen, nheads, headdim)
num_queries_per_kv = self.num_queries_per_kv
if num_queries_per_kv > 1:
k = k.repeat_interleave(num_queries_per_kv, dim=2)
v = v.repeat_interleave(num_queries_per_kv, dim=2)

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

# q,k,v = (bs, nkv, sl, d)
attn_output = F.scaled_dot_product_attention(
q, k, v, is_causal=True, scale=scale
)
attn_output = attn_output.transpose(1, 2)
elif q.shape[0] > 1:
# self.attn_impl == 'flash_attention'
attn_output = fa2_func(q, k, v, causal=True, softmax_scale=scale)
else:
# Varlen Attention
# inputs (bs==1, sl, nkv, d)
attn_output = self._compute_attention(
q[0], k[0], v[0], cu_seqlens, max_seqlen, softmax_scale=scale
)
attn_output = attn_output.unsqueeze(0)

return attn_output.contiguous()
Comment thread
cursor[bot] marked this conversation as resolved.
147 changes: 147 additions & 0 deletions src/prime_rl/trainer/models/deepseek_v3/configuration_deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from transformers.configuration_utils import PretrainedConfig


class DeepseekV3Config(PretrainedConfig):
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]

def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
num_hidden_layers=61,
num_attention_heads=16,
num_key_value_heads=16,
max_position_embeddings=163840,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
hidden_act="silu",
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
first_k_dense_replace=3,
q_lora_rank=None,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
moe_intermediate_size=2048,
n_routed_experts=64,
n_shared_experts=1,
num_experts_per_tok=8,
moe_layer_freq=1,
n_group=1,
topk_group=1,
num_cycles=1,
# num_experts_per_tok_k=0,
scoring_func="sigmoid",
norm_topk_prob=True,
routed_scaling_factor=2.5,
seq_scope=None,
long_context_remap=None,
rope_ver="v1",
ep_size=1,
num_nextn_predict_layers=1,
load_balance_coeff=1,
use_grouped_mm=True,
rope_interleave=True,
rope_parameters: dict | None = None,
**kwargs,
):
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.max_position_embeddings = max_position_embeddings
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache

self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.head_dim = qk_rope_head_dim ### !!! Not actual head dim.
self.v_head_dim = v_head_dim

self.moe_intermediate_size = moe_intermediate_size
self.n_routed_experts = n_routed_experts
self.num_local_experts = n_routed_experts
self.n_shared_experts = n_shared_experts
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.n_group = n_group
self.topk_group = topk_group
self.num_cycles = num_cycles
# self.num_experts_per_tok_k = num_experts_per_tok_k
self.scoring_func = scoring_func
self.norm_topk_prob = norm_topk_prob
self.routed_scaling_factor = routed_scaling_factor
self.seq_scope = seq_scope
self.long_context_remap = long_context_remap
self.rope_ver = rope_ver
self.ep_size = ep_size
self.num_nextn_predict_layers = num_nextn_predict_layers

self.load_balance_coeff = load_balance_coeff
self.use_grouped_mm = use_grouped_mm
self.rope_interleave = rope_interleave

super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

if hasattr(self, "rope_scaling") and isinstance(self.rope_scaling, dict):
rope_conf = self.rope_scaling
self.rope_type = rope_conf.get("rope_type", rope_conf.get("type"))
for key in ["beta_fast", "beta_slow"]:
# convert to float
if key in rope_conf:
rope_conf[key] = 1.0 * rope_conf[key]
else:
self.rope_type = "default"

self.rope_parameters = self.rope_scaling or self.rope_parameters
self.rope_parameters = (
self.rope_parameters if self.rope_parameters is not None else {}
)

self.__validate__()

def __validate__(self):

assert self.qk_nope_head_dim + self.qk_rope_head_dim == self.qk_head_dim
assert self.n_routed_experts % self.n_group == 0 # required for TopK router

# router expert_bias always used in HF implementation
assert self.load_balance_coeff > 0

# we always use at least top2 experts from each group
assert self.n_routed_experts // self.n_group >= 2

@property
def rope_total_dim(self):
return self.num_attention_heads * self.qk_rope_head_dim


__all__ = ["DeepseekV3Config"]
Loading