Skip to content

Commit baa53d9

Browse files
authored
Fused residual add rmsnorm (#173)
Signed-off-by: Mayank Mishra <[email protected]>
1 parent 9b52cfa commit baa53d9

20 files changed

+1066
-28
lines changed

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@ style:
1414
pre-commit run --all-files
1515

1616
cutotune-cache:
17-
DEBUG_CUTOTUNE=1 LOAD_CUTOTUNE_CACHE=0 TORCH_CUDA_ARCH_LIST=9.0 python tools/build_cutotune_cache.py
17+
DEBUG_CUTOTUNE=1 LOAD_CUTOTUNE_CACHE=1 TORCH_CUDA_ARCH_LIST=9.0 python tools/build_cutotune_cache.py

cute_kernels/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
embedding_torch,
2424
fused_linear_cross_entropy_cute,
2525
fused_linear_cross_entropy_torch,
26+
fused_residual_add_rmsnorm_cute,
27+
fused_residual_add_rmsnorm_torch,
2628
gemm_cute,
2729
gemm_torch,
2830
linear_cute,

cute_kernels/cache.yml

+424
Large diffs are not rendered by default.

cute_kernels/cutotune/tuner.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __call__(self, *args, **kwargs) -> Any:
7575
if _DEBUG_CUTOTUNE and (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0):
7676
print(
7777
f"config {best_config} achieved the best time ({best_time} sec) for {lookup_key} for "
78-
"function {self.function.__name__}"
78+
f"function {self.function.__name__}"
7979
)
8080

8181
output = self.function(

cute_kernels/kernels/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .cross_entropy import cross_entropy_cute, cross_entropy_torch
44
from .embedding import embedding_cute, embedding_torch
55
from .fused_linear_cross_entropy import fused_linear_cross_entropy_cute, fused_linear_cross_entropy_torch
6+
from .fused_residual_add_rmsnorm import fused_residual_add_rmsnorm_cute, fused_residual_add_rmsnorm_torch
67
from .gemm import gemm_cute, gemm_torch
78
from .linear import linear_cute, linear_torch
89
from .rmsnorm import rmsnorm_cute, rmsnorm_torch
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import torch
2+
3+
from ...cutotune import CutoTuneParameter
4+
from ...utils import ensure_contiguous
5+
from .backward import _backward
6+
from .forward import _forward
7+
from .torch_implementation import fused_residual_add_rmsnorm_torch
8+
9+
10+
class _FusedResidualAddRMSNorm_Cute(torch.autograd.Function):
11+
@staticmethod
12+
@ensure_contiguous
13+
def forward(
14+
ctx,
15+
x: torch.Tensor,
16+
residual: torch.Tensor,
17+
weight: torch.Tensor | None,
18+
eps: float | None,
19+
multiplier: float | None,
20+
memory_efficient: bool,
21+
kernel_backend_forward: str,
22+
kernel_backend_backward: str,
23+
BLOCK_SIZE_B_forward: int,
24+
BLOCK_SIZE_B_backward: int,
25+
BLOCK_SIZE_H_forward: int,
26+
BLOCK_SIZE_H_backward: int,
27+
) -> tuple[torch.Tensor]:
28+
if weight is not None:
29+
assert weight.dim() == 1, "weight should be 1D"
30+
assert weight.size(-1) == x.size(-1), "hidden size for x and weight tensor is different"
31+
assert weight.type() == x.type(), "tensors weight and y should have same dtype"
32+
33+
is_x_1d = x.dim() == 1
34+
if is_x_1d:
35+
x = x.unsqueeze(0)
36+
37+
if eps is None:
38+
eps = torch.finfo(x.dtype).eps
39+
40+
output, added_x_residual, rmsnorm_denominator = _forward(
41+
x=x,
42+
residual=residual,
43+
weight=weight,
44+
eps=eps,
45+
multiplier=multiplier,
46+
memory_efficient=memory_efficient,
47+
kernel_backend=kernel_backend_forward,
48+
BLOCK_SIZE_B=BLOCK_SIZE_B_forward,
49+
BLOCK_SIZE_H=BLOCK_SIZE_H_forward,
50+
)
51+
52+
ctx.save_for_backward(added_x_residual, weight, rmsnorm_denominator)
53+
54+
if is_x_1d:
55+
output = output.squeeze(0)
56+
added_x_residual = added_x_residual.squeeze(0)
57+
58+
ctx.is_x_1d = is_x_1d
59+
ctx.kernel_backend_backward = kernel_backend_backward
60+
ctx.eps = eps
61+
ctx.multiplier = multiplier
62+
ctx.BLOCK_SIZE_B_backward = BLOCK_SIZE_B_backward
63+
ctx.BLOCK_SIZE_H_backward = BLOCK_SIZE_H_backward
64+
65+
return output, added_x_residual
66+
67+
@staticmethod
68+
@ensure_contiguous
69+
def backward(ctx, output_grad: torch.Tensor, added_x_residual_grad: torch.Tensor) -> tuple[torch.Tensor | None]:
70+
added_x_residual, weight, rmsnorm_denominator = ctx.saved_tensors
71+
72+
x_grad, residual_grad, weight_grad = _backward(
73+
added_x_residual=added_x_residual,
74+
weight=weight,
75+
eps=ctx.eps,
76+
multiplier=ctx.multiplier,
77+
rmsnorm_denominator=rmsnorm_denominator,
78+
output_grad=output_grad,
79+
added_x_residual_grad=added_x_residual_grad,
80+
kernel_backend=ctx.kernel_backend_backward,
81+
BLOCK_SIZE_B=ctx.BLOCK_SIZE_B_backward,
82+
BLOCK_SIZE_H=ctx.BLOCK_SIZE_H_backward,
83+
)
84+
85+
if ctx.is_x_1d:
86+
x_grad = x_grad.squeeze(0)
87+
residual_grad = residual_grad.squeeze(0)
88+
89+
return x_grad, residual_grad, weight_grad, *[None] * 9
90+
91+
92+
def fused_residual_add_rmsnorm_cute(
93+
x: torch.Tensor,
94+
residual: torch.Tensor,
95+
weight: torch.Tensor | None,
96+
eps: float | None,
97+
multiplier: float | None = None,
98+
memory_efficient: bool = False,
99+
kernel_backend_forward: str = CutoTuneParameter(),
100+
kernel_backend_backward: str = CutoTuneParameter(),
101+
BLOCK_SIZE_B_forward: int = CutoTuneParameter(),
102+
BLOCK_SIZE_B_backward: int = CutoTuneParameter(),
103+
BLOCK_SIZE_H_forward: int = CutoTuneParameter(),
104+
BLOCK_SIZE_H_backward: int = CutoTuneParameter(),
105+
) -> tuple[torch.Tensor]:
106+
return _FusedResidualAddRMSNorm_Cute.apply(
107+
x,
108+
residual,
109+
weight,
110+
eps,
111+
multiplier,
112+
memory_efficient,
113+
kernel_backend_forward,
114+
kernel_backend_backward,
115+
BLOCK_SIZE_B_forward,
116+
BLOCK_SIZE_B_backward,
117+
BLOCK_SIZE_H_forward,
118+
BLOCK_SIZE_H_backward,
119+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import torch
2+
3+
from ...constants import MAX_TRITON_BLOCK_SIZE
4+
from ...cutotune import cutotune
5+
from ...math import get_next_power_of_2
6+
from ..rmsnorm.parameters import get_cutotune_parameters
7+
from .triton_implementation import fused_residual_add_rmsnorm_backward_triton
8+
9+
10+
@cutotune(**get_cutotune_parameters(triggers={"added_x_residual.dtype"}))
11+
def _backward(
12+
added_x_residual: torch.Tensor,
13+
weight: torch.Tensor | None,
14+
eps: float,
15+
multiplier: float | None,
16+
rmsnorm_denominator: torch.Tensor,
17+
output_grad: torch.Tensor,
18+
added_x_residual_grad: torch.Tensor,
19+
kernel_backend: str,
20+
BLOCK_SIZE_B: int,
21+
BLOCK_SIZE_H: int,
22+
) -> tuple[torch.Tensor | None]:
23+
hidden_size = added_x_residual.size(-1)
24+
25+
x_grad = torch.empty_like(added_x_residual)
26+
residual_grad = torch.empty_like(added_x_residual)
27+
weight_grad = None if weight is None else torch.zeros_like(weight, dtype=torch.float32)
28+
29+
if kernel_backend == "triton":
30+
BLOCK_SIZE_H = get_next_power_of_2(hidden_size)
31+
assert BLOCK_SIZE_H <= MAX_TRITON_BLOCK_SIZE
32+
33+
fused_residual_add_rmsnorm_backward_triton(
34+
added_x_residual=added_x_residual,
35+
weight=weight,
36+
output_grad=output_grad,
37+
added_x_residual_grad=added_x_residual_grad,
38+
rmsnorm_denominator=rmsnorm_denominator,
39+
x_grad=x_grad,
40+
residual_grad=residual_grad,
41+
weight_grad=weight_grad,
42+
eps=eps,
43+
multiplier=multiplier,
44+
BLOCK_SIZE_B=BLOCK_SIZE_B,
45+
BLOCK_SIZE_H=BLOCK_SIZE_H,
46+
)
47+
else:
48+
raise ValueError(f"unexpected kernel_backend ({kernel_backend})")
49+
50+
if weight_grad is not None:
51+
weight_grad = weight_grad.type_as(weight)
52+
53+
return x_grad, residual_grad, weight_grad
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import torch
2+
3+
from ...constants import MAX_TRITON_BLOCK_SIZE
4+
from ...cutotune import cutotune
5+
from ...math import get_next_power_of_2
6+
from ...utils import get_num_elements_and_hidden_size
7+
from ..rmsnorm.parameters import get_cutotune_parameters
8+
from .triton_implementation import fused_residual_add_rmsnorm_forward_triton
9+
10+
11+
@cutotune(**get_cutotune_parameters())
12+
def _forward(
13+
x: torch.Tensor,
14+
residual: torch.Tensor,
15+
weight: torch.Tensor | None,
16+
eps: float,
17+
multiplier: float | None,
18+
memory_efficient: bool,
19+
kernel_backend: str,
20+
BLOCK_SIZE_B: int,
21+
BLOCK_SIZE_H: int,
22+
) -> tuple[torch.Tensor | None]:
23+
num_elements, hidden_size = get_num_elements_and_hidden_size(x)
24+
25+
output = torch.empty_like(x)
26+
added_x_residual = torch.empty_like(x)
27+
rmsnorm_denominator = None if memory_efficient else torch.empty(num_elements, device=x.device, dtype=torch.float32)
28+
29+
if kernel_backend == "triton":
30+
BLOCK_SIZE_H = get_next_power_of_2(hidden_size)
31+
assert BLOCK_SIZE_H <= MAX_TRITON_BLOCK_SIZE
32+
33+
fused_residual_add_rmsnorm_forward_triton(
34+
x=x,
35+
residual=residual,
36+
weight=weight,
37+
output=output,
38+
eps=eps,
39+
multiplier=multiplier,
40+
added_x_residual=added_x_residual,
41+
rmsnorm_denominator=rmsnorm_denominator,
42+
BLOCK_SIZE_B=BLOCK_SIZE_B,
43+
BLOCK_SIZE_H=BLOCK_SIZE_H,
44+
)
45+
else:
46+
raise ValueError(f"unexpected kernel_backend ({kernel_backend})")
47+
48+
return output, added_x_residual, rmsnorm_denominator
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
5+
def fused_residual_add_rmsnorm_torch(
6+
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor | None, eps: float, multiplier: float | None = None
7+
) -> tuple[torch.Tensor]:
8+
if multiplier is not None:
9+
x = x * multiplier
10+
11+
x = x + residual
12+
residual = x
13+
x = F.rms_norm(x, (x.size(-1),), weight=weight, eps=eps)
14+
15+
return x, residual
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .backward import fused_residual_add_rmsnorm_backward_triton
2+
from .forward import fused_residual_add_rmsnorm_forward_triton

0 commit comments

Comments
 (0)