-
Notifications
You must be signed in to change notification settings - Fork 306
Added deepseek_v3 in models #2681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| max_steps = 300 | ||
| output_dir = "./output/deepseek_v3" | ||
| loss_impl = "liger_fused" | ||
|
|
||
| [wandb] | ||
| project = "prime-rl" | ||
| name = "deepseek_v3" | ||
| offline = true | ||
|
|
||
| [ckpt] | ||
| interval = 300 | ||
| keep_last = 5 | ||
| resume_step = -1 | ||
|
|
||
| [ckpt.weights] | ||
| save_format = "safetensors" | ||
|
|
||
| [model] | ||
| moe_use_grouped_mm = true | ||
| name = "v2ray/DeepSeek-V3-1B-Test" | ||
| optimization_dtype = "bfloat16" | ||
| reduce_dtype = "bfloat16" | ||
| trust_remote_code = false | ||
| seq_len = 1024 | ||
| ep = 1 | ||
| cp = 1 | ||
| dp_replicate = 1 | ||
| attn = "sdpa" | ||
| fsdp_cpu_offload = false | ||
| reshard_after_forward = true | ||
| impl = "custom" | ||
|
|
||
| [model.ac] | ||
|
|
||
| [model.compile] | ||
|
|
||
| [optim] | ||
| type = "adamw" | ||
| lr = 9e-5 | ||
| weight_decay = 0.01 | ||
|
|
||
| [scheduler] | ||
|
|
||
| [data] | ||
| type = "sft" | ||
| name = "PrimeIntellect/Reverse-Text-SFT" | ||
| batch_size = 32 | ||
| # micro_batch_size = 1 | ||
| seq_len = 1024 | ||
| pack_function = "stack" | ||
| shuffle = true | ||
| seed = 42 | ||
|
|
||
| [deployment] | ||
| type = "single_node" | ||
| num_gpus = 2 | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from prime_rl.trainer.models.deepseek_v3.configuration_deepseek_v3 import ( | ||
| DeepseekV3Config, | ||
| ) | ||
| from prime_rl.trainer.models.deepseek_v3.modeling_deepseek_v3 import ( | ||
| DeepseekV3ForCausalLM, | ||
| DeepseekV3Model, | ||
| DeepseekV3PreTrainedModel, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "DeepseekV3Config", | ||
| "DeepseekV3ForCausalLM", | ||
| "DeepseekV3Model", | ||
| "DeepseekV3PreTrainedModel", | ||
| ] | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New custom model missing required KL mismatch tableMedium Severity This PR introduces Triggered by project rule: BugBot Instructions Reviewed by Cursor Bugbot for commit b70c306. Configure here. |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| import torch | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file seems to copy most of the current attention impl without any (or few) changes, any reason for it? If there are any changes, let's move them to the shared impl if not breaking?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't been able to reuse them, because 'FlashAttention' and 'SDPAAttention' have their own versions of q,k,v projections which are differ from the ones that are used in DeepSeek. Because of this to reuse original attention I would need to rewrite 'init', 'forward' and 'attn_projections' methods. Which is basically the same as rewriting whole module from scratch. |
||
| from torch import Tensor, nn | ||
| import torch.nn.functional as F | ||
| from prime_rl.trainer.models.deepseek_v3.configuration_deepseek_v3 import ( | ||
| DeepseekV3Config, | ||
| ) | ||
| from typing import Callable | ||
|
|
||
| # Flash attention imports | ||
| try: | ||
| from flash_attn import flash_attn_varlen_func | ||
| from flash_attn import flash_attn_func as fa2_func | ||
| except ImportError: | ||
| flash_attn_varlen_func = None # type: ignore | ||
| fa2_func = None # type: ignore | ||
|
|
||
| try: | ||
| from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func | ||
| except ImportError: | ||
| flash_attn_3_varlen_func = None # type: ignore | ||
|
|
||
| try: | ||
| from flash_attn.cute import flash_attn_varlen_func as flash_attn_4_varlen_func | ||
| except ImportError: | ||
| flash_attn_4_varlen_func = None # type: ignore | ||
|
|
||
|
|
||
| from prime_rl.utils.logger import get_logger | ||
|
|
||
|
|
||
| class DeepSeekAttentionCore: | ||
|
|
||
| _flash_attn_version_mapper = { | ||
| "flash_attention_2": 2, | ||
| "flash_attention_3": 3, | ||
| "fa4": 4, | ||
| } | ||
|
|
||
| _funcs = { | ||
| 2: flash_attn_varlen_func, | ||
| 3: flash_attn_3_varlen_func, | ||
| 4: flash_attn_4_varlen_func, | ||
| } | ||
|
|
||
| def __init__(self, config: DeepseekV3Config): | ||
|
|
||
| self.func: Callable | None = None | ||
| self._flash_attn_version: int = -1 | ||
|
|
||
| self.num_queries_per_kv = ( | ||
| config.num_attention_heads // config.num_key_value_heads | ||
| ) | ||
|
|
||
| attn_impl = config._attn_implementation | ||
| if attn_impl in ("eager", "sdpa"): | ||
| self.attn_impl = "sdpa" | ||
| elif attn_impl in self._flash_attn_version_mapper: | ||
| # flash attention | ||
| self.attn_impl = config._attn_implementation | ||
| self._flash_attn_version = self._flash_attn_version_mapper[attn_impl] | ||
| self.func = self._funcs[self._flash_attn_version] | ||
| self._flash_attn_call = self.func | ||
| if self._flash_attn_version == 4: | ||
| self._flash_attn_call = torch._dynamo.disable(self.func) | ||
| else: | ||
| raise ValueError( | ||
| f"Not supportted attention '{config._attn_implementation}'. " | ||
| ) | ||
|
cursor[bot] marked this conversation as resolved.
|
||
|
|
||
| def _compute_attention( | ||
| self, q, k, v, cu_seqlens, max_seqlen, softmax_scale: float | None = None | ||
| ): | ||
| ### !! MUST BE PATCHED BY RING_ATTN | ||
|
|
||
| args = [q, k, v, cu_seqlens, cu_seqlens] | ||
| if self._flash_attn_version != 4: | ||
| args.extend([max_seqlen, max_seqlen]) | ||
|
|
||
| kwargs: dict = {"causal": True} | ||
|
|
||
| if softmax_scale: | ||
| kwargs["softmax_scale"] = softmax_scale | ||
|
|
||
| out = self._flash_attn_call(*args, **kwargs) | ||
| if isinstance(out, tuple): | ||
| out = out[0] | ||
| return out | ||
|
|
||
| def _attention_core( | ||
| self, | ||
| q: torch.Tensor, | ||
| k: torch.Tensor, | ||
| v: torch.Tensor, | ||
| cu_seqlens: torch.LongTensor | None, | ||
| max_seqlen: int | None, | ||
| scale: float | None = None, | ||
| ) -> torch.Tensor: | ||
| """ | ||
| Inputs: | ||
| # q,k,v = (bs, sl, nkv, d) | ||
| """ | ||
|
|
||
| if cu_seqlens is None: | ||
| # self.attn_impl == 'sdpa' | ||
| # q,k,v: (batch_size, seqlen, nheads, headdim) | ||
| num_queries_per_kv = self.num_queries_per_kv | ||
| if num_queries_per_kv > 1: | ||
| k = k.repeat_interleave(num_queries_per_kv, dim=2) | ||
| v = v.repeat_interleave(num_queries_per_kv, dim=2) | ||
|
|
||
| q = q.transpose(1, 2) | ||
| k = k.transpose(1, 2) | ||
| v = v.transpose(1, 2) | ||
|
|
||
| # q,k,v = (bs, nkv, sl, d) | ||
| attn_output = F.scaled_dot_product_attention( | ||
| q, k, v, is_causal=True, scale=scale | ||
| ) | ||
| attn_output = attn_output.transpose(1, 2) | ||
| elif q.shape[0] > 1: | ||
| # self.attn_impl == 'flash_attention' | ||
| attn_output = fa2_func(q, k, v, causal=True, softmax_scale=scale) | ||
| else: | ||
| # Varlen Attention | ||
| # inputs (bs==1, sl, nkv, d) | ||
| attn_output = self._compute_attention( | ||
| q[0], k[0], v[0], cu_seqlens, max_seqlen, softmax_scale=scale | ||
| ) | ||
| attn_output = attn_output.unsqueeze(0) | ||
|
|
||
| return attn_output.contiguous() | ||
|
cursor[bot] marked this conversation as resolved.
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| from transformers.configuration_utils import PretrainedConfig | ||
|
|
||
|
|
||
| class DeepseekV3Config(PretrainedConfig): | ||
| model_type = "deepseek_v3" | ||
| keys_to_ignore_at_inference = ["past_key_values"] | ||
|
|
||
| def __init__( | ||
| self, | ||
| vocab_size=129280, | ||
| hidden_size=7168, | ||
| intermediate_size=18432, | ||
| num_hidden_layers=61, | ||
| num_attention_heads=16, | ||
| num_key_value_heads=16, | ||
| max_position_embeddings=163840, | ||
| rope_theta=10000.0, | ||
| rope_scaling=None, | ||
| attention_bias=False, | ||
| attention_dropout=0.0, | ||
| hidden_act="silu", | ||
| initializer_range=0.02, | ||
| rms_norm_eps=1e-6, | ||
| use_cache=True, | ||
| pad_token_id=None, | ||
| bos_token_id=0, | ||
| eos_token_id=1, | ||
| tie_word_embeddings=False, | ||
| first_k_dense_replace=3, | ||
| q_lora_rank=None, | ||
| kv_lora_rank=512, | ||
| qk_nope_head_dim=128, | ||
| qk_rope_head_dim=64, | ||
| v_head_dim=128, | ||
| moe_intermediate_size=2048, | ||
| n_routed_experts=64, | ||
| n_shared_experts=1, | ||
| num_experts_per_tok=8, | ||
| moe_layer_freq=1, | ||
| n_group=1, | ||
| topk_group=1, | ||
| num_cycles=1, | ||
| # num_experts_per_tok_k=0, | ||
| scoring_func="sigmoid", | ||
| norm_topk_prob=True, | ||
| routed_scaling_factor=2.5, | ||
| seq_scope=None, | ||
| long_context_remap=None, | ||
| rope_ver="v1", | ||
| ep_size=1, | ||
| num_nextn_predict_layers=1, | ||
| load_balance_coeff=1, | ||
| use_grouped_mm=True, | ||
| rope_interleave=True, | ||
| rope_parameters: dict | None = None, | ||
| **kwargs, | ||
| ): | ||
| self.vocab_size = vocab_size | ||
| self.hidden_size = hidden_size | ||
| self.intermediate_size = intermediate_size | ||
| self.num_hidden_layers = num_hidden_layers | ||
| self.num_attention_heads = num_attention_heads | ||
| self.num_key_value_heads = num_key_value_heads | ||
| self.max_position_embeddings = max_position_embeddings | ||
| self.rope_theta = rope_theta | ||
| self.rope_scaling = rope_scaling | ||
| self.attention_bias = attention_bias | ||
| self.attention_dropout = attention_dropout | ||
| self.hidden_act = hidden_act | ||
| self.initializer_range = initializer_range | ||
| self.rms_norm_eps = rms_norm_eps | ||
| self.use_cache = use_cache | ||
|
|
||
| self.q_lora_rank = q_lora_rank | ||
| self.kv_lora_rank = kv_lora_rank | ||
| self.qk_nope_head_dim = qk_nope_head_dim | ||
| self.qk_rope_head_dim = qk_rope_head_dim | ||
| self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim | ||
| self.head_dim = qk_rope_head_dim ### !!! Not actual head dim. | ||
| self.v_head_dim = v_head_dim | ||
|
|
||
| self.moe_intermediate_size = moe_intermediate_size | ||
| self.n_routed_experts = n_routed_experts | ||
| self.num_local_experts = n_routed_experts | ||
| self.n_shared_experts = n_shared_experts | ||
| self.num_experts_per_tok = num_experts_per_tok | ||
| self.moe_layer_freq = moe_layer_freq | ||
| self.first_k_dense_replace = first_k_dense_replace | ||
| self.n_group = n_group | ||
| self.topk_group = topk_group | ||
| self.num_cycles = num_cycles | ||
| # self.num_experts_per_tok_k = num_experts_per_tok_k | ||
| self.scoring_func = scoring_func | ||
| self.norm_topk_prob = norm_topk_prob | ||
| self.routed_scaling_factor = routed_scaling_factor | ||
| self.seq_scope = seq_scope | ||
| self.long_context_remap = long_context_remap | ||
| self.rope_ver = rope_ver | ||
| self.ep_size = ep_size | ||
| self.num_nextn_predict_layers = num_nextn_predict_layers | ||
|
|
||
| self.load_balance_coeff = load_balance_coeff | ||
| self.use_grouped_mm = use_grouped_mm | ||
| self.rope_interleave = rope_interleave | ||
|
|
||
| super().__init__( | ||
| pad_token_id=pad_token_id, | ||
| bos_token_id=bos_token_id, | ||
| eos_token_id=eos_token_id, | ||
| tie_word_embeddings=tie_word_embeddings, | ||
| **kwargs, | ||
| ) | ||
|
|
||
| if hasattr(self, "rope_scaling") and isinstance(self.rope_scaling, dict): | ||
| rope_conf = self.rope_scaling | ||
| self.rope_type = rope_conf.get("rope_type", rope_conf.get("type")) | ||
| for key in ["beta_fast", "beta_slow"]: | ||
| # convert to float | ||
| if key in rope_conf: | ||
| rope_conf[key] = 1.0 * rope_conf[key] | ||
| else: | ||
| self.rope_type = "default" | ||
|
|
||
| self.rope_parameters = self.rope_scaling or self.rope_parameters | ||
| self.rope_parameters = ( | ||
| self.rope_parameters if self.rope_parameters is not None else {} | ||
| ) | ||
|
|
||
| self.__validate__() | ||
|
|
||
| def __validate__(self): | ||
|
|
||
| assert self.qk_nope_head_dim + self.qk_rope_head_dim == self.qk_head_dim | ||
| assert self.n_routed_experts % self.n_group == 0 # required for TopK router | ||
|
|
||
| # router expert_bias always used in HF implementation | ||
| assert self.load_balance_coeff > 0 | ||
|
|
||
| # we always use at least top2 experts from each group | ||
| assert self.n_routed_experts // self.n_group >= 2 | ||
|
|
||
| @property | ||
| def rope_total_dim(self): | ||
| return self.num_attention_heads * self.qk_rope_head_dim | ||
|
|
||
|
|
||
| __all__ = ["DeepseekV3Config"] |


Uh oh!
There was an error while loading. Please reload this page.