Skip to content

Commit aad79c1

Browse files
committed
This PR introduces KV Cache support for the CPU runtime:
Implementation of KV Cache TIR for CPU-based processing. Updates to the relevant runtime components to integrate KV Cache.
1 parent f9da13f commit aad79c1

File tree

5 files changed

+168
-33
lines changed

5 files changed

+168
-33
lines changed

cpp/serve/sampler/cpu_sampler.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,10 @@ class CPUSampler : public SamplerObj {
554554
/*! \brief Copy prob distributions from device to CPU. */
555555
NDArray CopyProbsToCPU(NDArray probs_on_device) {
556556
// probs_on_device: (n, v)
557+
if (probs_on_device->device.device_type == kDLCPU) {
558+
return probs_on_device;
559+
}
560+
557561
ICHECK(probs_on_device->device.device_type != kDLCPU);
558562
if (probs_host_.defined()) {
559563
ICHECK_EQ(probs_host_->shape[1], probs_on_device->shape[1]);

python/mlc_llm/compiler_pass/attach_logit_processor.py

Lines changed: 119 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,48 @@ def __init__(self, target: tvm.target.Target):
2727
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
2828
"""Entrypoint"""
2929
mod = mod.clone()
30-
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
31-
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
32-
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
30+
if str(self.target.kind) == "llvm":
31+
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu(self.target)
32+
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu(self.target)
33+
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu(self.target)
34+
else:
35+
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
36+
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
37+
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
3338
return mod
3439

3540

41+
def _get_apply_logit_bias_inplace_cpu(target: tvm.target.Target):
42+
@T.prim_func
43+
def _apply_logit_bias_inplace(
44+
var_logits: T.handle,
45+
var_pos2seq_id: T.handle,
46+
var_token_ids: T.handle,
47+
var_logit_bias: T.handle,
48+
) -> None:
49+
"""Function that applies logit bias in place."""
50+
T.func_attr(
51+
{
52+
"global_symbol": "apply_logit_bias_inplace",
53+
"tir.noalias": True,
54+
"tir.is_scheduled": True,
55+
}
56+
)
57+
batch_size = T.int32(is_size_var=True)
58+
vocab_size = T.int32(is_size_var=True)
59+
num_token = T.int32(is_size_var=True)
60+
logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32")
61+
# seq_ids
62+
pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
63+
token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
64+
logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32")
65+
66+
for i in range(num_token):
67+
logits[pos2seq_id[i], token_ids[i]] += logit_bias[i]
68+
69+
return _apply_logit_bias_inplace
70+
71+
3672
def _get_apply_logit_bias_inplace(target: tvm.target.Target):
3773
tx = 1024 # default
3874
max_num_threads_per_block = get_max_num_threads_per_block(target)
@@ -74,6 +110,50 @@ def _apply_logit_bias_inplace(
74110
return _apply_logit_bias_inplace
75111

76112

113+
def _get_apply_penalty_inplace_cpu(target: tvm.target.Target):
114+
@T.prim_func
115+
def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals
116+
var_logits: T.handle,
117+
var_seq_ids: T.handle,
118+
var_pos2seq_id: T.handle,
119+
var_token_ids: T.handle,
120+
var_token_cnt: T.handle,
121+
var_penalties: T.handle,
122+
) -> None:
123+
"""Function that applies penalties in place."""
124+
T.func_attr(
125+
{
126+
"global_symbol": "apply_penalty_inplace",
127+
"tir.noalias": True,
128+
"tir.is_scheduled": True,
129+
}
130+
)
131+
batch_size = T.int32(is_size_var=True)
132+
vocab_size = T.int32(is_size_var=True)
133+
num_token = T.int32(is_size_var=True)
134+
num_seq = T.int32(is_size_var=True)
135+
logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32")
136+
seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
137+
pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32")
138+
token_ids = T.match_buffer(var_token_ids, (num_token,), "int32")
139+
token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32")
140+
penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32")
141+
142+
for token in T.serial(num_token):
143+
with T.block("block"):
144+
vp = T.axis.spatial(num_token, token)
145+
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= (
146+
penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1]
147+
)
148+
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else(
149+
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < 0,
150+
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2],
151+
logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2],
152+
)
153+
154+
return _apply_penalty_inplace
155+
156+
77157
def _get_apply_penalty_inplace(target: tvm.target.Target):
78158
tx = 1024 # default
79159
max_num_threads_per_block = get_max_num_threads_per_block(target)
@@ -129,6 +209,42 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
129209
return _apply_penalty_inplace
130210

131211

212+
def _get_apply_bitmask_inplace_cpu(target: tvm.target.Target):
213+
@T.prim_func
214+
def _apply_bitmask_inplace(
215+
var_logits: T.handle,
216+
var_seq_ids: T.handle,
217+
var_bitmask: T.handle,
218+
) -> None:
219+
"""Function that applies vocabulary masking in place."""
220+
T.func_attr(
221+
{
222+
"global_symbol": "apply_bitmask_inplace",
223+
"tir.noalias": True,
224+
"tir.is_scheduled": True,
225+
}
226+
)
227+
batch_size = T.int32(is_size_var=True)
228+
vocab_size = T.int32(is_size_var=True)
229+
num_seq = T.int32(is_size_var=True)
230+
logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32")
231+
seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32")
232+
bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32")
233+
234+
for token in T.serial(num_seq * vocab_size):
235+
with T.block("block"):
236+
vs = T.axis.spatial(num_seq, (token) // vocab_size)
237+
vv = T.axis.spatial(vocab_size, (token) % vocab_size)
238+
239+
logits[seq_ids[vs], vv] = T.if_then_else(
240+
(bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1,
241+
logits[seq_ids[vs], vv],
242+
T.min_value("float32"),
243+
)
244+
245+
return _apply_bitmask_inplace
246+
247+
132248
def _get_apply_bitmask_inplace(target: tvm.target.Target):
133249
tx = 1024 # default
134250
max_num_threads_per_block = get_max_num_threads_per_block(target)

python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -213,31 +213,38 @@ def softmax_with_chunked_sum(
213213
)
214214

215215
sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum}))
216-
max_threads = get_max_num_threads_per_block(target)
217-
TX = 32
218-
TY = max_threads // TX
219-
unroll_depth = 64
220-
# pylint: enable=invalid-name
221-
222-
sch.work_on("softmax_with_chunked_sum")
223-
l0, l1, l2 = sch.get_loops("log_pad")
224-
bx = sch.fuse(l0, l1)
225-
sch.bind(bx, "blockIdx.x")
226-
unroll, ty, tx = sch.split(l2, [None, TY, TX])
227-
sch.bind(ty, "threadIdx.y")
228-
sch.bind(tx, "threadIdx.x")
229-
sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
230-
sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)
231-
232-
for block_name in ["sum_exp", "max"]:
233-
block = sch.get_block(block_name)
234-
sch.set_scope(block, buffer_index=0, storage_scope="shared")
235-
sch.compute_at(block, bx)
236-
r_loop = sch.get_loops(block)[-1]
237-
r_loop, tx = sch.split(r_loop, [None, TX])
238-
sch.reorder(tx, r_loop)
239-
sch.bind(tx, "threadIdx.x")
240-
sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
241-
sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1)
242216

243-
return chunk_lse, sch.mod["softmax_with_chunked_sum"]
217+
def apply_gpu_schedule(target, sch):
218+
max_threads = get_max_num_threads_per_block(target)
219+
TX = 32
220+
TY = max_threads // TX
221+
unroll_depth = 64
222+
# pylint: enable=invalid-name
223+
224+
sch.work_on("softmax_with_chunked_sum")
225+
l0, l1, l2 = sch.get_loops("log_pad")
226+
bx = sch.fuse(l0, l1)
227+
sch.bind(bx, "blockIdx.x")
228+
unroll, ty, tx = sch.split(l2, [None, TY, TX])
229+
sch.bind(ty, "threadIdx.y")
230+
sch.bind(tx, "threadIdx.x")
231+
sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
232+
sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1)
233+
234+
for block_name in ["sum_exp", "max"]:
235+
block = sch.get_block(block_name)
236+
sch.set_scope(block, buffer_index=0, storage_scope="shared")
237+
sch.compute_at(block, bx)
238+
r_loop = sch.get_loops(block)[-1]
239+
r_loop, tx = sch.split(r_loop, [None, TX])
240+
sch.reorder(tx, r_loop)
241+
sch.bind(tx, "threadIdx.x")
242+
sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
243+
sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1)
244+
245+
return chunk_lse, sch.mod["softmax_with_chunked_sum"]
246+
247+
if target.kind.name == "llvm":
248+
return chunk_lse, sch.mod["softmax_with_chunked_sum"]
249+
else:
250+
return apply_gpu_schedule(target, sch)

python/mlc_llm/compiler_pass/pipeline.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
120120
FuseFTDequantizeEpilogue(),
121121
FuseDequantizeTranspose(),
122122
BLASDispatch(target) if cublas_gemm else tvm.transform.Sequential([]),
123-
FuseAddRMSNorm(target=target),
123+
(
124+
FuseAddRMSNorm(target=target)
125+
if target.kind.name != "llvm"
126+
else tvm.transform.Sequential([])
127+
),
124128
FuseTransposeMatmul(),
125129
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
126130
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
@@ -152,7 +156,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
152156
),
153157
_DebugDump("debug-phase4.py", debug_dump, show_meta=False),
154158
_LogProgress("Lowering to VM bytecode"),
155-
LiftTIRGlobalBufferAlloc(),
159+
(
160+
LiftTIRGlobalBufferAlloc()
161+
if target.kind.name != "llvm"
162+
else tvm.transform.Sequential([])
163+
),
156164
(
157165
tvm.tir.transform.ForceNarrowIndexToInt32()
158166
if target.kind.name != "cuda"

python/mlc_llm/support/auto_device.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
FOUND = green("Found")
1515
NOT_FOUND = red("Not found")
16-
AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan", "opencl"]
16+
AUTO_DETECT_DEVICES = ["cpu", "cuda", "rocm", "metal", "vulkan", "opencl"]
1717
_RESULT_CACHE: Dict[str, bool] = {}
1818

1919

0 commit comments

Comments
 (0)