Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mengshyu committed Feb 2, 2025
1 parent aad79c1 commit 7ae4087
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cpp/serve/sampler/cpu_sampler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -555,7 +555,7 @@ class CPUSampler : public SamplerObj {
NDArray CopyProbsToCPU(NDArray probs_on_device) {
// probs_on_device: (n, v)
if (probs_on_device->device.device_type == kDLCPU) {
return probs_on_device;
return probs_on_device;
}

ICHECK(probs_on_device->device.device_type != kDLCPU);
Expand Down
12 changes: 6 additions & 6 deletions python/mlc_llm/compiler_pass/attach_logit_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IR
"""Entrypoint"""
mod = mod.clone()
if str(self.target.kind) == "llvm":
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu(self.target)
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu(self.target)
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu(self.target)
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace_cpu()
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace_cpu()
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace_cpu()
else:
mod["apply_logit_bias_inplace"] = _get_apply_logit_bias_inplace(self.target)
mod["apply_penalty_inplace"] = _get_apply_penalty_inplace(self.target)
mod["apply_bitmask_inplace"] = _get_apply_bitmask_inplace(self.target)
return mod


def _get_apply_logit_bias_inplace_cpu(target: tvm.target.Target):
def _get_apply_logit_bias_inplace_cpu():
@T.prim_func
def _apply_logit_bias_inplace(
var_logits: T.handle,
Expand Down Expand Up @@ -110,7 +110,7 @@ def _apply_logit_bias_inplace(
return _apply_logit_bias_inplace


def _get_apply_penalty_inplace_cpu(target: tvm.target.Target):
def _get_apply_penalty_inplace_cpu():
@T.prim_func
def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-locals
var_logits: T.handle,
Expand Down Expand Up @@ -209,7 +209,7 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
return _apply_penalty_inplace


def _get_apply_bitmask_inplace_cpu(target: tvm.target.Target):
def _get_apply_bitmask_inplace_cpu():
@T.prim_func
def _apply_bitmask_inplace(
var_logits: T.handle,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,5 +246,4 @@ def apply_gpu_schedule(target, sch):

if target.kind.name == "llvm":
return chunk_lse, sch.mod["softmax_with_chunked_sum"]
else:
return apply_gpu_schedule(target, sch)
return apply_gpu_schedule(target, sch)

0 comments on commit 7ae4087

Please sign in to comment.