@@ -2266,7 +2266,6 @@ def jsd_forward(_input: Tensor, target: Tensor, shift_labels: Tensor | None=None
22662266from __future__ import annotations
22672267
22682268import torch
2269- import helion
22702269import triton
22712270import triton.language as tl
22722271from torch._inductor.runtime import triton_helpers
@@ -2275,41 +2274,44 @@ from torch._inductor.runtime.triton_compat import libdevice
22752274from helion.runtime import default_launcher as _default_launcher
22762275
22772276@triton.jit
2278- def _helion_kl_div_forward(y_pred, y_true, kl_loss,  loss, kl_loss_stride_0, kl_loss_stride_1 , loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
2277+ def _helion_kl_div_forward(y_pred, y_true, loss, loss_stride_0, y_pred_stride_0, y_pred_stride_1, y_true_stride_0, y_true_stride_1, BT, V, log_target, eps, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_0: tl.constexpr):
22792278    pid_0 = tl.program_id(0)
22802279    offset_1 = pid_0 * _BLOCK_SIZE_1
22812280    indices_1 = (offset_1 + tl.arange(0, _BLOCK_SIZE_1)).to(tl.int32)
22822281    mask_1 = indices_1 < BT
22832282    loss_sum = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2284-     for offset_0 in tl.range(0, V.to(tl.int32), _BLOCK_SIZE_0):
2285-         indices_0 = offset_0 + tl.arange(0, _BLOCK_SIZE_0).to(tl.int32)
2286-         mask_0 = indices_0 < V
2283+     for offset_0 in tl.range(0, V.to(tl.int32)):
2284+         indices_0 = offset_0 + tl.arange(0, 1).to(tl.int32)
22872285        loss_sum_copy = loss_sum
22882286        loss_sum_copy_0 = loss_sum_copy
2289-         y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2290-         y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None] & mask_0[None, :], other=0)
2287+         kl_loss = tl.full([_BLOCK_SIZE_1, _BLOCK_SIZE_0], 0.0, tl.float32)
2288+         y_pred_val = tl.load(y_pred + (indices_1[:, None] * y_pred_stride_0 + indices_0[None, :] * y_pred_stride_1), mask_1[:, None], other=0)
2289+         y_true_val = tl.load(y_true + (indices_1[:, None] * y_true_stride_0 + indices_0[None, :] * y_true_stride_1), mask_1[:, None], other=0)
22912290        if log_target:
22922291            y_true_val_copy = y_true_val
22932292            y_pred_val_copy = y_pred_val
2293+             kl_loss_copy = kl_loss
22942294            y_true_val_copy_0 = y_true_val_copy
22952295            y_pred_val_copy_0 = y_pred_val_copy
2296+             kl_loss_copy_0 = kl_loss_copy
22962297            v_0 = libdevice.exp(y_true_val_copy_0)
22972298            v_1 = y_true_val_copy_0 - y_pred_val_copy_0
22982299            v_2 = v_0 * v_1
2299-             tl.store( kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1),  v_2, mask_1[:, None] & mask_0[None, :]) 
2300+             kl_loss = kl_loss_copy_0 +  v_2
23002301        _not = not log_target
23012302        if _not:
23022303            y_true_val_copy_1 = y_true_val
23032304            y_pred_val_copy_1 = y_pred_val
2305+             kl_loss_copy_1 = kl_loss
23042306            y_true_val_copy_1_0 = y_true_val_copy_1
23052307            y_pred_val_copy_1_0 = y_pred_val_copy_1
2306-             v_3  = triton_helpers.maximum(y_true_val_copy_1_0, eps) 
2307-             v_4 = tl_math.log(v_3 )
2308-             v_5 = v_4 - y_pred_val_copy_1_0 
2309-             v_6 = y_true_val_copy_1_0 * v_5 
2310-             tl.store(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1),  v_6, mask_1[:, None] & mask_0[None, :]) 
2311-         load_2 = tl.load(kl_loss + (indices_1[:, None] * kl_loss_stride_0 + indices_0[None, :] * kl_loss_stride_1), mask_1[:, None] & mask_0[None, :], other=0) 
2312-         loss_sum = loss_sum_copy_0 + load_2 
2308+             kl_loss_copy_1_0  = kl_loss_copy_1 
2309+             v_4 = triton_helpers.maximum(y_true_val_copy_1_0, eps )
2310+             v_5 = tl_math.log( v_4) 
2311+             v_6 = v_5 - y_pred_val_copy_1_0 
2312+             v_7 = y_true_val_copy_1_0 *  v_6
2313+             kl_loss = kl_loss_copy_1_0 + v_7 
2314+         loss_sum = loss_sum_copy_0 + kl_loss 
23132315    sum_1 = tl.cast(tl.sum(loss_sum, 1), tl.float32)
23142316    tl.store(loss + indices_1 * loss_stride_0, sum_1, mask_1)
23152317
@@ -2333,11 +2335,8 @@ def kl_div_forward(y_pred: Tensor, y_true: Tensor, log_target: bool=False, reduc
23332335        loss = torch.zeros_like(y_pred)
23342336    else:
23352337        loss = torch.zeros((BT,), dtype=torch.float32, device=y_pred.device)
2336-     kl_loss = torch.zeros_like(y_pred)
2337-     BT_SIZE = helion.cdiv(BT, BT)
2338-     _BLOCK_SIZE_1 = BT_SIZE
2339-     _BLOCK_SIZE_0 = 4096
2340-     _launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, kl_loss, loss, kl_loss.stride(0), kl_loss.stride(1), loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, _BLOCK_SIZE_0, num_warps=4, num_stages=3)
2338+     _BLOCK_SIZE_1 = 4096
2339+     _launcher(_helion_kl_div_forward, (triton.cdiv(BT, _BLOCK_SIZE_1),), y_pred, y_true, loss, loss.stride(0), y_pred.stride(0), y_pred.stride(1), y_true.stride(0), y_true.stride(1), BT, V, log_target, eps, _BLOCK_SIZE_1, 1, num_warps=4, num_stages=3)
23412340    if reduction == 'batchmean':
23422341        final_loss = torch.sum(loss) / BT
23432342    elif reduction == 'sum':
0 commit comments