Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion batch_invariant_ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .batch_invariant_ops import (
from .batch_invariant_ops import softmax, (
set_batch_invariant_mode,
is_batch_invariant_mode_enabled,
disable_batch_invariant_mode,
Expand Down
37 changes: 35 additions & 2 deletions batch_invariant_ops/batch_invariant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def matmul_kernel_persistent(
offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n

for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True):
for tile_id in range(start_pid, num_tiles, NUM_SMS):
pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS)
start_m = pid_m * BLOCK_SIZE_M
start_n = pid_n * BLOCK_SIZE_N
Expand Down Expand Up @@ -365,7 +365,12 @@ def mean_kernel(
acc += tl.sum(vals)

# Compute mean and store
mean_val = acc / N
# Handle empty dimension case (return zeros to avoid NaN)
if N == 0:
mean_val = 0.0
else:
mean_val = acc / N

output_idx = m_idx * output_stride0 + k_idx * output_stride1
tl.store(output_ptr + output_idx, mean_val)

Expand Down Expand Up @@ -395,6 +400,15 @@ def mean_dim(
if dim < 0:
dim = dim + input.ndim

# Handle empty dimension case (return zeros to avoid NaN)
if shape[dim] == 0:
if keepdim:
output_shape = shape.copy()
output_shape[dim] = 1
else:
output_shape = shape[:dim] + shape[dim + 1:]
return torch.zeros(output_shape, dtype=dtype, device=input.device)

# Handle dtype
if dtype is None:
if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
Expand Down Expand Up @@ -468,6 +482,10 @@ def addmm_batch_invariant(bias, a, b):
return matmul_persistent(a, b, bias=bias)


def _softmax_batch_invariant(input, dim, _half_to_float):
return softmax(input, dim=dim)


def _log_softmax_batch_invariant(input, dim, _half_to_float):
assert not _half_to_float, "not implemented"
return log_softmax(input, dim=dim)
Expand Down Expand Up @@ -506,6 +524,7 @@ def enable_batch_invariant_mode():
_batch_invariant_LIB = torch.library.Library("aten", "IMPL")
_batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, dispatch_key)
_batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, dispatch_key )
_batch_invariant_LIB.impl("aten::softmax", _softmax_batch_invariant)
_batch_invariant_LIB.impl("aten::_log_softmax", _log_softmax_batch_invariant, dispatch_key )
_batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, dispatch_key )

Expand Down Expand Up @@ -537,3 +556,17 @@ def set_batch_invariant_mode(enabled: bool = True):

def get_batch_invariant_attention_block_size() -> AttentionBlockSize:
return AttentionBlockSize(block_m=16, block_n=16)


def softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""Compute softmax along the last dimension."""
if dim != -1 and dim != input.ndim - 1:
raise ValueError("Only supports last dimension")
original_shape = input.shape
input_2d = input.reshape(-1, input.shape[-1]).contiguous()
n_rows, n_cols = input_2d.shape
output = torch.empty_like(input_2d)
# Simple softmax using existing log_softmax + exp
log_out = log_softmax(input_2d, dim=-1)
output = torch.exp(log_out)
return output.reshape(original_shape)
127 changes: 127 additions & 0 deletions tests/test_batch_invariant_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Tests for batch-invariant operations."""

import pytest
import torch
from batch_invariant_ops import (
set_batch_invariant_mode,
mm_batch_invariant,
addmm_batch_invariant,
log_softmax,
mean_kernel,
)


class TestBatchInvariant:
"""Test batch-invariant property: op(x[:1], y) == op(x, y)[:1]"""

@pytest.fixture(autouse=True)
def setup(self):
"""Enable batch-invariant mode for all tests."""
with set_batch_invariant_mode(True):
yield

def test_mm_batch_invariant(self):
"""Test mm preserves batch invariance."""
B, D = 16, 32
a = torch.randn(B, D)
b = torch.randn(D, D)

# Single batch
out1 = mm_batch_invariant(a[:1], b)
# Full batch, sliced
out2 = mm_batch_invariant(a, b)[:1]

assert torch.allclose(out1, out2)

def test_mm_output_shape(self):
"""Test mm output has correct shape."""
a = torch.randn(8, 16)
b = torch.randn(16, 32)
result = mm_batch_invariant(a, b)
assert result.shape == (8, 32)

def test_addmm_batch_invariant(self):
"""Test addmm preserves batch invariance."""
B, D = 16, 32
bias = torch.randn(D)
a = torch.randn(B, D)
b = torch.randn(D, D)

# Single batch
out1 = addmm_batch_invariant(bias, a[:1], b)
# Full batch, sliced
out2 = addmm_batch_invariant(bias, a, b)[:1]

assert torch.allclose(out1, out2)

def test_addmm_output_shape(self):
"""Test addmm output has correct shape."""
bias = torch.randn(32)
a = torch.randn(8, 16)
b = torch.randn(16, 32)
result = addmm_batch_invariant(bias, a, b)
assert result.shape == (8, 32)

def test_log_softmax_batch_invariant(self):
"""Test log_softmax preserves batch invariance."""
B, D = 16, 32
x = torch.randn(B, D)

# Single batch
out1 = log_softmax(x[:1], dim=-1)
# Full batch, sliced
out2 = log_softmax(x, dim=-1)[:1]

assert torch.allclose(out1, out2)

def test_log_softmax_output_shape(self):
"""Test log_softmax output has correct shape."""
x = torch.randn(8, 16)
result = log_softmax(x, dim=-1)
assert result.shape == x.shape

def test_log_softmax_values(self):
"""Test log_softmax sums to 1 in probability space."""
x = torch.randn(4, 8)
result = log_softmax(x, dim=-1)
probs = torch.exp(result)
# Sum along the softmax dimension should be ~1
assert torch.allclose(probs.sum(dim=-1), torch.ones(probs.shape[0]), atol=1e-5)

def test_mean_kernel_batch_invariant(self):
"""Test mean preserves batch invariance."""
B, M, N = 4, 8, 16
x = torch.randn(B, M, N)

# Single batch
out1 = mean_kernel(x[:1], dim=1)
# Full batch, sliced
out2 = mean_kernel(x, dim=1)[:1]

assert torch.allclose(out1, out2)

def test_mean_kernel_output_shape(self):
"""Test mean_kernel output has correct shape."""
x = torch.randn(4, 8, 16)
result = mean_kernel(x, dim=1)
assert result.shape == (4, 16)

def test_mean_kernel_various_dims(self):
"""Test mean works along different dimensions."""
x = torch.randn(4, 8, 16)

# Along dim 1
result1 = mean_kernel(x, dim=1)
assert result1.shape == (4, 16)

# Along dim 2
result2 = mean_kernel(x, dim=2)
assert result2.shape == (4, 8)

def test_mean_kernel_empty_dimension(self):
"""Test mean_kernel handles edge case of empty dimension."""
# This tests the fix for issue #15
x = torch.randn(1, 0, 8)
result = mean_kernel(x, dim=1)
# Should not produce NaN
assert not torch.isnan(result).any() if result.numel() > 0 else True