Skip to content

Commit

Permalink
support deepseekv2 for maca backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
Reinerzhou committed Jan 27, 2025
1 parent 97458de commit 5c3f925
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 1 deletion.
1 change: 1 addition & 0 deletions lmdeploy/pytorch/backends/dlinfer/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def forward(
is_decoding=is_decoding,
block_size=block_size,
attn_mask=attn_mask,
softmax_scale=self.scale,
is_unpaged_prefill=is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand Down
27 changes: 26 additions & 1 deletion lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import nn

from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding
from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding, YarnRotaryEmbeddingImpl
from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder,
RotaryEmbeddingImpl, YarnParameters)

Expand Down Expand Up @@ -128,6 +128,25 @@ def __init__(
self.register_buffer('inv_freq', inv_freq_llama)


class DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl):
"""yarn rotary embedding implementation."""

def __init__(self,
dim: int,
base: int = 10000,
scaling_factor: float = 1.0,
original_max_position_embeddings: int = 4096,
yarn_params: YarnParameters = None):
super().__init__(dim, base, scaling_factor, original_max_position_embeddings, yarn_params)

def forward(self, x: torch.Tensor, position_ids: torch.Tensor):
"""forward."""
dtype = x.dtype
if self.inv_freq.device != x.device:
self.inv_freq = self.inv_freq.to(x.device)
return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=1.0, mscale=self.mscale, dtype=dtype)


class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder):
"""rotary embedding dlinfer builder."""

Expand All @@ -150,5 +169,11 @@ def build(
elif emb_type == RopeType.Llama3:
return DlinferLlama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor,
llama3_params.high_freq_factor, max_position_embeddings)
elif emb_type == RopeType.Yarn:
return DlinferYarnRotaryEmbeddingImpl(dim,
base,
scaling_factor,
max_position_embeddings,
yarn_params=yarn_params)
else:
raise NotImplementedError(f'Unsupported embedding type: {emb_type}')
8 changes: 8 additions & 0 deletions lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def prefill_attention(
max_kv_seq_len: int,
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
softmax_scale: Optional[float],
is_unpaged_prefill: Optional[bool],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
Expand All @@ -39,6 +40,7 @@ def prefill_attention(
num_q_heads,
num_kv_heads,
attn_mask,
softmax_scale=softmax_scale,
attn_output=attn_output,
)
else:
Expand All @@ -59,6 +61,7 @@ def prefill_attention(
num_q_heads,
num_kv_heads,
attn_mask,
softmax_scale=softmax_scale,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -75,6 +78,7 @@ def paged_token_attention(
max_kv_seq_len,
block_offsets,
block_size,
softmax_scale: Optional[float],
kv_scales: Optional[Tensor],
kv_zeros: Optional[Tensor],
quant_bits: Optional[int],
Expand All @@ -91,6 +95,7 @@ def paged_token_attention(
max_kv_seq_len,
num_q_heads,
num_kv_heads,
softmax_scale=softmax_scale,
attn_output=attn_output,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -115,6 +120,7 @@ def paged_attention_fwd(
is_decoding: bool,
block_size: int,
attn_mask: Sequence[Optional[Tensor]] = (),
softmax_scale: Optional[float] = None,
is_unpaged_prefill: Optional[bool] = None,
kv_scales: Optional[Tensor] = None,
kv_zeros: Optional[Tensor] = None,
Expand All @@ -137,6 +143,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_size,
attn_mask,
softmax_scale,
is_unpaged_prefill,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
Expand All @@ -152,6 +159,7 @@ def paged_attention_fwd(
max_kv_seq_len,
block_offsets,
block_size,
softmax_scale=softmax_scale,
kv_scales=kv_scales,
kv_zeros=kv_zeros,
quant_bits=quant_bits,
Expand Down

0 comments on commit 5c3f925

Please sign in to comment.