diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index bf8a1f898..bca5978e9 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -1,5 +1,6 @@ import logging import math +from functools import reduce import torch import triton @@ -21,12 +22,21 @@ def mean_kernel_1( M, BLOCK_SIZE: tl.constexpr, ): + # accumulation dtype + if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( + inp.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = inp.dtype.element_ty + pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) - sum_val = tl.sum(inp_val, axis=0) + + inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) + sum_val = tl.sum(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, sum_val) @@ -34,12 +44,21 @@ def mean_kernel_1( @libentry() @triton.jit def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): + if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( + mid.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = mid.dtype.element_ty + offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < MID_SIZE - mid_val = tl.load(mid_ptrs, mask=mask, other=0.0) - sum_val = tl.sum(mid_val, axis=0) / M - tl.store(out, sum_val) + mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) + sum_val = tl.sum(mid_val) + # divide by total element count M to get mean + mean_val = sum_val / M + tl.store(out, mean_val) def mean(inp, *, dtype=None): @@ -60,57 +79,232 @@ def mean(inp, *, dtype=None): return out +@libentry() +@triton.heuristics(runtime.get_heuristic_config("mean_non_inner")) +@triton.jit +def mean_dim_kernel_non_inner( + output_ptr, + input_ptr, + M, + N, + K, + TILE_N: tl.constexpr, + TILE_K: tl.constexpr, + ONE_TILE_PER_CTA: tl.constexpr, +): + # accumulation dtype + if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( + input_ptr.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = input_ptr.dtype.element_ty + + pid_m = tle.program_id(0) + pid_k = tle.program_id(1) + + k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] + + if ONE_TILE_PER_CTA: + n_offsets = tl.arange(0, TILE_N)[:, None] + inp_offset = pid_m * N * K + n_offsets * K + k_offsets + mask = (n_offsets < N) & (k_offsets < K) + input_ptrs = input_ptr + inp_offset + inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) + # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K + summed = tl.sum(inp, axis=0, keep_dims=True) + # divide by N to get mean + out = summed / N + out_offset = pid_m * K + k_offsets + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out, mask=k_offsets < K) + else: + sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype) + for start_n in range(0, N, TILE_N): + n_offsets = start_n + tl.arange(0, TILE_N)[:, None] + inp_offsets = pid_m * N * K + n_offsets * K + k_offsets + mask = (n_offsets < N) & (k_offsets < K) + inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) + sum_tile += inp + summed = tl.sum(sum_tile, axis=0, keep_dims=True) + out = summed / N + out_offset = pid_m * K + k_offsets + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out, mask=k_offsets < K) + + +@libentry() +@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) +@triton.jit +def mean_dim_kernel_inner( + output_ptr, + input_ptr, + M, + N, + TILE_N: tl.constexpr, + ONE_TILE_PER_CTA: tl.constexpr, +): + if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( + input_ptr.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = input_ptr.dtype.element_ty + + pid_m = tle.program_id(0) + if ONE_TILE_PER_CTA: + n_offsets = tl.arange(0, TILE_N) + inp_offset = pid_m * N + n_offsets + input_ptrs = input_ptr + inp_offset + mask = n_offsets < N + inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) + summed = tl.sum(inp, axis=0) + out = summed / N + out_offset = pid_m + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out) + else: + sum_vec = tl.zeros( + [ + TILE_N, + ], + dtype=cdtype, + ) + for start_n in range(0, N, TILE_N): + n_offsets = start_n + tl.arange(0, TILE_N) + inp_offsets = pid_m * N + n_offsets + mask = n_offsets < N + inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) + sum_vec += inp + summed = tl.sum(sum_vec, axis=0) + out = summed / N + out_offset = pid_m + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out) + + @libentry() @libtuner( configs=runtime.get_tuned_config("naive_reduction"), key=["M", "N"], ) @triton.jit -def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - # Map the program id to the row of X it should compute. +def mean_dim_kernel( + inp, + out, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( + inp.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = inp.dtype.element_ty + + # Map the program id to the row of inp it should compute. pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] - X = X + pid * N - Mean = Mean + pid + inp = inp + pid * N + out = out + pid row_mask = pid < M - # Compute mean - _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) for off in range(0, N, BLOCK_N): cols = off + tl.arange(0, BLOCK_N)[None, :] col_mask = cols < N mask = row_mask and col_mask - a = tl.load(X + cols, mask, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=1) / N - mean = mean[:, None] - tl.store(Mean, mean, row_mask) - + a = tl.load(inp + cols, mask, other=0).to(cdtype) + _sum += a + summed = tl.sum(_sum, axis=1)[:, None] + mean = summed / N + tl.store(out, mean, row_mask) -def mean_dim(x, dim, keepdim=False, *, dtype=None): - logger.debug("GEMS MEAN DIM") +def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): + logger.debug("GEMS MEAN_DIM") if dtype is None: - dtype = x.dtype - if dim is None: - out = mean(x, dtype=dtype) + dtype = inp.dtype + if dtype is torch.bool: + inp = inp.to(torch.int64) + dtype = torch.int64 + + if dim == []: + # mean over all elements + if not keepdim: + return mean(inp, dtype=dtype) + else: + dim_num = inp.ndim + return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num) + + shape = list(inp.shape) + + # -------- normalize dim to a list of ints -------- + if isinstance(dim, int): + dim = [dim] + else: + try: + dim = list(dim) + except TypeError: + raise TypeError( + f"dim must be an int, iterable of ints, or [], got {type(dim)}" + ) + + dim = [d % inp.ndim for d in dim] + # ------------------------------------------------- + + if len(dim) == 1: + dim0 = dim[0] + N = inp.shape[dim0] # reduction length + # product of dims before dim0; use initializer 1 for empty slice + M = reduce(lambda x, y: x * y, shape[:dim0], 1) + inp = inp.contiguous() + K = inp.numel() // M // N + shape[dim0] = 1 + if out is None: + out = torch.empty(shape, dtype=dtype, device=inp.device) + + with torch_device_fn.device(inp.device): + if K > 1: + grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) + mean_dim_kernel_non_inner[grid]( + out, + inp, + M, + N, + K, + ) + else: + grid = (M, 1, 1) + mean_dim_kernel_inner[grid]( + out, + inp, + M, + N, + ) if not keepdim: - out = out.reshape([1] * x.ndim) + out = out.squeeze(dim=dim0) return out + else: + inp = dim_compress(inp, dim) + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = inp.numel() // N + if out is None: + out = torch.empty(shape, dtype=dtype, device=inp.device) - shape = list(x.shape) - dim = [d % x.ndim for d in dim] - x = dim_compress(x, dim) - N = 1 - for i in dim: - N *= shape[i] - shape[i] = 1 - M = x.numel() // N - out = torch.empty(shape, dtype=dtype, device=x.device) - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) - - with torch_device_fn.device(x.device): - mean_dim_kernel[grid](x, out, M, N) - if not keepdim: - out = out.squeeze(dim) - return out + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + with torch_device_fn.device(inp.device): + mean_dim_kernel[grid](inp, out, M, N) + if not keepdim: + out = out.squeeze(dim=dim) + return out + + +def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): + logger.debug("GEMS MEAN_DIM (wrapper)") + return mean_dim_comm(inp, dim, keepdim, dtype=dtype) \ No newline at end of file diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index 326c5d3d2..017a863e7 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -2,6 +2,11 @@ import triton +_MIN_TILE_N = 64 +_MAX_TILE_N_PER_ROW = 4096 +_MAX_ONE_TILE_N = 2048 + + def simple_elementwise_blocksize_heur(args): return 1024 @@ -232,6 +237,42 @@ def vdot_heur_block_size(args): return 1024 +def mean_heur_tile_k(args): + MAX_TILE_K = 512 + NUM_SMS = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + tile_k = 1 + upper_bound = min(args["K"], MAX_TILE_K) + max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N) + upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n) + while tile_k <= upper_bound: + num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) + num_waves = num_blocks / NUM_SMS + if (num_waves > 1) and (tile_k * 2 <= upper_bound): + tile_k *= 2 + else: + break + return tile_k + + +def mean_heur_tile_n_non_inner(args): + tile_k = args.get("TILE_K", 1) + limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k) + N = args.get("N", 1) + desired = min(max(N, _MIN_TILE_N), limit_by_k) + desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) + tile_n = triton.next_power_of_2(desired) + if tile_n > limit_by_k: + tile_n = limit_by_k + tile_n = max(tile_n, _MIN_TILE_N) + return tile_n + + +def mean_heur_one_tile_per_cta(args): + return args["TILE_N"] >= args["N"] + + HEURISTICS_CONFIGS = { "argmax": { "BLOCK_M": argmax_heur_block_m, @@ -279,6 +320,12 @@ def vdot_heur_block_size(args): "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, "num_warps": softmax_heur_num_warps_non_inner, }, + "mean_non_inner": { + "TILE_K": mean_heur_tile_k, + "TILE_N": mean_heur_tile_n_non_inner, + "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta, + "num_warps": softmax_heur_num_warps_non_inner, + }, "softmax_inner": { "TILE_N": softmax_heur_tile_n_inner, "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta,