diff --git a/batch_invariant_ops/__init__.py b/batch_invariant_ops/__init__.py index d06595b..18d0c32 100644 --- a/batch_invariant_ops/__init__.py +++ b/batch_invariant_ops/__init__.py @@ -1,4 +1,4 @@ -from .batch_invariant_ops import ( +from .batch_invariant_ops import softmax, ( set_batch_invariant_mode, is_batch_invariant_mode_enabled, disable_batch_invariant_mode, diff --git a/batch_invariant_ops/batch_invariant_ops.py b/batch_invariant_ops/batch_invariant_ops.py index b9021bb..1d6b9a1 100644 --- a/batch_invariant_ops/batch_invariant_ops.py +++ b/batch_invariant_ops/batch_invariant_ops.py @@ -75,7 +75,7 @@ def matmul_kernel_persistent( offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K) num_pid_in_group = GROUP_SIZE_M * num_pid_n - for tile_id in tl.range(start_pid, num_tiles, NUM_SMS, flatten=True): + for tile_id in range(start_pid, num_tiles, NUM_SMS): pid_m, pid_n = _compute_pid(tile_id, num_pid_in_group, num_pid_m, GROUP_SIZE_M, NUM_SMS) start_m = pid_m * BLOCK_SIZE_M start_n = pid_n * BLOCK_SIZE_N @@ -365,7 +365,12 @@ def mean_kernel( acc += tl.sum(vals) # Compute mean and store - mean_val = acc / N + # Handle empty dimension case (return zeros to avoid NaN) + if N == 0: + mean_val = 0.0 + else: + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 tl.store(output_ptr + output_idx, mean_val) @@ -395,6 +400,15 @@ def mean_dim( if dim < 0: dim = dim + input.ndim + # Handle empty dimension case (return zeros to avoid NaN) + if shape[dim] == 0: + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1:] + return torch.zeros(output_shape, dtype=dtype, device=input.device) + # Handle dtype if dtype is None: if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: @@ -468,6 +482,10 @@ def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) +def _softmax_batch_invariant(input, dim, _half_to_float): + return softmax(input, dim=dim) + + def _log_softmax_batch_invariant(input, dim, _half_to_float): assert not _half_to_float, "not implemented" return log_softmax(input, dim=dim) @@ -506,6 +524,7 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, dispatch_key) _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, dispatch_key ) + _batch_invariant_LIB.impl("aten::softmax", _softmax_batch_invariant) _batch_invariant_LIB.impl("aten::_log_softmax", _log_softmax_batch_invariant, dispatch_key ) _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, dispatch_key ) @@ -537,3 +556,17 @@ def set_batch_invariant_mode(enabled: bool = True): def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) + + +def softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Compute softmax along the last dimension.""" + if dim != -1 and dim != input.ndim - 1: + raise ValueError("Only supports last dimension") + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]).contiguous() + n_rows, n_cols = input_2d.shape + output = torch.empty_like(input_2d) + # Simple softmax using existing log_softmax + exp + log_out = log_softmax(input_2d, dim=-1) + output = torch.exp(log_out) + return output.reshape(original_shape) diff --git a/tests/test_batch_invariant_ops.py b/tests/test_batch_invariant_ops.py new file mode 100644 index 0000000..fd4d51f --- /dev/null +++ b/tests/test_batch_invariant_ops.py @@ -0,0 +1,127 @@ +"""Tests for batch-invariant operations.""" + +import pytest +import torch +from batch_invariant_ops import ( + set_batch_invariant_mode, + mm_batch_invariant, + addmm_batch_invariant, + log_softmax, + mean_kernel, +) + + +class TestBatchInvariant: + """Test batch-invariant property: op(x[:1], y) == op(x, y)[:1]""" + + @pytest.fixture(autouse=True) + def setup(self): + """Enable batch-invariant mode for all tests.""" + with set_batch_invariant_mode(True): + yield + + def test_mm_batch_invariant(self): + """Test mm preserves batch invariance.""" + B, D = 16, 32 + a = torch.randn(B, D) + b = torch.randn(D, D) + + # Single batch + out1 = mm_batch_invariant(a[:1], b) + # Full batch, sliced + out2 = mm_batch_invariant(a, b)[:1] + + assert torch.allclose(out1, out2) + + def test_mm_output_shape(self): + """Test mm output has correct shape.""" + a = torch.randn(8, 16) + b = torch.randn(16, 32) + result = mm_batch_invariant(a, b) + assert result.shape == (8, 32) + + def test_addmm_batch_invariant(self): + """Test addmm preserves batch invariance.""" + B, D = 16, 32 + bias = torch.randn(D) + a = torch.randn(B, D) + b = torch.randn(D, D) + + # Single batch + out1 = addmm_batch_invariant(bias, a[:1], b) + # Full batch, sliced + out2 = addmm_batch_invariant(bias, a, b)[:1] + + assert torch.allclose(out1, out2) + + def test_addmm_output_shape(self): + """Test addmm output has correct shape.""" + bias = torch.randn(32) + a = torch.randn(8, 16) + b = torch.randn(16, 32) + result = addmm_batch_invariant(bias, a, b) + assert result.shape == (8, 32) + + def test_log_softmax_batch_invariant(self): + """Test log_softmax preserves batch invariance.""" + B, D = 16, 32 + x = torch.randn(B, D) + + # Single batch + out1 = log_softmax(x[:1], dim=-1) + # Full batch, sliced + out2 = log_softmax(x, dim=-1)[:1] + + assert torch.allclose(out1, out2) + + def test_log_softmax_output_shape(self): + """Test log_softmax output has correct shape.""" + x = torch.randn(8, 16) + result = log_softmax(x, dim=-1) + assert result.shape == x.shape + + def test_log_softmax_values(self): + """Test log_softmax sums to 1 in probability space.""" + x = torch.randn(4, 8) + result = log_softmax(x, dim=-1) + probs = torch.exp(result) + # Sum along the softmax dimension should be ~1 + assert torch.allclose(probs.sum(dim=-1), torch.ones(probs.shape[0]), atol=1e-5) + + def test_mean_kernel_batch_invariant(self): + """Test mean preserves batch invariance.""" + B, M, N = 4, 8, 16 + x = torch.randn(B, M, N) + + # Single batch + out1 = mean_kernel(x[:1], dim=1) + # Full batch, sliced + out2 = mean_kernel(x, dim=1)[:1] + + assert torch.allclose(out1, out2) + + def test_mean_kernel_output_shape(self): + """Test mean_kernel output has correct shape.""" + x = torch.randn(4, 8, 16) + result = mean_kernel(x, dim=1) + assert result.shape == (4, 16) + + def test_mean_kernel_various_dims(self): + """Test mean works along different dimensions.""" + x = torch.randn(4, 8, 16) + + # Along dim 1 + result1 = mean_kernel(x, dim=1) + assert result1.shape == (4, 16) + + # Along dim 2 + result2 = mean_kernel(x, dim=2) + assert result2.shape == (4, 8) + + def test_mean_kernel_empty_dimension(self): + """Test mean_kernel handles edge case of empty dimension.""" + # This tests the fix for issue #15 + x = torch.randn(1, 0, 8) + result = mean_kernel(x, dim=1) + # Should not produce NaN + assert not torch.isnan(result).any() if result.numel() > 0 else True \ No newline at end of file