Skip to content

custom triton cache manager #3659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions lmdeploy/pytorch/backends/cuda/op_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Tuple

import torch
Expand Down Expand Up @@ -179,6 +180,8 @@ def update_step_context(cls, step_context):
def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_config: CacheConfig,
backend_config: BackendConfig, device: torch.device):
"""Build graph runner."""
from lmdeploy.pytorch import envs

from .graph_runner import CUDAGraphRunner
from .warmup_manager import WarmupMeta, get_warmup_manager

Expand All @@ -190,6 +193,10 @@ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig, cache_
)
get_warmup_manager().warmup(warmup_meta)

# add custom triton cache manager
if envs.triton_custom_cache_mgr_enable:
os.environ['TRITON_CACHE_MANAGER'] = 'lmdeploy.pytorch.kernels.cuda.triton_utils:MPLockCacheManager'

# make graph runner.
return CUDAGraphRunner(model, model_config, cache_config, backend_config, device)

Expand Down
3 changes: 3 additions & 0 deletions lmdeploy/pytorch/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ def _patched_get_env(
# logging
log_file = os.getenv('LMDEPLOY_LOG_FILE', None)

# triton
triton_custom_cache_mgr_enable = env_to_bool('LMDEPLOY_TRITON_CUSTOM_CACHE_MGR_ENABLE', False)


def get_all_envs():
"""Get all environment variables."""
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from .apply_rotary_pos_emb import apply_rotary_pos_emb
from .fill_kv_cache import fill_kv_cache
from .fused_moe import fused_moe
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -14,7 +13,6 @@
__all__ = [
'apply_rotary_pos_emb',
'fused_moe',
'fused_rotary_emb',
'paged_attention_fwd',
'alibi_paged_attention_fwd',
'fill_kv_cache',
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from .flashattention import flash_attention_fwd
from .flatten_kv_cache import flatten_kv_cache
from .fused_moe import fused_moe
from .fused_rotary_emb import fused_rotary_emb
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
Expand All @@ -17,7 +16,6 @@
__all__ = [
'apply_rotary_pos_emb',
'fused_moe',
'fused_rotary_emb',
'paged_attention_fwd',
'alibi_paged_attention_fwd',
'fill_kv_cache',
Expand Down
25 changes: 6 additions & 19 deletions lmdeploy/pytorch/kernels/cuda/alibi_pagedattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,9 @@
import triton.language as tl
from torch import Tensor

from .triton_utils import get_kernel_meta, wrap_jit_func

assert triton.__version__ >= '2.1.0'

LOG2: tl.constexpr = math.log(2)
LOG2 = tl.constexpr(math.log(2))


@triton.jit
Expand Down Expand Up @@ -65,7 +63,6 @@ def _load_block_offsets(offset_ptr, block_id, num_sub_blocks: tl.constexpr, BLOC
return tl.load(offset_ptr + block_id) * BLOCK + offs_n


@wrap_jit_func
@triton.jit
def _fwd_split_kernel(
Q,
Expand Down Expand Up @@ -200,7 +197,6 @@ def _fwd_split_kernel(
tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)


@wrap_jit_func
@triton.jit
def _reduce_split_kernel(
Acc,
Expand Down Expand Up @@ -244,7 +240,6 @@ def _reduce_split_kernel(
tl.store(Out + out_offs, acc)


@wrap_jit_func
@triton.jit
def _fwd_kernel(
Q,
Expand Down Expand Up @@ -375,7 +370,6 @@ def _fwd_kernel(
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)


@wrap_jit_func
@triton.jit
def _fwd_split_kernel_quant(
Q,
Expand Down Expand Up @@ -561,7 +555,6 @@ def _fwd_split_kernel_quant(
tl.store(Acc_out + off_meta + 1 + tl.arange(0, 1), l_i)


@wrap_jit_func
@triton.jit
def _fwd_kernel_quant(
Q,
Expand Down Expand Up @@ -802,7 +795,6 @@ def alibi_paged_attention_fwd(
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,

num_warps = 4 if Lq <= 64 else 8
kernel_meta = get_kernel_meta(q)
is_decoding = q.shape[-3] == b_seq_len.size(0)
if not is_decoding:
if quant_policy > 0:
Expand Down Expand Up @@ -846,8 +838,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
num_stages=1)
else:
_fwd_kernel[grid](q,
k,
Expand Down Expand Up @@ -880,8 +871,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
num_stages=1)
else:
SPLIT_K = 4
grid = (batch, head, SPLIT_K)
Expand Down Expand Up @@ -927,8 +917,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=4,
num_stages=1,
**kernel_meta)
num_stages=1)

else:
_fwd_split_kernel[grid](q,
Expand Down Expand Up @@ -961,8 +950,7 @@ def alibi_paged_attention_fwd(
BLOCK_DMODEL=Lq,
BLOCK_N=BLOCK,
num_warps=4,
num_stages=1,
**kernel_meta)
num_stages=1)

grid = (batch, head)
_reduce_split_kernel[grid](acc,
Expand All @@ -977,5 +965,4 @@ def alibi_paged_attention_fwd(
SPLIT_K=SPLIT_K,
BLOCK_DMODEL=Lq,
num_warps=num_warps,
num_stages=1,
**kernel_meta)
num_stages=1)
4 changes: 1 addition & 3 deletions lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,7 @@ def _gemm_fp8_tma_kernel(
'BLOCK_N': 64,
}, num_stages=3, num_warps=4)
],
key=['N', 'K'],
warmup=5,
rep=10)
key=['N', 'K'])
@triton.jit
def _gemm_fp8_kernel(
A,
Expand Down
2 changes: 0 additions & 2 deletions lmdeploy/pytorch/kernels/cuda/fused_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ def _atomic_store(ptrs, val, mask):
configs=get_autotune_config(),
key=['N', 'K'],
restore_value=['c_ptr'],
warmup=5,
rep=20,
)
@triton.jit
def _fused_lora_kernel(
Expand Down
141 changes: 0 additions & 141 deletions lmdeploy/pytorch/kernels/cuda/fused_rotary_emb.py

This file was deleted.

Loading