Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 144 additions & 14 deletions vllm_ascend/ops/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,146 @@
# This file is a part of the vllm-ascend project.
#

from typing import Optional, Tuple, Union, cast
from typing import Optional, Tuple, Union, cast, Dict, Any

import torch
from vllm.config import get_current_vllm_config
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.triton_utils import tl, triton
from functools import cache


@cache
def get_device_properties() -> Tuple[int, int]:
device = torch.npu.current_device()
device_properties: Dict[str, Any] = (
triton.runtime.driver.active.utils.get_device_properties(device)
)

num_aicore = device_properties.get("num_aicore", -1)
num_vectorcore = device_properties.get("num_vectorcore", -1)

assert num_aicore > 0 and num_vectorcore > 0, "Failed to detect device properties."
return num_aicore, num_vectorcore


@triton.heuristics({
"HAS_BIAS": lambda args: args["B"] is not None
})
@triton.heuristics({"HAS_Z": lambda args: args["Z"] is not None})
@triton.jit
def rms_norm_fwd_kernel(
X, # pointer to the input
Y, # pointer to the output
W, # pointer to the weights
B, # pointer to the biases
Z, # pointer to the residual
Z_Out, # pointer to the residual output
stride_x_row, # how much to increase the pointer when moving by 1 row
stride_y_row,
stride_z_row,
stride_z_out_row,
n_rows, # number of rows in X_base
n_cols, # number of columns in X_base
eps, # epsilon to avoid division by zero
BLOCK_N: tl.constexpr,
HAS_BIAS: tl.constexpr,
HAS_Z: tl.constexpr,
):
# Map the program id to the row of X_base and Y_base it should compute.
# Each program computes a row of X_base and store to Y_base
row_start = tl.program_id(0)
for row_idx in tl.range(row_start, n_rows, tl.num_programs(0)):
start_x = X + row_idx * stride_x_row
start_y = Y + row_idx * stride_y_row
if HAS_Z:
start_z = Z + row_idx * stride_z_row
start_z_out = Z_Out + row_idx * stride_z_out_row
offsets = tl.arange(0, BLOCK_N)
mask = offsets < n_cols
x = tl.load(start_x + offsets, mask=mask, other=0.0)
original_dtype = x.dtype
x = x.to(tl.float32)
if HAS_Z:
z = tl.load(start_z + offsets, mask=mask, other=0.0).to(tl.float32)
x = x + z
tl.store(start_z_out + offsets, x, mask=mask)
var = tl.sum(x * x, axis=0) / n_cols
rstd = 1 / tl.sqrt(var + eps)
w = tl.load(W + offsets, mask=mask).to(tl.float32)
if HAS_BIAS:
bias = tl.load(B + offsets, mask=mask).to(tl.float32)

x_hat = x * rstd
# Cast normalized x back to original data type to preserve precision contract
x_hat = x_hat.to(original_dtype)
y = x_hat * w
if HAS_BIAS:
y = y + bias
tl.store(start_y + offsets, y, mask=mask)


def _rms_norm_fwd_triton(
x,
weight,
eps,
residual=None,
bias=None,
out=None,
residual_out=None,
):
M, N = x.shape
assert x.stride(-1) == 1
assert weight.shape == (N,)
assert weight.stride(-1) == 1
# logger.info(f"bias is {bias}")
if bias is not None:
assert bias.stride(-1) == 1
assert bias.shape == (N,)
if residual is not None:
assert residual.shape == x.shape
assert residual.stride(-1) == 1
if residual_out is None:
residual_out = torch.empty_like(x)
# allocate output
if out is not None:
assert out.shape == x.shape
else:
out = torch.empty_like(x)
assert out.stride(-1) == 1
# Less than 64KB per feature: enqueue fused kernel
MAX_FUSED_SIZE = 65536 // x.element_size()
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
if N > BLOCK_N:
raise RuntimeError(
"This rms norm doesn't support feature dim >= 64KB.")
# heuristics for number of warps
num_warps = min(max(BLOCK_N // 256, 1), 8)
# _, num_vectorcore = get_device_properties()
num_vectorcore = 40
Comment on lines +134 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The num_vectorcore is hardcoded to 40. The code should use the get_device_properties() function to dynamically fetch this value, as intended by the commented-out line. Hardcoding device properties makes the code less portable and may lead to suboptimal performance on different hardware.

Suggested change
# _, num_vectorcore = get_device_properties()
num_vectorcore = 40
_, num_vectorcore = get_device_properties()

grid = (M if M < num_vectorcore else num_vectorcore,)
# with torch.npu.device(x.device.index):
rms_norm_fwd_kernel[grid](
x,
out,
weight,
bias,
residual,
residual_out,
x.stride(0),
out.stride(0),
residual.stride(0) if residual is not None else None,
residual_out.stride(0) if residual is not None else None,
Comment on lines +147 to +148
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Passing None for stride arguments to a Triton kernel can lead to a TypeError at runtime. When residual is None, None is passed for stride_z_row and stride_z_out_row. A dummy integer value, such as 0, should be passed instead.

Suggested change
residual.stride(0) if residual is not None else None,
residual_out.stride(0) if residual is not None else None,
residual.stride(0) if residual is not None else 0,
residual_out.stride(0) if residual_out is not None else 0,

M,
N,
eps,
BLOCK_N=BLOCK_N,
num_warps=num_warps,
# multibuffer=True,
)
return out, residual_out

def _addrmsnorm_forward_oot(
self,
x: torch.Tensor,
Expand All @@ -30,12 +162,11 @@
layer: Optional[torch.nn.Module] = None,
bias: Optional[torch.nn.Parameter] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

Check failure on line 165 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 165 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 165 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 165 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 165 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
from vllm_ascend.utils import is_310p

if layer is not None and get_ascend_device_type(
) != AscendDeviceType._310P:
if layer is not None and not is_310p():
layer_cls_name = layer.__class__.__name__
try:
weight_prefetch_method = get_forward_context(
Expand Down Expand Up @@ -68,17 +199,14 @@
)

else:
if get_ascend_device_type() == AscendDeviceType._310P:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
else:
x, _, residual = torch_npu.npu_add_rms_norm(
x, residual, self.weight, self.variance_epsilon)
if bias is not None:
x.add_(bias)
x, residual = _rms_norm_fwd_triton(x, self.weight, self.variance_epsilon, residual, bias)
torch.ops.vllm.maybe_wait_prefetch_done(x)
return x, residual

Expand Down Expand Up @@ -115,10 +243,12 @@
self, x, residual, self.next_need_quant_fusion_linear,
self.bias)
return x, residual
x, residual = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)

if self.bias is not None:
x.add_(self.bias)
x, _ = _rms_norm_fwd_triton(x, self.weight, self.variance_epsilon, None, self.bias)
else:
x, _ = torch_npu.npu_rms_norm(x, self.weight,
self.variance_epsilon)
return x

@property
Expand Down Expand Up @@ -196,9 +326,9 @@
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
import torch_npu

from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
from vllm_ascend.utils import is_310p
if residual is not None:

Check failure on line 330 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 330 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 330 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 330 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]

Check failure on line 330 in vllm_ascend/ops/layernorm.py

View workflow job for this annotation

GitHub Actions / lint / pre-commit

Module "vllm_ascend.utils" has no attribute "is_310p" [attr-defined]
if get_ascend_device_type() == AscendDeviceType._310P:
if is_310p():
orig_dtype = residual.dtype
x = x + residual.to(x.dtype)
residual = x.to(orig_dtype)
Expand Down
Loading