Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion batch_invariant_ops/batch_invariant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,12 @@ def mean_kernel(
acc += tl.sum(vals)

# Compute mean and store
mean_val = acc / N
# Handle empty dimension case (return zeros to avoid NaN)
if N == 0:
mean_val = 0.0
else:
mean_val = acc / N

output_idx = m_idx * output_stride0 + k_idx * output_stride1
tl.store(output_ptr + output_idx, mean_val)

Expand Down Expand Up @@ -395,6 +400,18 @@ def mean_dim(
if dim < 0:
dim = dim + input.ndim

# Get input shape early for empty dimension check
shape = list(input.shape)

# Handle empty dimension case (return zeros to avoid NaN)
if shape[dim] == 0:
if keepdim:
output_shape = shape.copy()
output_shape[dim] = 1
else:
output_shape = shape[:dim] + shape[dim + 1:]
return torch.zeros(output_shape, dtype=dtype, device=input.device)

# Handle dtype
if dtype is None:
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
Expand Down Expand Up @@ -468,6 +485,10 @@ def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)


def _softmax_batch_invariant(input, dim, _half_to_float):
return softmax(input, dim=dim)


def _log_softmax_batch_invariant(input, dim, _half_to_float):
assert not _half_to_float, "not implemented"
return log_softmax(input, dim=dim)
Expand Down Expand Up @@ -506,6 +527,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, dispatch_key)
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, dispatch_key )
_batch_invariant_LIB.impl("aten::softmax", _softmax_batch_invariant)
_batch_invariant_LIB.impl("aten::_log_softmax", _log_softmax_batch_invariant, dispatch_key )
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, dispatch_key )

Expand Down Expand Up @@ -537,3 +559,17 @@ def set_batch_invariant_mode(enabled: bool = True):

def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)


def softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Compute softmax along the last dimension."""
if dim != -1 and dim != input.ndim - 1:
raise ValueError("Only supports last dimension")
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1]).contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d)
# Simple softmax using existing log_softmax + exp
log_out = log_softmax(input_2d, dim=-1)
output = torch.exp(log_out)
return output.reshape(original_shape)