diff --git a/batch_invariant_ops/batch_invariant_ops.py b/batch_invariant_ops/batch_invariant_ops.py index b9021bb..6755c5d 100644 --- a/batch_invariant_ops/batch_invariant_ops.py +++ b/batch_invariant_ops/batch_invariant_ops.py @@ -365,7 +365,12 @@ def mean_kernel( acc += tl.sum(vals) # Compute mean and store - mean_val = acc / N + # Handle empty dimension case (return zeros to avoid NaN) + if N == 0: + mean_val = 0.0 + else: + mean_val = acc / N + output_idx = m_idx * output_stride0 + k_idx * output_stride1 tl.store(output_ptr + output_idx, mean_val) @@ -395,6 +400,18 @@ def mean_dim( if dim < 0: dim = dim + input.ndim + # Get input shape early for empty dimension check + shape = list(input.shape) + + # Handle empty dimension case (return zeros to avoid NaN) + if shape[dim] == 0: + if keepdim: + output_shape = shape.copy() + output_shape[dim] = 1 + else: + output_shape = shape[:dim] + shape[dim + 1:] + return torch.zeros(output_shape, dtype=dtype, device=input.device) + # Handle dtype if dtype is None: if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: @@ -468,6 +485,10 @@ def addmm_batch_invariant(bias, a, b): return matmul_persistent(a, b, bias=bias) +def _softmax_batch_invariant(input, dim, _half_to_float): + return softmax(input, dim=dim) + + def _log_softmax_batch_invariant(input, dim, _half_to_float): assert not _half_to_float, "not implemented" return log_softmax(input, dim=dim) @@ -506,6 +527,7 @@ def enable_batch_invariant_mode(): _batch_invariant_LIB = torch.library.Library("aten", "IMPL") _batch_invariant_LIB.impl("aten::mm", mm_batch_invariant, dispatch_key) _batch_invariant_LIB.impl("aten::addmm", addmm_batch_invariant, dispatch_key ) + _batch_invariant_LIB.impl("aten::softmax", _softmax_batch_invariant) _batch_invariant_LIB.impl("aten::_log_softmax", _log_softmax_batch_invariant, dispatch_key ) _batch_invariant_LIB.impl("aten::mean.dim", mean_batch_invariant, dispatch_key ) @@ -537,3 +559,17 @@ def set_batch_invariant_mode(enabled: bool = True): def get_batch_invariant_attention_block_size() -> AttentionBlockSize: return AttentionBlockSize(block_m=16, block_n=16) + + +def softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor: + """Compute softmax along the last dimension.""" + if dim != -1 and dim != input.ndim - 1: + raise ValueError("Only supports last dimension") + original_shape = input.shape + input_2d = input.reshape(-1, input.shape[-1]).contiguous() + n_rows, n_cols = input_2d.shape + output = torch.empty_like(input_2d) + # Simple softmax using existing log_softmax + exp + log_out = log_softmax(input_2d, dim=-1) + output = torch.exp(log_out) + return output.reshape(original_shape)