Skip to content

Commit ebb9f34

Browse files
authored
Faster KL Div (#822)
1 parent 432c653 commit ebb9f34

File tree

3 files changed

+28
-29
lines changed

3 files changed

+28
-29
lines changed

examples/kl_div.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,34 +72,34 @@ def kl_div_forward(
7272
else:
7373
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
7474

75-
kl_loss = torch.zeros_like(y_pred)
76-
7775
# Call register_block_size to know block_size_n outside of the reduction loop.
7876
block_size_n = hl.register_block_size(V)
77+
block_size_m = hl.register_block_size(BT)
7978

80-
BT_SIZE = helion.cdiv(BT, BT) # Process all at once for simplicity
81-
for tile_bt in hl.tile(BT, block_size=BT_SIZE):
79+
for tile_bt in hl.tile(BT, block_size=block_size_m):
8280
loss_sum = hl.zeros([tile_bt, block_size_n], dtype=torch.float32)
8381

8482
for tile_v in hl.tile(V, block_size=block_size_n):
83+
kl_loss = hl.zeros([block_size_m, block_size_n], dtype=torch.float32)
84+
8585
y_pred_val = y_pred[tile_bt, tile_v]
8686
y_true_val = y_true[tile_bt, tile_v]
8787

8888
if log_target:
8989
# KL(P || Q) = exp(y_true) * (y_true - y_pred) when both in log-space
9090
prob_true = torch.exp(y_true_val)
91-
kl_loss[tile_bt, tile_v] = prob_true * (y_true_val - y_pred_val)
91+
kl_loss += prob_true * (y_true_val - y_pred_val)
9292

9393
else:
9494
# KL(P || Q) = y_true * (log(y_true) - y_pred) when y_pred in log-space
9595
log_true = torch.log(torch.clamp(y_true_val, min=eps))
96-
kl_loss[tile_bt, tile_v] = y_true_val * (log_true - y_pred_val)
96+
kl_loss += y_true_val * (log_true - y_pred_val)
9797

9898
if reduction == "none":
99-
loss[tile_bt, tile_v] = kl_loss[tile_bt, tile_v]
99+
loss[tile_bt, tile_v] = kl_loss
100100
else:
101101
# Sum over vocabulary dimension
102-
loss_sum += kl_loss[tile_bt, tile_v]
102+
loss_sum += kl_loss
103103

104104
if reduction != "none":
105105
loss[tile_bt] = loss_sum.sum(dim=-1)

test/test_examples.expected

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2266,7 +2266,6 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
22662266
from __future__ import annotations
22672267

22682268
import torch
2269-
import helion
22702269
import triton
22712270
import triton.language as tl
22722271
from torch._inductor.runtime import triton_helpers
@@ -2275,41 +2274,44 @@ from torch._inductor.runtime.triton_compat import libdevice
22752274
from helion.runtime import default_launcher as _default_launcher
22762275

22772276
@triton.jit
2278-
def _helion_kl_div_forward(y_pred, y_true, kl_loss, loss, kl_loss_stride_0, kl_loss_stride_1, loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
2277+
def _helion_kl_div_forward(y_pred, y_true, loss, loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
22792278
pid_0 = tl.program_id(0)
22802279
offset_1 = pid_0 * _BLOCK_SIZE_1
22812280
indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
22822281
mask_1 = indices_1 < BT
22832282
loss_sum = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2284-
for offset_0 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_0):
2285-
indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
2286-
mask_0 = indices_0 < V
2283+
for offset_0 in tl.range(0, V.to(tl.int32)):
2284+
indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
22872285
loss_sum_copy = loss_sum
22882286
loss_sum_copy_0 = loss_sum_copy
2289-
y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2290-
y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2287+
kl_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2288+
y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None], other=0)
2289+
y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None], other=0)
22912290
if log_target:
22922291
y_true_val_copy = y_true_val
22932292
y_pred_val_copy = y_pred_val
2293+
kl_loss_copy = kl_loss
22942294
y_true_val_copy_0 = y_true_val_copy
22952295
y_pred_val_copy_0 = y_pred_val_copy
2296+
kl_loss_copy_0 = kl_loss_copy
22962297
v_0 = libdevice.exp(y_true_val_copy_0)
22972298
v_1 = y_true_val_copy_0 - y_pred_val_copy_0
22982299
v_2 = v_0 * v_1
2299-
tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_2, mask_1[:, None] & mask_0[None, :])
2300+
kl_loss = kl_loss_copy_0 + v_2
23002301
_not = not log_target
23012302
if _not:
23022303
y_true_val_copy_1 = y_true_val
23032304
y_pred_val_copy_1 = y_pred_val
2305+
kl_loss_copy_1 = kl_loss
23042306
y_true_val_copy_1_0 = y_true_val_copy_1
23052307
y_pred_val_copy_1_0 = y_pred_val_copy_1
2306-
v_3 = triton_helpers.maximum(y_true_val_copy_1_0, eps)
2307-
v_4 = tl_math.log(v_3)
2308-
v_5 = v_4 - y_pred_val_copy_1_0
2309-
v_6 = y_true_val_copy_1_0 * v_5
2310-
tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), v_6, mask_1[:, None] & mask_0[None, :])
2311-
load_2 = tl.load(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2312-
loss_sum = loss_sum_copy_0 + load_2
2308+
kl_loss_copy_1_0 = kl_loss_copy_1
2309+
v_4 = triton_helpers.maximum(y_true_val_copy_1_0, eps)
2310+
v_5 = tl_math.log(v_4)
2311+
v_6 = v_5 - y_pred_val_copy_1_0
2312+
v_7 = y_true_val_copy_1_0 * v_6
2313+
kl_loss = kl_loss_copy_1_0 + v_7
2314+
loss_sum = loss_sum_copy_0 + kl_loss
23132315
sum_1 = tl.cast(tl.sum(loss_sum, 1), tl.float32)
23142316
tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
23152317

@@ -2333,11 +2335,8 @@ def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduc
23332335
loss = torch.zeros_like(y_pred)
23342336
else:
23352337
loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
2336-
kl_loss = torch.zeros_like(y_pred)
2337-
BT_SIZE = helion.cdiv(BT, BT)
2338-
_BLOCK_SIZE_1 = BT_SIZE
2339-
_BLOCK_SIZE_0 = 4096
2340-
_launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, kl_loss, loss, kl_loss.stride(0), kl_loss.stride(1), loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
2338+
_BLOCK_SIZE_1 = 4096
2339+
_launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, loss, loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3)
23412340
if reduction == 'batchmean':
23422341
final_loss = torch.sum(loss) / BT
23432342
elif reduction == 'sum':

test/test_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,7 @@ def test_kl_div(self):
11401140
args,
11411141
torch_kl_div(*args),
11421142
fn_name="kl_div_forward",
1143-
block_sizes=[4096],
1143+
block_sizes=[1, 4096],
11441144
num_warps=4,
11451145
num_stages=3,
11461146
)

0 commit comments

Comments
 (0)