diff --git a/cpp/serve/sampler/cpu_sampler.cc b/cpp/serve/sampler/cpu_sampler.cc index cde12ea8da..1542091f32 100644 --- a/cpp/serve/sampler/cpu_sampler.cc +++ b/cpp/serve/sampler/cpu_sampler.cc @@ -554,6 +554,10 @@ class CPUSampler : public SamplerObj { /*! \brief Copy prob distributions from device to CPU. */ NDArray CopyProbsToCPU(NDArray probs_on_device) { // probs_on_device: (n, v) + if (probs_on_device->device.device_type == kDLCPU) { + return probs_on_device; + } + ICHECK(probs_on_device->device.device_type != kDLCPU); if (probs_host_.defined()) { ICHECK_EQ(probs_host_->shape[1], probs_on_device->shape[1]); diff --git a/python/mlc_llm/compiler_pass/attach_logit_processor.py b/python/mlc_llm/compiler_pass/attach_logit_processor.py index a6b23a8fa0..f5ba5531a3 100644 --- a/python/mlc_llm/compiler_pass/attach_logit_processor.py +++ b/python/mlc_llm/compiler_pass/attach_logit_processor.py @@ -27,12 +27,48 @@ def __init__(self, target: tvm.target.Target): def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule: """Entrypoint""" mod = mod.clone() - mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target) - mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target) - mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target) + if str(self.target.kind) == "llvm": + mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu(self.target) + mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu(self.target) + mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu(self.target) + else: + mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target) + mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target) + mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target) return mod +def _get_apply_logit_bias_inplace_cpu(target: tvm.target.Target): + @T.prim_func + def _apply_logit_bias_inplace( + var_logits: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_logit_bias: T.handle, + ) -> None: + """Function that applies logit bias in place.""" + T.func_attr( + { + "global_symbol": "apply_logit_bias_inplace", + "tir.noalias": True, + "tir.is_scheduled": True, + } + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + # seq_ids + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + logit_bias = T.match_buffer(var_logit_bias, (num_token,), "float32") + + for i in range(num_token): + logits[pos2seq_id[i], token_ids[i]] += logit_bias[i] + + return _apply_logit_bias_inplace + + def _get_apply_logit_bias_inplace(target: tvm.target.Target): tx = 1024 # default max_num_threads_per_block = get_max_num_threads_per_block(target) @@ -74,6 +110,50 @@ def _apply_logit_bias_inplace( return _apply_logit_bias_inplace +def _get_apply_penalty_inplace_cpu(target: tvm.target.Target): + @T.prim_func + def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals + var_logits: T.handle, + var_seq_ids: T.handle, + var_pos2seq_id: T.handle, + var_token_ids: T.handle, + var_token_cnt: T.handle, + var_penalties: T.handle, + ) -> None: + """Function that applies penalties in place.""" + T.func_attr( + { + "global_symbol": "apply_penalty_inplace", + "tir.noalias": True, + "tir.is_scheduled": True, + } + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_token = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + pos2seq_id = T.match_buffer(var_pos2seq_id, (num_token,), "int32") + token_ids = T.match_buffer(var_token_ids, (num_token,), "int32") + token_cnt = T.match_buffer(var_token_cnt, (num_token,), "int32") + penalties = T.match_buffer(var_penalties, (num_seq, 3), "float32") + + for token in T.serial(num_token): + with T.block("block"): + vp = T.axis.spatial(num_token, token) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] -= ( + penalties[pos2seq_id[vp], 0] + token_cnt[vp] * penalties[pos2seq_id[vp], 1] + ) + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] = T.if_then_else( + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] < 0, + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] * penalties[pos2seq_id[vp], 2], + logits[seq_ids[pos2seq_id[vp]], token_ids[vp]] / penalties[pos2seq_id[vp], 2], + ) + + return _apply_penalty_inplace + + def _get_apply_penalty_inplace(target: tvm.target.Target): tx = 1024 # default max_num_threads_per_block = get_max_num_threads_per_block(target) @@ -129,6 +209,42 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local return _apply_penalty_inplace +def _get_apply_bitmask_inplace_cpu(target: tvm.target.Target): + @T.prim_func + def _apply_bitmask_inplace( + var_logits: T.handle, + var_seq_ids: T.handle, + var_bitmask: T.handle, + ) -> None: + """Function that applies vocabulary masking in place.""" + T.func_attr( + { + "global_symbol": "apply_bitmask_inplace", + "tir.noalias": True, + "tir.is_scheduled": True, + } + ) + batch_size = T.int32(is_size_var=True) + vocab_size = T.int32(is_size_var=True) + num_seq = T.int32(is_size_var=True) + logits = T.match_buffer(var_logits, (batch_size, vocab_size), "float32") + seq_ids = T.match_buffer(var_seq_ids, (num_seq,), "int32") + bitmask = T.match_buffer(var_bitmask, (batch_size, (vocab_size + 31) // 32), "int32") + + for token in T.serial(num_seq * vocab_size): + with T.block("block"): + vs = T.axis.spatial(num_seq, (token) // vocab_size) + vv = T.axis.spatial(vocab_size, (token) % vocab_size) + + logits[seq_ids[vs], vv] = T.if_then_else( + (bitmask[seq_ids[vs], vv // 32] >> (vv % 32)) & 1 == 1, + logits[seq_ids[vs], vv], + T.min_value("float32"), + ) + + return _apply_bitmask_inplace + + def _get_apply_bitmask_inplace(target: tvm.target.Target): tx = 1024 # default max_num_threads_per_block = get_max_num_threads_per_block(target) diff --git a/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py index 10b2017a08..46c7e2bcb9 100644 --- a/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py +++ b/python/mlc_llm/compiler_pass/attach_softmax_with_temperature.py @@ -213,31 +213,38 @@ def softmax_with_chunked_sum( ) sch = tvm.tir.Schedule(IRModule({"softmax_with_chunked_sum": softmax_with_chunked_sum})) - max_threads = get_max_num_threads_per_block(target) - TX = 32 - TY = max_threads // TX - unroll_depth = 64 - # pylint: enable=invalid-name - - sch.work_on("softmax_with_chunked_sum") - l0, l1, l2 = sch.get_loops("log_pad") - bx = sch.fuse(l0, l1) - sch.bind(bx, "blockIdx.x") - unroll, ty, tx = sch.split(l2, [None, TY, TX]) - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) - - for block_name in ["sum_exp", "max"]: - block = sch.get_block(block_name) - sch.set_scope(block, buffer_index=0, storage_scope="shared") - sch.compute_at(block, bx) - r_loop = sch.get_loops(block)[-1] - r_loop, tx = sch.split(r_loop, [None, TX]) - sch.reorder(tx, r_loop) - sch.bind(tx, "threadIdx.x") - sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) - sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) - return chunk_lse, sch.mod["softmax_with_chunked_sum"] + def apply_gpu_schedule(target, sch): + max_threads = get_max_num_threads_per_block(target) + TX = 32 + TY = max_threads // TX + unroll_depth = 64 + # pylint: enable=invalid-name + + sch.work_on("softmax_with_chunked_sum") + l0, l1, l2 = sch.get_loops("log_pad") + bx = sch.fuse(l0, l1) + sch.bind(bx, "blockIdx.x") + unroll, ty, tx = sch.split(l2, [None, TY, TX]) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.annotate(unroll, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(unroll, ann_key="pragma_unroll_explicit", ann_val=1) + + for block_name in ["sum_exp", "max"]: + block = sch.get_block(block_name) + sch.set_scope(block, buffer_index=0, storage_scope="shared") + sch.compute_at(block, bx) + r_loop = sch.get_loops(block)[-1] + r_loop, tx = sch.split(r_loop, [None, TX]) + sch.reorder(tx, r_loop) + sch.bind(tx, "threadIdx.x") + sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth) + sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1) + + return chunk_lse, sch.mod["softmax_with_chunked_sum"] + + if target.kind.name == "llvm": + return chunk_lse, sch.mod["softmax_with_chunked_sum"] + else: + return apply_gpu_schedule(target, sch) diff --git a/python/mlc_llm/compiler_pass/pipeline.py b/python/mlc_llm/compiler_pass/pipeline.py index af1cf9f0e9..f812f932a9 100644 --- a/python/mlc_llm/compiler_pass/pipeline.py +++ b/python/mlc_llm/compiler_pass/pipeline.py @@ -120,7 +120,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I FuseFTDequantizeEpilogue(), FuseDequantizeTranspose(), BLASDispatch(target) if cublas_gemm else tvm.transform.Sequential([]), - FuseAddRMSNorm(target=target), + ( + FuseAddRMSNorm(target=target) + if target.kind.name != "llvm" + else tvm.transform.Sequential([]) + ), FuseTransposeMatmul(), _DebugDump("debug-phase1.py", debug_dump, show_meta=False), # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline @@ -152,7 +156,11 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I ), _DebugDump("debug-phase4.py", debug_dump, show_meta=False), _LogProgress("Lowering to VM bytecode"), - LiftTIRGlobalBufferAlloc(), + ( + LiftTIRGlobalBufferAlloc() + if target.kind.name != "llvm" + else tvm.transform.Sequential([]) + ), ( tvm.tir.transform.ForceNarrowIndexToInt32() if target.kind.name != "cuda" diff --git a/python/mlc_llm/support/auto_device.py b/python/mlc_llm/support/auto_device.py index bddb9954c6..738f86d9c5 100644 --- a/python/mlc_llm/support/auto_device.py +++ b/python/mlc_llm/support/auto_device.py @@ -13,7 +13,7 @@ FOUND = green("Found") NOT_FOUND = red("Not found") -AUTO_DETECT_DEVICES = ["cuda", "rocm", "metal", "vulkan", "opencl"] +AUTO_DETECT_DEVICES = ["cpu", "cuda", "rocm", "metal", "vulkan", "opencl"] _RESULT_CACHE: Dict[str, bool] = {}