Skip to content

Commit

Permalink
This PR introduces KV Cache support for the CPU runtime:
Browse files Browse the repository at this point in the history
Implementation of KV Cache TIR for CPU-based processing.
Updates to the relevant runtime components to integrate KV Cache.
  • Loading branch information
mengshyu committed Feb 2, 2025
1 parent f9da13f commit aad79c1
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 33 deletions.
4 changes: 4 additions & 0 deletions cpp/serve/sampler/cpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,10 @@ class CPUSampler : public SamplerObj {
/*! \brief Copy prob distributions from device to CPU. */
NDArray CopyProbsToCPU(NDArray probs_on_device) {
// probs_on_device: (n, v)
if (probs_on_device->device.device_type == kDLCPU) {
return probs_on_device;
}

ICHECK(probs_on_device->device.device_type != kDLCPU);
if (probs_host_.defined()) {
ICHECK_EQ(probs_host_->shape[1], probs_on_device->shape[1]);
Expand Down
122 changes: 119 additions & 3 deletions python/mlc_llm/compiler_pass/attach_logit_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,48 @@ def __init__(self, target: tvm.target.Target):
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
"""Entrypoint"""
mod = mod.clone()
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
if str(self.target.kind) == "llvm":
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu(self.target)
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu(self.target)
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu(self.target)
else:
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
return mod


def _get_apply_logit_bias_inplace_cpu(target: tvm.target.Target):
@T.prim_func
def _apply_logit_bias_inplace(
var_logits: T.handle,
var_pos2seq_id: T.handle,
var_token_ids: T.handle,
var_logit_bias: T.handle,
) -> None:
"""Function that applies logit bias in place."""
T.func_attr(
{
"global_symbol": "apply_logit_bias_inplace",
"tir.noalias": True,
"tir.is_scheduled": True,
}
)
batch_size = T.int32(is_size_var=True)
vocab_size = T.int32(is_size_var=True)
num_token = T.int32(is_size_var=True)
logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32")
# seq_ids
pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32")

for i in range(num_token):
logits[pos2seq_id[i], token_ids[i]] += logit_bias[i]

return _apply_logit_bias_inplace


def _get_apply_logit_bias_inplace(target: tvm.target.Target):
tx = 1024 # default
max_num_threads_per_block = get_max_num_threads_per_block(target)
Expand Down Expand Up @@ -74,6 +110,50 @@ def _apply_logit_bias_inplace(
return _apply_logit_bias_inplace


def _get_apply_penalty_inplace_cpu(target: tvm.target.Target):
@T.prim_func
def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals
var_logits: T.handle,
var_seq_ids: T.handle,
var_pos2seq_id: T.handle,
var_token_ids: T.handle,
var_token_cnt: T.handle,
var_penalties: T.handle,
) -> None:
"""Function that applies penalties in place."""
T.func_attr(
{
"global_symbol": "apply_penalty_inplace",
"tir.noalias": True,
"tir.is_scheduled": True,
}
)
batch_size = T.int32(is_size_var=True)
vocab_size = T.int32(is_size_var=True)
num_token = T.int32(is_size_var=True)
num_seq = T.int32(is_size_var=True)
logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32")
seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32")

for token in T.serial(num_token):
with T.block("block"):
vp = T.axis.spatial(num_token, token)
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (
penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]
)
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < 0,
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2],
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2],
)

return _apply_penalty_inplace


def _get_apply_penalty_inplace(target: tvm.target.Target):
tx = 1024 # default
max_num_threads_per_block = get_max_num_threads_per_block(target)
Expand Down Expand Up @@ -129,6 +209,42 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
return _apply_penalty_inplace


def _get_apply_bitmask_inplace_cpu(target: tvm.target.Target):
@T.prim_func
def _apply_bitmask_inplace(
var_logits: T.handle,
var_seq_ids: T.handle,
var_bitmask: T.handle,
) -> None:
"""Function that applies vocabulary masking in place."""
T.func_attr(
{
"global_symbol": "apply_bitmask_inplace",
"tir.noalias": True,
"tir.is_scheduled": True,
}
)
batch_size = T.int32(is_size_var=True)
vocab_size = T.int32(is_size_var=True)
num_seq = T.int32(is_size_var=True)
logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32")
seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")

for token in T.serial(num_seq * vocab_size):
with T.block("block"):
vs = T.axis.spatial(num_seq, (token) // vocab_size)
vv = T.axis.spatial(vocab_size, (token) % vocab_size)

logits[seq_ids[vs], vv] = T.if_then_else(
(bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1,
logits[seq_ids[vs], vv],
T.min_value("float32"),
)

return _apply_bitmask_inplace


def _get_apply_bitmask_inplace(target: tvm.target.Target):
tx = 1024 # default
max_num_threads_per_block = get_max_num_threads_per_block(target)
Expand Down
61 changes: 34 additions & 27 deletions python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,31 +213,38 @@ def softmax_with_chunked_sum(
)

sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum}))
max_threads = get_max_num_threads_per_block(target)
TX = 32
TY = max_threads // TX
unroll_depth = 64
# pylint: enable=invalid-name

sch.work_on("softmax_with_chunked_sum")
l0, l1, l2 = sch.get_loops("log_pad")
bx = sch.fuse(l0, l1)
sch.bind(bx, "blockIdx.x")
unroll, ty, tx = sch.split(l2, [None, TY, TX])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)

for block_name in ["sum_exp", "max"]:
block = sch.get_block(block_name)
sch.set_scope(block, buffer_index=0, storage_scope="shared")
sch.compute_at(block, bx)
r_loop = sch.get_loops(block)[-1]
r_loop, tx = sch.split(r_loop, [None, TX])
sch.reorder(tx, r_loop)
sch.bind(tx, "threadIdx.x")
sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1)

return chunk_lse, sch.mod["softmax_with_chunked_sum"]
def apply_gpu_schedule(target, sch):
max_threads = get_max_num_threads_per_block(target)
TX = 32
TY = max_threads // TX
unroll_depth = 64
# pylint: enable=invalid-name

sch.work_on("softmax_with_chunked_sum")
l0, l1, l2 = sch.get_loops("log_pad")
bx = sch.fuse(l0, l1)
sch.bind(bx, "blockIdx.x")
unroll, ty, tx = sch.split(l2, [None, TY, TX])
sch.bind(ty, "threadIdx.y")
sch.bind(tx, "threadIdx.x")
sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)

for block_name in ["sum_exp", "max"]:
block = sch.get_block(block_name)
sch.set_scope(block, buffer_index=0, storage_scope="shared")
sch.compute_at(block, bx)
r_loop = sch.get_loops(block)[-1]
r_loop, tx = sch.split(r_loop, [None, TX])
sch.reorder(tx, r_loop)
sch.bind(tx, "threadIdx.x")
sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1)

return chunk_lse, sch.mod["softmax_with_chunked_sum"]

if target.kind.name == "llvm":
return chunk_lse, sch.mod["softmax_with_chunked_sum"]
else:
return apply_gpu_schedule(target, sch)
12 changes: 10 additions & 2 deletions python/mlc_llm/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
FuseFTDequantizeEpilogue(),
FuseDequantizeTranspose(),
BLASDispatch(target) if cublas_gemm else tvm.transform.Sequential([]),
FuseAddRMSNorm(target=target),
(
FuseAddRMSNorm(target=target)
if target.kind.name != "llvm"
else tvm.transform.Sequential([])
),
FuseTransposeMatmul(),
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
Expand Down Expand Up @@ -152,7 +156,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
),
_DebugDump("debug-phase4.py", debug_dump, show_meta=False),
_LogProgress("Lowering to VM bytecode"),
LiftTIRGlobalBufferAlloc(),
(
LiftTIRGlobalBufferAlloc()
if target.kind.name != "llvm"
else tvm.transform.Sequential([])
),
(
tvm.tir.transform.ForceNarrowIndexToInt32()
if target.kind.name != "cuda"
Expand Down
2 changes: 1 addition & 1 deletion python/mlc_llm/support/auto_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

FOUND = green("Found")
NOT_FOUND = red("Not found")
AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan", "opencl"]
AUTO_DETECT_DEVICES = ["cpu", "cuda", "rocm", "metal", "vulkan", "opencl"]
_RESULT_CACHE: Dict[str, bool] = {}


Expand Down

0 comments on commit aad79c1

Please sign in to comment.