@@ -83,14 +83,14 @@ def layer_norm_fwd(
8383
8484# %%
8585@helion .kernel
86- def layer_norm_bwd_dwdb (
86+ def layer_norm_bwd (
8787 grad_out : torch .Tensor ,
8888 x : torch .Tensor ,
8989 mean : torch .Tensor ,
9090 rstd : torch .Tensor ,
9191 weight : torch .Tensor ,
9292 compute_bias_grad : hl .constexpr = True , # type: ignore[valid-type]
93- ) -> tuple [torch .Tensor , torch .Tensor | None ]:
93+ ) -> tuple [torch .Tensor , torch .Tensor , torch . Tensor | None ]:
9494 """
9595 Compute gradients for weight (dW) and optionally bias (dB) parameters.
9696
@@ -106,86 +106,50 @@ def layer_norm_bwd_dwdb(
106106 compute_bias_grad: Whether to compute bias gradient (default: True)
107107
108108 Returns:
109- (grad_weight, grad_bias): Gradients for weight and bias (if computed), both shape [N]
109+ (grad_x, grad_weight, grad_bias): Gradients for input, weight, and bias (if computed)
110110 grad_bias is None if compute_bias_grad is False
111111 """
112- m , n = x .shape
113- n = hl .specialize (n )
114112
115- dw = torch .empty ([n ], dtype = weight .dtype , device = weight .device )
116- if compute_bias_grad :
117- db = torch .empty ([n ], dtype = weight .dtype , device = weight .device )
118- else :
119- db = None
120-
121- # Reduce across rows (M) inside the kernel without atomics
122- m = hl .specialize (m )
113+ m_block = hl .register_block_size (x .size (0 ))
114+ n = hl .specialize (x .size (1 ))
123115
124- for tile_n in hl .tile (n ):
125- rows = hl .arange (0 , m )
126- # Load slices for all rows in rdim and this tile of columns
127- x_blk = x [rows , tile_n ].to (torch .float32 )
128- dy_blk = grad_out [rows , tile_n ].to (torch .float32 )
129- mean_vec = mean [rows ]
130- rstd_vec = rstd [rows ]
131-
132- x_hat_blk = (x_blk - mean_vec [:, None ]) * rstd_vec [:, None ]
133- dw_tile = torch .sum (dy_blk * x_hat_blk , dim = 0 ).to (weight .dtype )
116+ grad_x = torch .empty_like (x )
117+ num_blocks = (x .size (0 ) + m_block - 1 ) // m_block
118+ grad_weight_blocks = x .new_empty ([num_blocks , n ], dtype = torch .float32 )
119+ grad_bias_blocks = x .new_empty ([num_blocks , n ], dtype = torch .float32 )
134120
135- dw [tile_n ] = dw_tile
121+ for mb_cta in hl .tile (x .size (0 ), block_size = m_block ):
122+ grad_w_acc = weight .new_zeros (n , dtype = torch .float32 )
136123 if compute_bias_grad :
137- db_tile = torch .sum (dy_blk , dim = 0 ).to (weight .dtype )
138- db [tile_n ] = db_tile # type: ignore[index]
124+ grad_b_acc = weight .new_zeros (n , dtype = torch .float32 )
125+ weight_cta = weight [None , :].to (torch .float32 )
126+ for mb in hl .tile (mb_cta .begin , mb_cta .end ):
127+ x_mb = x [mb , :].to (torch .float32 )
128+ dy_mb = grad_out [mb , :].to (torch .float32 )
129+ mean_mb = mean [mb ].to (torch .float32 )
130+ rstd_mb = rstd [mb ].to (torch .float32 )
131+
132+ x_hat = (x_mb - mean_mb [:, None ]) * rstd_mb [:, None ]
133+
134+ grad_w_acc += torch .sum (dy_mb * x_hat , dim = 0 )
135+ if compute_bias_grad :
136+ grad_b_acc += torch .sum (dy_mb , dim = 0 ) # pyright: ignore[reportPossiblyUnboundVariable]
137+
138+ wdy = weight_cta * dy_mb
139+ c1 = torch .sum (x_hat * wdy , dim = - 1 ) / n
140+ c2 = torch .sum (wdy , dim = - 1 ) / n
141+ dx = (wdy - (x_hat * c1 [:, None ] + c2 [:, None ])) * rstd_mb [:, None ]
142+ grad_x [mb , :] = dx .to (x .dtype )
143+
144+ grad_weight_blocks [mb_cta .id , :] = grad_w_acc
145+ if compute_bias_grad :
146+ grad_bias_blocks [mb_cta .id , :] = grad_b_acc # type: ignore[index]
139147
148+ grad_weight = grad_weight_blocks .sum (0 ).to (weight .dtype )
140149 if compute_bias_grad :
141- return dw , db
142- return dw , None
143-
144-
145- @helion .kernel
146- def layer_norm_bwd_dx (
147- grad_out : torch .Tensor ,
148- x : torch .Tensor ,
149- weight : torch .Tensor ,
150- mean : torch .Tensor ,
151- rstd : torch .Tensor ,
152- ) -> torch .Tensor :
153- """
154- Compute gradient for input tensor (dX).
155-
156- This kernel computes per-sample gradients by performing reductions across
157- the feature dimension (N) for each sample in the batch.
158-
159- Args:
160- grad_out: Gradient w.r.t layer norm output [M, N]
161- x: Original input tensor [M, N]
162- weight: Weight parameter [N]
163- mean: Per-sample mean computed in forward pass [M]
164- rstd: Per-sample reciprocal standard deviation from forward pass [M]
165-
166- Returns:
167- grad_x: Gradient w.r.t input tensor, shape [M, N]
168- """
169- m , n = x .shape
170- n = hl .specialize (n )
171-
172- grad_x = torch .empty_like (x )
173-
174- for tile_m in hl .tile (m ):
175- x_tile = x [tile_m , :].to (torch .float32 )
176- dy_tile = grad_out [tile_m , :].to (torch .float32 )
177- w = weight [:].to (torch .float32 )
178- mean_tile = mean [tile_m ]
179- rstd_tile = rstd [tile_m ]
180-
181- x_hat = (x_tile - mean_tile [:, None ]) * rstd_tile [:, None ]
182- wdy = w * dy_tile
183- c1 = torch .sum (x_hat * wdy , dim = - 1 ) / n
184- c2 = torch .sum (wdy , dim = - 1 ) / n
185- dx = (wdy - (x_hat * c1 [:, None ] + c2 [:, None ])) * rstd_tile [:, None ]
186- grad_x [tile_m , :] = dx .to (x .dtype )
187-
188- return grad_x
150+ grad_bias = grad_bias_blocks .sum (0 ).to (weight .dtype )
151+ return grad_x , grad_weight , grad_bias
152+ return grad_x , grad_weight , None
189153
190154
191155# %%
@@ -219,14 +183,10 @@ def backward( # type: ignore[override]
219183 # Check if bias gradient is needed
220184 compute_bias_grad = bias is not None
221185
222- # First kernel: Compute gradients for weight and bias by reducing across batch dimension (M)
223- grad_weight , grad_bias = layer_norm_bwd_dwdb (
186+ grad_x , grad_weight , grad_bias = layer_norm_bwd (
224187 grad_out , x , mean , rstd , weight , compute_bias_grad
225188 )
226189
227- # Second kernel: Compute gradient for input (dx) using per-sample reductions across feature dimension (N)
228- grad_x = layer_norm_bwd_dx (grad_out , x , weight , mean , rstd )
229-
230190 return grad_x , None , grad_weight , grad_bias , None
231191
232192
0 commit comments