Skip to content

Commit d7e69f9

Browse files
authored
Layer Norm bwd kernel to support large B*M case used by internal (#973)
1 parent 2eb9154 commit d7e69f9

File tree

3 files changed

+214
-342
lines changed

3 files changed

+214
-342
lines changed

examples/layer_norm.py

Lines changed: 39 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ def layer_norm_fwd(
8383

8484
# %%
8585
@helion.kernel
86-
def layer_norm_bwd_dwdb(
86+
def layer_norm_bwd(
8787
grad_out: torch.Tensor,
8888
x: torch.Tensor,
8989
mean: torch.Tensor,
9090
rstd: torch.Tensor,
9191
weight: torch.Tensor,
9292
compute_bias_grad: hl.constexpr = True, # type: ignore[valid-type]
93-
) -> tuple[torch.Tensor, torch.Tensor | None]:
93+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
9494
"""
9595
Compute gradients for weight (dW) and optionally bias (dB) parameters.
9696
@@ -106,86 +106,50 @@ def layer_norm_bwd_dwdb(
106106
compute_bias_grad: Whether to compute bias gradient (default: True)
107107
108108
Returns:
109-
(grad_weight, grad_bias): Gradients for weight and bias (if computed), both shape [N]
109+
(grad_x, grad_weight, grad_bias): Gradients for input, weight, and bias (if computed)
110110
grad_bias is None if compute_bias_grad is False
111111
"""
112-
m, n = x.shape
113-
n = hl.specialize(n)
114112

115-
dw = torch.empty([n], dtype=weight.dtype, device=weight.device)
116-
if compute_bias_grad:
117-
db = torch.empty([n], dtype=weight.dtype, device=weight.device)
118-
else:
119-
db = None
120-
121-
# Reduce across rows (M) inside the kernel without atomics
122-
m = hl.specialize(m)
113+
m_block = hl.register_block_size(x.size(0))
114+
n = hl.specialize(x.size(1))
123115

124-
for tile_n in hl.tile(n):
125-
rows = hl.arange(0, m)
126-
# Load slices for all rows in rdim and this tile of columns
127-
x_blk = x[rows, tile_n].to(torch.float32)
128-
dy_blk = grad_out[rows, tile_n].to(torch.float32)
129-
mean_vec = mean[rows]
130-
rstd_vec = rstd[rows]
131-
132-
x_hat_blk = (x_blk - mean_vec[:, None]) * rstd_vec[:, None]
133-
dw_tile = torch.sum(dy_blk * x_hat_blk, dim=0).to(weight.dtype)
116+
grad_x = torch.empty_like(x)
117+
num_blocks = (x.size(0) + m_block - 1) // m_block
118+
grad_weight_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
119+
grad_bias_blocks = x.new_empty([num_blocks, n], dtype=torch.float32)
134120

135-
dw[tile_n] = dw_tile
121+
for mb_cta in hl.tile(x.size(0), block_size=m_block):
122+
grad_w_acc = weight.new_zeros(n, dtype=torch.float32)
136123
if compute_bias_grad:
137-
db_tile = torch.sum(dy_blk, dim=0).to(weight.dtype)
138-
db[tile_n] = db_tile # type: ignore[index]
124+
grad_b_acc = weight.new_zeros(n, dtype=torch.float32)
125+
weight_cta = weight[None, :].to(torch.float32)
126+
for mb in hl.tile(mb_cta.begin, mb_cta.end):
127+
x_mb = x[mb, :].to(torch.float32)
128+
dy_mb = grad_out[mb, :].to(torch.float32)
129+
mean_mb = mean[mb].to(torch.float32)
130+
rstd_mb = rstd[mb].to(torch.float32)
131+
132+
x_hat = (x_mb - mean_mb[:, None]) * rstd_mb[:, None]
133+
134+
grad_w_acc += torch.sum(dy_mb * x_hat, dim=0)
135+
if compute_bias_grad:
136+
grad_b_acc += torch.sum(dy_mb, dim=0) # pyright: ignore[reportPossiblyUnboundVariable]
137+
138+
wdy = weight_cta * dy_mb
139+
c1 = torch.sum(x_hat * wdy, dim=-1) / n
140+
c2 = torch.sum(wdy, dim=-1) / n
141+
dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_mb[:, None]
142+
grad_x[mb, :] = dx.to(x.dtype)
143+
144+
grad_weight_blocks[mb_cta.id, :] = grad_w_acc
145+
if compute_bias_grad:
146+
grad_bias_blocks[mb_cta.id, :] = grad_b_acc # type: ignore[index]
139147

148+
grad_weight = grad_weight_blocks.sum(0).to(weight.dtype)
140149
if compute_bias_grad:
141-
return dw, db
142-
return dw, None
143-
144-
145-
@helion.kernel
146-
def layer_norm_bwd_dx(
147-
grad_out: torch.Tensor,
148-
x: torch.Tensor,
149-
weight: torch.Tensor,
150-
mean: torch.Tensor,
151-
rstd: torch.Tensor,
152-
) -> torch.Tensor:
153-
"""
154-
Compute gradient for input tensor (dX).
155-
156-
This kernel computes per-sample gradients by performing reductions across
157-
the feature dimension (N) for each sample in the batch.
158-
159-
Args:
160-
grad_out: Gradient w.r.t layer norm output [M, N]
161-
x: Original input tensor [M, N]
162-
weight: Weight parameter [N]
163-
mean: Per-sample mean computed in forward pass [M]
164-
rstd: Per-sample reciprocal standard deviation from forward pass [M]
165-
166-
Returns:
167-
grad_x: Gradient w.r.t input tensor, shape [M, N]
168-
"""
169-
m, n = x.shape
170-
n = hl.specialize(n)
171-
172-
grad_x = torch.empty_like(x)
173-
174-
for tile_m in hl.tile(m):
175-
x_tile = x[tile_m, :].to(torch.float32)
176-
dy_tile = grad_out[tile_m, :].to(torch.float32)
177-
w = weight[:].to(torch.float32)
178-
mean_tile = mean[tile_m]
179-
rstd_tile = rstd[tile_m]
180-
181-
x_hat = (x_tile - mean_tile[:, None]) * rstd_tile[:, None]
182-
wdy = w * dy_tile
183-
c1 = torch.sum(x_hat * wdy, dim=-1) / n
184-
c2 = torch.sum(wdy, dim=-1) / n
185-
dx = (wdy - (x_hat * c1[:, None] + c2[:, None])) * rstd_tile[:, None]
186-
grad_x[tile_m, :] = dx.to(x.dtype)
187-
188-
return grad_x
150+
grad_bias = grad_bias_blocks.sum(0).to(weight.dtype)
151+
return grad_x, grad_weight, grad_bias
152+
return grad_x, grad_weight, None
189153

190154

191155
# %%
@@ -219,14 +183,10 @@ def backward( # type: ignore[override]
219183
# Check if bias gradient is needed
220184
compute_bias_grad = bias is not None
221185

222-
# First kernel: Compute gradients for weight and bias by reducing across batch dimension (M)
223-
grad_weight, grad_bias = layer_norm_bwd_dwdb(
186+
grad_x, grad_weight, grad_bias = layer_norm_bwd(
224187
grad_out, x, mean, rstd, weight, compute_bias_grad
225188
)
226189

227-
# Second kernel: Compute gradient for input (dx) using per-sample reductions across feature dimension (N)
228-
grad_x = layer_norm_bwd_dx(grad_out, x, weight, mean, rstd)
229-
230190
return grad_x, None, grad_weight, grad_bias, None
231191

232192

0 commit comments

Comments
 (0)