diff --git a/lmdeploy/pytorch/backends/dlinfer/lora.py b/lmdeploy/pytorch/backends/dlinfer/lora.py new file mode 100644 index 0000000000..2c36f4dd32 --- /dev/null +++ b/lmdeploy/pytorch/backends/dlinfer/lora.py @@ -0,0 +1,78 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from dataclasses import dataclass + +import torch + +from lmdeploy.pytorch.kernels.dlinfer.fused_lora import fused_lora +from lmdeploy.pytorch.model_inputs import StepContextManager + +from ..lora import AdapterInfo, LoRABuilder, LoRAImpl + + +@dataclass +class PackedLoRAInput: + """Packed lora input.""" + x: torch.Tensor + q_start_loc: torch.Tensor + q_seqlens: torch.Tensor + adapter_ids: torch.Tensor + max_seq_len: int + is_decoding: bool + + +class DlinferLoRAImpl(LoRAImpl): + """Triton lora implementation.""" + + @staticmethod + def _make_packed_lora_input(x, ctx_mgr): + """Make PackedLoRAInput.""" + context = ctx_mgr.current_context() + + # adapter cache + max_q_seq_length = x.numel() // x.size(-1) + + return PackedLoRAInput(x=x.flatten(0, -2).contiguous(), + q_start_loc=context.q_start_loc, + q_seqlens=context.q_seqlens, + adapter_ids=context.local_adapter_ids, + max_seq_len=max_q_seq_length, + is_decoding=context.is_decoding) + + def forward(self, + x: torch.Tensor, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + base_output: torch.Tensor, + adapter_info: AdapterInfo, + ctx_mgr: StepContextManager, + colwise: bool, + is_tp: bool = True): + """forward.""" + lora_input = self._make_packed_lora_input(x, ctx_mgr) + + return fused_lora( + lora_input.x, + lora_A, + lora_B, + scaling=adapter_info.scalings, + rank_start=adapter_info.rank_offsets, + ranks=adapter_info.ranks, + seq_start=lora_input.q_start_loc, + seq_lens=lora_input.q_seqlens, + adapter_ids=lora_input.adapter_ids, + max_rank=adapter_info.max_rank, + max_seqlen=lora_input.max_seq_len, + slice_start=adapter_info.base_slice.start, + slice_stop=adapter_info.base_slice.stop, + slice_step=adapter_info.base_slice.step, + output=base_output, + ) + + +class DlinferLoRABuilder(LoRABuilder): + """Dlinfer lora layer builder.""" + + @staticmethod + def build(): + """build.""" + return DlinferLoRAImpl() diff --git a/lmdeploy/pytorch/backends/dlinfer/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/op_backend.py index 16eb604ccd..70820adb6f 100644 --- a/lmdeploy/pytorch/backends/dlinfer/op_backend.py +++ b/lmdeploy/pytorch/backends/dlinfer/op_backend.py @@ -58,6 +58,9 @@ def get_layer_impl_builder(cls, layer_type: OpType): elif layer_type == OpType.RotaryEmbedding: from .rotary_embedding import DlinferRotaryEmbeddingBuilder return DlinferRotaryEmbeddingBuilder + elif layer_type == OpType.LoRA: + from .lora import DlinferLoRABuilder + return DlinferLoRABuilder else: logger.debug(f'Op {layer_type} fallback to default implementation.') return super().get_layer_impl_builder(layer_type) diff --git a/lmdeploy/pytorch/kernels/dlinfer/fused_lora.py b/lmdeploy/pytorch/kernels/dlinfer/fused_lora.py new file mode 100644 index 0000000000..c99fb5e229 --- /dev/null +++ b/lmdeploy/pytorch/kernels/dlinfer/fused_lora.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + +import dlinfer.ops as ext_ops +from torch import Tensor + + +def fused_lora(input: Tensor, lora_a: Tensor, lora_b: Tensor, scaling: Tensor, rank_start: Tensor, ranks: Tensor, + seq_start: Tensor, seq_lens: Tensor, adapter_ids: Tensor, max_rank: int, max_seqlen: int, + slice_start: int, slice_stop: int, slice_step: Optional[int], output: Optional[Tensor]): + """Fused lora.""" + return ext_ops.fused_lora(input, lora_a, lora_b, scaling, rank_start, ranks, seq_start, seq_lens, adapter_ids, + max_rank, max_seqlen, slice_start, slice_stop, slice_step, output) diff --git a/requirements/runtime_ascend.txt b/requirements/runtime_ascend.txt index e680ed706a..5546d58b49 100644 --- a/requirements/runtime_ascend.txt +++ b/requirements/runtime_ascend.txt @@ -17,8 +17,10 @@ safetensors sentencepiece shortuuid tiktoken -torch<=2.4.0,>=2.3.1 -torch-npu==2.3.1 +# Supported torch versions: 2.3.1, 2.5.1, 2.6.0, 2.7.1 +# Please install one of the supported versions manually +torch>=2.3.1,<2.8.0 +torch-npu>=2.3.1,<2.8.0 torchvision<=0.19.0,>=0.18.1 transformers uvicorn