From 5c3f9256b934f4b8814a745b29f4bd919b2c5575 Mon Sep 17 00:00:00 2001 From: zhoushenglong Date: Mon, 27 Jan 2025 03:15:55 +0000 Subject: [PATCH] support deepseekv2 for maca backend. --- .../pytorch/backends/dlinfer/attention.py | 1 + .../backends/dlinfer/rotary_embedding.py | 27 ++++++++++++++++++- .../pytorch/kernels/dlinfer/pagedattention.py | 8 ++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py index 16f332267..094bf5c57 100644 --- a/lmdeploy/pytorch/backends/dlinfer/attention.py +++ b/lmdeploy/pytorch/backends/dlinfer/attention.py @@ -132,6 +132,7 @@ def forward( is_decoding=is_decoding, block_size=block_size, attn_mask=attn_mask, + softmax_scale=self.scale, is_unpaged_prefill=is_unpaged_prefill, kv_scales=kv_scales, kv_zeros=kv_zeros, diff --git a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py index d3a10a9cd..cb31e3162 100644 --- a/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/dlinfer/rotary_embedding.py @@ -5,7 +5,7 @@ import torch from torch import nn -from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding +from ..default.rotary_embedding import LlamaDynamicNTKScalingRotaryEmbedding, YarnRotaryEmbeddingImpl from ..rotary_embedding import (Llama3Parameters, LongRoPEScalingParameters, RopeType, RotaryEmbeddingBuilder, RotaryEmbeddingImpl, YarnParameters) @@ -128,6 +128,25 @@ def __init__( self.register_buffer('inv_freq', inv_freq_llama) +class DlinferYarnRotaryEmbeddingImpl(YarnRotaryEmbeddingImpl): + """yarn rotary embedding implementation.""" + + def __init__(self, + dim: int, + base: int = 10000, + scaling_factor: float = 1.0, + original_max_position_embeddings: int = 4096, + yarn_params: YarnParameters = None): + super().__init__(dim, base, scaling_factor, original_max_position_embeddings, yarn_params) + + def forward(self, x: torch.Tensor, position_ids: torch.Tensor): + """forward.""" + dtype = x.dtype + if self.inv_freq.device != x.device: + self.inv_freq = self.inv_freq.to(x.device) + return _rotary_embedding_fwd(position_ids, self.inv_freq, scaling_factor=1.0, mscale=self.mscale, dtype=dtype) + + class DlinferRotaryEmbeddingBuilder(RotaryEmbeddingBuilder): """rotary embedding dlinfer builder.""" @@ -150,5 +169,11 @@ def build( elif emb_type == RopeType.Llama3: return DlinferLlama3RotaryEmbeddingImpl(dim, base, scaling_factor, llama3_params.low_freq_factor, llama3_params.high_freq_factor, max_position_embeddings) + elif emb_type == RopeType.Yarn: + return DlinferYarnRotaryEmbeddingImpl(dim, + base, + scaling_factor, + max_position_embeddings, + yarn_params=yarn_params) else: raise NotImplementedError(f'Unsupported embedding type: {emb_type}') diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py index 584945bb2..0346bd63a 100644 --- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py +++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py @@ -20,6 +20,7 @@ def prefill_attention( max_kv_seq_len: int, block_size: int, attn_mask: Sequence[Optional[Tensor]], + softmax_scale: Optional[float], is_unpaged_prefill: Optional[bool], kv_scales: Optional[Tensor], kv_zeros: Optional[Tensor], @@ -39,6 +40,7 @@ def prefill_attention( num_q_heads, num_kv_heads, attn_mask, + softmax_scale=softmax_scale, attn_output=attn_output, ) else: @@ -59,6 +61,7 @@ def prefill_attention( num_q_heads, num_kv_heads, attn_mask, + softmax_scale=softmax_scale, attn_output=attn_output, kv_scales=kv_scales, kv_zeros=kv_zeros, @@ -75,6 +78,7 @@ def paged_token_attention( max_kv_seq_len, block_offsets, block_size, + softmax_scale: Optional[float], kv_scales: Optional[Tensor], kv_zeros: Optional[Tensor], quant_bits: Optional[int], @@ -91,6 +95,7 @@ def paged_token_attention( max_kv_seq_len, num_q_heads, num_kv_heads, + softmax_scale=softmax_scale, attn_output=attn_output, kv_scales=kv_scales, kv_zeros=kv_zeros, @@ -115,6 +120,7 @@ def paged_attention_fwd( is_decoding: bool, block_size: int, attn_mask: Sequence[Optional[Tensor]] = (), + softmax_scale: Optional[float] = None, is_unpaged_prefill: Optional[bool] = None, kv_scales: Optional[Tensor] = None, kv_zeros: Optional[Tensor] = None, @@ -137,6 +143,7 @@ def paged_attention_fwd( max_kv_seq_len, block_size, attn_mask, + softmax_scale, is_unpaged_prefill, kv_scales=kv_scales, kv_zeros=kv_zeros, @@ -152,6 +159,7 @@ def paged_attention_fwd( max_kv_seq_len, block_offsets, block_size, + softmax_scale=softmax_scale, kv_scales=kv_scales, kv_zeros=kv_zeros, quant_bits=quant_bits,