diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index f28688381..77d3ea1a3 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -240,6 +240,98 @@ def count_nonzero_input_fn(shape, dtype, device): bench.run() +def max_pool2d_input_fn(shape, dtype, device): + inp = generate_tensor_input(shape, dtype, device) + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "dilation": 1, + "ceil_mode": False, + } + if Config.bench_level == BenchLevel.COMPREHENSIVE: + # Non-square kernel/stride/padding + if shape[-2] > 5 and shape[-1] > 5: + yield inp, { + "kernel_size": (3, 5), + "stride": (2, 1), + "padding": (1, 2), + "dilation": 1, + "ceil_mode": False, + } + # With dilation + yield inp, { + "kernel_size": 3, + "stride": 1, + "padding": 1, + "dilation": 2, + "ceil_mode": False, + } + # With ceil_mode + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "dilation": 1, + "ceil_mode": True, + } + + +class MaxPool2dBenchmark(GenericBenchmark): + def get_input_iter(self, cur_dtype) -> Generator: + shapes_4d = [ + (4, 3, 224, 224), # Typical input image size + (16, 64, 56, 56), # Early ResNet layer output + (32, 128, 28, 28), # Mid ResNet layer output + (64, 256, 14, 14), # Later ResNet layer output + (128, 512, 7, 7), # Final ResNet layer output + ] + + for shape in shapes_4d: + yield from self.input_fn(shape, cur_dtype, self.device) + + +@pytest.mark.max_pool2d +def test_perf_max_pool2d(): + bench = MaxPool2dBenchmark( + input_fn=max_pool2d_input_fn, + op_name="max_pool2d_with_indices", + torch_op=torch.nn.functional.max_pool2d_with_indices, + dtypes=FLOAT_DTYPES, + ) + bench.set_gems(flag_gems.max_pool2d_with_indices) + bench.run() + + +@pytest.mark.max_pool2d_backward +def test_perf_max_pool2d_backward(): + def max_pool2d_backward_input_fn(shape, dtype, device): + for forward_args in max_pool2d_input_fn(shape, dtype, device): + inp, params = forward_args + inp.requires_grad_(True) + output, indices = torch.nn.functional.max_pool2d_with_indices(inp, **params) + grad_output = torch.randn_like(output) + yield grad_output, inp, indices, params + + def torch_max_pool2d_backward_wrapper(grad_output, input, indices, **kwargs): + output, _ = torch.nn.functional.max_pool2d_with_indices(input, **kwargs) + grad_input = torch.autograd.grad( + outputs=(output,), inputs=(input,), grad_outputs=(grad_output,) + ) + return grad_input[0] + + bench = MaxPool2dBenchmark( + input_fn=max_pool2d_backward_input_fn, + op_name="max_pool2d_backward", + torch_op=torch_max_pool2d_backward_wrapper, + dtypes=FLOAT_DTYPES, + is_backward=False, + ) + + bench.set_gems(flag_gems.max_pool2d_backward) + bench.run() + + @pytest.mark.dot def test_perf_dot(): def dot_input_fn(shape, dtype, device): diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index c7598b319..180ddac2a 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -189,6 +189,8 @@ def enable( ("max", max), ("max.dim", max_dim), ("maximum", maximum), + ("max_pool2d_with_indices", max_pool2d_with_indices), + ("max_pool2d_backward", max_pool2d_backward), ("mean", mean), ("mean.dim", mean_dim), ("min", min), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 7bc001697..a18db0e62 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -108,6 +108,10 @@ from flag_gems.ops.masked_fill import masked_fill, masked_fill_ from flag_gems.ops.masked_select import masked_select from flag_gems.ops.max import max, max_dim +from flag_gems.ops.max_pool2d_with_indices import ( + max_pool2d_backward, + max_pool2d_with_indices, +) from flag_gems.ops.maximum import maximum from flag_gems.ops.mean import mean, mean_dim from flag_gems.ops.min import min, min_dim @@ -347,6 +351,8 @@ "max", "max_dim", "maximum", + "max_pool2d_with_indices", + "max_pool2d_backward", "mean", "mean_dim", "min", diff --git a/src/flag_gems/ops/max_pool2d_with_indices.py b/src/flag_gems/ops/max_pool2d_with_indices.py new file mode 100644 index 000000000..47418f9e1 --- /dev/null +++ b/src/flag_gems/ops/max_pool2d_with_indices.py @@ -0,0 +1,396 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry +from flag_gems.utils.limits import get_dtype_min + +logger = logging.getLogger(__name__) + + +def max_pool2d_output_size( + in_size: int, + kernel_size: int, + stride: int, + padding: int, + dilation: int, + ceil_mode: bool = False, +) -> int: + effective_kernel_size = (kernel_size - 1) * dilation + 1 + numerator = in_size + 2 * padding - effective_kernel_size + if ceil_mode: + output_size = (numerator + stride - 1) // stride + 1 + # PyTorch-compatible adjustment for ceil_mode + if (output_size - 1) * stride >= in_size + padding: + output_size -= 1 + else: + output_size = numerator // stride + 1 + + return output_size + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=3, num_warps=8), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 64}, num_stages=2, num_warps=8), + ], + key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], +) +@triton.jit +def max_pool2d_forward_kernel( + input_ptr, + output_ptr, + indices_ptr, + # Input tensor strides + in_stride_n, + in_stride_c, + in_stride_h, + in_stride_w, + # Input/Output shapes + in_c, + in_h, + in_w, + out_h, + out_w, + # Pooling parameters + kernel_h: tl.constexpr, + kernel_w: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + padding_h: tl.constexpr, + padding_w: tl.constexpr, + dilation_h: tl.constexpr, + dilation_w: tl.constexpr, + # Meta-parameters for tiling + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, +): + pid_nc = tl.program_id(0) + pid_hw = tl.program_id(1) + num_w_blocks = tl.cdiv(out_w, BLOCK_W) + h_block_idx = pid_hw // num_w_blocks + w_block_idx = pid_hw % num_w_blocks + n_idx = pid_nc // in_c + c_idx = pid_nc % in_c + + h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + + dtype = input_ptr.type.element_ty + min_val = get_dtype_min(dtype) + max_val_acc = tl.full((BLOCK_H, BLOCK_W), min_val, dtype=dtype) + max_idx_acc = tl.full((BLOCK_H, BLOCK_W), -1, dtype=tl.int64) + + input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c + + for kh in tl.static_range(0, kernel_h): + for kw in tl.static_range(0, kernel_w): + h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h + w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w + in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) + input_offset = h_in * in_stride_h + w_in * in_stride_w + current_val = tl.load( + input_base_ptr + input_offset, mask=in_mask, other=min_val + ) + current_idx = h_in * in_w + w_in + + is_new_max = current_val > max_val_acc + max_val_acc = tl.where(is_new_max, current_val, max_val_acc) + max_idx_acc = tl.where(is_new_max & in_mask, current_idx, max_idx_acc) + + out_base_ptr = output_ptr + pid_nc * out_h * out_w + indices_base_ptr = indices_ptr + pid_nc * out_h * out_w + out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + output_block_ptr = ( + out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] + ) + indices_block_ptr = ( + indices_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] + ) + + out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) + tl.store(output_block_ptr, max_val_acc, mask=out_mask) + tl.store(indices_block_ptr, max_idx_acc, mask=out_mask) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 16}, num_warps=4), + triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 8}, num_warps=4), + triton.Config({"BLOCK_IN_H": 8, "BLOCK_IN_W": 32}, num_warps=4), + triton.Config({"BLOCK_IN_H": 32, "BLOCK_IN_W": 32}, num_warps=8), + triton.Config({"BLOCK_IN_H": 16, "BLOCK_IN_W": 64}, num_warps=8), + triton.Config({"BLOCK_IN_H": 64, "BLOCK_IN_W": 16}, num_warps=8), + ], + key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], +) +@triton.jit +def max_pool2d_backward_kernel( + grad_output_ptr, + indices_ptr, + grad_input_ptr, + # Shape info + in_h, + in_w, + out_h, + out_w, + # Strides for grad_output/indices + out_stride_nc, + out_stride_h, + out_stride_w, + # Pooling parameters + kernel_h: tl.constexpr, + kernel_w: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + padding_h: tl.constexpr, + padding_w: tl.constexpr, + dilation_h: tl.constexpr, + dilation_w: tl.constexpr, + # Tiling parameters + BLOCK_IN_H: tl.constexpr, + BLOCK_IN_W: tl.constexpr, +): + nc_idx = tl.program_id(0) + pid_hw = tl.program_id(1) + + num_w_blocks = tl.cdiv(in_w, BLOCK_IN_W) + h_block_idx = pid_hw // num_w_blocks + w_block_idx = pid_hw % num_w_blocks + + h_in_offsets = h_block_idx * BLOCK_IN_H + tl.arange(0, BLOCK_IN_H) + w_in_offsets = w_block_idx * BLOCK_IN_W + tl.arange(0, BLOCK_IN_W) + + current_input_flat_idx = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] + grad_acc = tl.zeros((BLOCK_IN_H, BLOCK_IN_W), dtype=tl.float32) + + indices_base_ptr = indices_ptr + nc_idx * out_stride_nc + grad_output_base_ptr = grad_output_ptr + nc_idx * out_stride_nc + + for kh in tl.static_range(0, kernel_h): + for kw in tl.static_range(0, kernel_w): + numerator_h = h_in_offsets[:, None] + padding_h - kh * dilation_h + numerator_w = w_in_offsets[None, :] + padding_w - kw * dilation_w + + valid_map_mask = (numerator_h % stride_h == 0) & ( + numerator_w % stride_w == 0 + ) + h_out = numerator_h // stride_h + w_out = numerator_w // stride_w + out_bounds_mask = ( + (h_out >= 0) & (h_out < out_h) & (w_out >= 0) & (w_out < out_w) + ) + load_mask = valid_map_mask & out_bounds_mask + + safe_h_out = tl.where(load_mask, h_out, 0) + safe_w_out = tl.where(load_mask, w_out, 0) + out_offsets = safe_h_out * out_stride_h + safe_w_out + + indices_block = tl.load( + indices_base_ptr + out_offsets, mask=load_mask, other=-1 + ) + match_mask = indices_block == current_input_flat_idx + + grad_block = tl.load( + grad_output_base_ptr + out_offsets, mask=match_mask, other=0.0 + ) + grad_acc += grad_block + + grad_input_base_ptr = grad_input_ptr + nc_idx * in_h * in_w + grad_input_offsets = h_in_offsets[:, None] * in_w + w_in_offsets[None, :] + store_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) + tl.store(grad_input_base_ptr + grad_input_offsets, grad_acc, mask=store_mask) + + +def _parse_pool_params(kernel_size, stride, padding, dilation): + def _parse_param(param, name, default=None): + if param is None: + return default + if isinstance(param, int): + return param, param + if isinstance(param, (list, tuple)) and len(param) == 2: + return param + raise ValueError(f"Invalid {name}: {param}") + + kernel_h, kernel_w = _parse_param(kernel_size, "kernel_size") + stride_h, stride_w = _parse_param(stride, "stride", default=(kernel_h, kernel_w)) + padding_h, padding_w = _parse_param(padding, "padding", default=(0, 0)) + dilation_h, dilation_w = _parse_param(dilation, "dilation", default=(1, 1)) + + if stride_h <= 0 or stride_w <= 0: + raise ValueError( + f"stride must be positive, but got stride=({stride_h}, {stride_w})" + ) + if padding_h < 0 or padding_w < 0: + raise ValueError( + f"padding must be non-negative, but got padding=({padding_h}, {padding_w})" + ) + if dilation_h <= 0 or dilation_w <= 0: + raise ValueError( + f"dilation must be positive, but got dilation=({dilation_h}, {dilation_w})" + ) + + return ( + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) + + +def max_pool2d_with_indices( + input: torch.Tensor, + kernel_size, + stride=None, + padding=0, + dilation=1, + ceil_mode=False, +): + logger.debug("GEMS MAX_POOL2D_WITH_INDICES FORWARD") + input = input.contiguous() + + params = _parse_pool_params(kernel_size, stride, padding, dilation) + ( + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) = params + + in_n, in_c, in_h, in_w = input.shape + out_h = max_pool2d_output_size( + in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode + ) + out_w = max_pool2d_output_size( + in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode + ) + + output = torch.empty( + (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype + ) + indices = torch.empty( + (in_n, in_c, out_h, out_w), device=input.device, dtype=torch.int64 + ) + + if output.numel() == 0: + return output, indices + + grid = lambda meta: ( + in_n * in_c, + triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), + ) + + max_pool2d_forward_kernel[grid]( + input, + output, + indices, + input.stride(0), + input.stride(1), + input.stride(2), + input.stride(3), + in_c, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) + + return output, indices + + +def max_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + indices: torch.Tensor, + kernel_size, + stride, + padding, + dilation, + ceil_mode, +): + logger.debug("GEMS MAX_POOL2D BACKWARD") + grad_output = grad_output.contiguous() + indices = indices.contiguous() + + params = _parse_pool_params(kernel_size, stride, padding, dilation) + ( + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) = params + + in_n, in_c, in_h, in_w = input.shape + out_h, out_w = grad_output.shape[2], grad_output.shape[3] + + grad_input = torch.zeros_like(input, dtype=torch.float32) + + if grad_input.numel() == 0: + return grad_input.to(grad_output.dtype) + + grid = lambda meta: ( + in_n * in_c, + triton.cdiv(in_h, meta["BLOCK_IN_H"]) * triton.cdiv(in_w, meta["BLOCK_IN_W"]), + ) + + out_stride_nc = out_h * out_w + out_stride_h = out_w + out_stride_w = 1 + + max_pool2d_backward_kernel[grid]( + grad_output, + indices, + grad_input, + in_h, + in_w, + out_h, + out_w, + out_stride_nc, + out_stride_h, + out_stride_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + ) + + return grad_input.to(grad_output.dtype) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index a7150e923..7c6369f56 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1108,6 +1108,99 @@ def test_accuracy_masked_select(shape, dtype, threshold): gems_assert_equal(res_out, ref_out) +MAXPOOL2D_CONFIGS = [ + # Classic case: 3x3 kernel, stride 2, padding 1 + ((4, 3, 32, 32), 3, 2, 1, 1, False), + # Non-square kernel and stride + ((8, 16, 28, 28), (3, 5), (1, 2), 1, 1, False), + # Test ceil_mode + ((2, 4, 15, 15), 3, 2, 1, 1, True), + # Test dilation + ((1, 1, 7, 7), 2, 1, 0, 2, False), + # Larger case from ResNet + ((1, 64, 56, 56), 3, 2, 1, 1, False), + # No padding + ((2, 8, 16, 16), 2, 2, 0, 1, False), + # Non-square padding + ((2, 8, 16, 20), 2, 2, (1, 0), 1, False), +] + + +@pytest.mark.max_pool2d +@pytest.mark.parametrize( + "shape, kernel_size, stride, padding, dilation, ceil_mode", MAXPOOL2D_CONFIGS +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max_pool2d( + shape, kernel_size, stride, padding, dilation, ceil_mode, dtype +): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out, ref_indices = torch.nn.functional.max_pool2d_with_indices( + ref_inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + res_out, res_indices = flag_gems.max_pool2d_with_indices( + inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + gems_assert_close(res_out, ref_out, dtype) + + +@pytest.mark.max_pool2d_backward +@pytest.mark.parametrize( + "shape, kernel_size, stride, padding, dilation, ceil_mode", MAXPOOL2D_CONFIGS +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_max_pool2d_backward( + shape, kernel_size, stride, padding, dilation, ceil_mode, dtype +): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + ref_inp = to_reference(inp) + ref_out, _ = torch.nn.functional.max_pool2d_with_indices( + ref_inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + out_grad = torch.randn_like(ref_out, device=flag_gems.device) + ref_grad = to_reference(out_grad) + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + _, res_indices = flag_gems.max_pool2d_with_indices( + inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + res_in_grad = flag_gems.max_pool2d_backward( + out_grad, + inp, + res_indices, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + SHAPE_CONV1D = [ ((32, 2, 4), (17, 2, 2)), ((32, 15, 6), (17, 15, 2)),