Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[maca] support deepseekv2 for maca backend. #2918

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def forward(
is_decoding=is_decoding,
block_size=block_size,
attn_mask=attn_mask,
softmax_scale=self.scale,
is_unpaged_prefill=is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand Down
27 changes: 26 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn

from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding
from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding, YarnRotaryEmbeddingImpl
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder,
RotaryEmbeddingImpl, YarnParameters)

Expand Down Expand Up @@ -128,6 +128,25 @@ def __init__(
self.register_buffer('inv_freq', inv_freq_llama)


class DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl):
"""yarn rotary embedding implementation."""

def __init__(self,
dim: int,
base: int = 10000,
scaling_factor: float = 1.0,
original_max_position_embeddings: int = 4096,
yarn_params: YarnParameters = None):
super().__init__(dim, base, scaling_factor, original_max_position_embeddings, yarn_params)

def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""forward."""
dtype = x.dtype
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)
return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=1.0, mscale=self.mscale, dtype=dtype)


class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
"""rotary embedding dlinfer builder."""

Expand All @@ -150,5 +169,11 @@ def build(
elif emb_type == RopeType.Llama3:
return DlinferLlama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor,
llama3_params.high_freq_factor, max_position_embeddings)
elif emb_type == RopeType.Yarn:
return DlinferYarnRotaryEmbeddingImpl(dim,
base,
scaling_factor,
max_position_embeddings,
yarn_params=yarn_params)
else:
raise NotImplementedError(f'Unsupported embedding type: {emb_type}')
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def prefill_attention(
max_kv_seq_len: int,
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
softmax_scale: Optional[float],
is_unpaged_prefill: Optional[bool],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
Expand All @@ -39,6 +40,7 @@ def prefill_attention(
num_q_heads,
num_kv_heads,
attn_mask,
softmax_scale=softmax_scale,
attn_output=attn_output,
)
else:
Expand All @@ -59,6 +61,7 @@ def prefill_attention(
num_q_heads,
num_kv_heads,
attn_mask,
softmax_scale=softmax_scale,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -75,6 +78,7 @@ def paged_token_attention(
max_kv_seq_len,
block_offsets,
block_size,
softmax_scale: Optional[float],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
Expand All @@ -91,6 +95,7 @@ def paged_token_attention(
max_kv_seq_len,
num_q_heads,
num_kv_heads,
softmax_scale=softmax_scale,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -115,6 +120,7 @@ def paged_attention_fwd(
is_decoding: bool,
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
softmax_scale: Optional[float] = None,
is_unpaged_prefill: Optional[bool] = None,
kv_scales: Optional[Tensor] = None,
kv_zeros: Optional[Tensor] = None,
Expand All @@ -137,6 +143,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_size,
attn_mask,
softmax_scale,
is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -152,6 +159,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
softmax_scale=softmax_scale,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
Expand Down