From aab7106d74d428b8af77ef96b28a8c6215becf48 Mon Sep 17 00:00:00 2001 From: zhoubo567 <781266327@qq.com> Date: Thu, 4 Sep 2025 19:07:56 +0800 Subject: [PATCH 1/5] optimize_mean --- src/flag_gems/ops/mean.py | 272 ++++++++++++++++++++++++++++++++------ 1 file changed, 233 insertions(+), 39 deletions(-) diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index bf8a1f898..0c0d3bbc7 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -1,5 +1,6 @@ import logging import math +from functools import reduce import torch import triton @@ -21,12 +22,21 @@ def mean_kernel_1( M, BLOCK_SIZE: tl.constexpr, ): + # accumulation dtype + if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( + inp.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = inp.dtype.element_ty + pid = tle.program_id(0) offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) inp_ptrs = inp + offset mask = offset < M - inp_val = tl.load(inp_ptrs, mask=mask, other=0.0) - sum_val = tl.sum(inp_val, axis=0) + + inp_val = tl.load(inp_ptrs, mask=mask, other=0).to(cdtype) + sum_val = tl.sum(inp_val) mid_ptr = mid + pid tl.store(mid_ptr, sum_val) @@ -34,12 +44,21 @@ def mean_kernel_1( @libentry() @triton.jit def mean_kernel_2(mid, out, M, MID_SIZE, BLOCK_MID: tl.constexpr): + if tl.constexpr(mid.dtype.element_ty == tl.float16) or tl.constexpr( + mid.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = mid.dtype.element_ty + offset = tl.arange(0, BLOCK_MID) mid_ptrs = mid + offset mask = offset < MID_SIZE - mid_val = tl.load(mid_ptrs, mask=mask, other=0.0) - sum_val = tl.sum(mid_val, axis=0) / M - tl.store(out, sum_val) + mid_val = tl.load(mid_ptrs, mask=mask, other=0).to(cdtype) + sum_val = tl.sum(mid_val) + # divide by total element count M to get mean + mean_val = sum_val / M + tl.store(out, mean_val) def mean(inp, *, dtype=None): @@ -60,57 +79,232 @@ def mean(inp, *, dtype=None): return out +@libentry() +@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) +@triton.jit +def mean_dim_kernel_non_inner( + output_ptr, + input_ptr, + M, + N, + K, + TILE_N: tl.constexpr, + TILE_K: tl.constexpr, + ONE_TILE_PER_CTA: tl.constexpr, +): + # accumulation dtype + if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( + input_ptr.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = input_ptr.dtype.element_ty + + pid_m = tle.program_id(0) + pid_k = tle.program_id(1) + + k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)[None, :] + + if ONE_TILE_PER_CTA: + n_offsets = tl.arange(0, TILE_N)[:, None] + inp_offset = pid_m * N * K + n_offsets * K + k_offsets + mask = (n_offsets < N) & (k_offsets < K) + input_ptrs = input_ptr + inp_offset + inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) + # sum along reduction axis (N) -> keep dims so axis 0 corresponds to TILE_K + summed = tl.sum(inp, axis=0, keep_dims=True) + # divide by N to get mean + out = summed / N + out_offset = pid_m * K + k_offsets + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out, mask=k_offsets < K) + else: + sum_tile = tl.zeros([TILE_N, TILE_K], dtype=cdtype) + for start_n in range(0, N, TILE_N): + n_offsets = start_n + tl.arange(0, TILE_N)[:, None] + inp_offsets = pid_m * N * K + n_offsets * K + k_offsets + mask = (n_offsets < N) & (k_offsets < K) + inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) + sum_tile += inp + summed = tl.sum(sum_tile, axis=0, keep_dims=True) + out = summed / N + out_offset = pid_m * K + k_offsets + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out, mask=k_offsets < K) + + +@libentry() +@triton.heuristics(runtime.get_heuristic_config("softmax_inner")) +@triton.jit +def mean_dim_kernel_inner( + output_ptr, + input_ptr, + M, + N, + TILE_N: tl.constexpr, + ONE_TILE_PER_CTA: tl.constexpr, +): + if tl.constexpr(input_ptr.dtype.element_ty == tl.float16) or tl.constexpr( + input_ptr.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = input_ptr.dtype.element_ty + + pid_m = tle.program_id(0) + if ONE_TILE_PER_CTA: + n_offsets = tl.arange(0, TILE_N) + inp_offset = pid_m * N + n_offsets + input_ptrs = input_ptr + inp_offset + mask = n_offsets < N + inp = tl.load(input_ptrs, mask=mask, other=0).to(cdtype) + summed = tl.sum(inp, axis=0) + out = summed / N + out_offset = pid_m + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out) + else: + sum_vec = tl.zeros( + [ + TILE_N, + ], + dtype=cdtype, + ) + for start_n in range(0, N, TILE_N): + n_offsets = start_n + tl.arange(0, TILE_N) + inp_offsets = pid_m * N + n_offsets + mask = n_offsets < N + inp = tl.load(input_ptr + inp_offsets, mask=mask, other=0).to(cdtype) + sum_vec += inp + summed = tl.sum(sum_vec, axis=0) + out = summed / N + out_offset = pid_m + output_ptrs = output_ptr + out_offset + tl.store(output_ptrs, out) + + @libentry() @libtuner( configs=runtime.get_tuned_config("naive_reduction"), key=["M", "N"], ) @triton.jit -def mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - # Map the program id to the row of X it should compute. +def mean_dim_kernel( + inp, + out, + M, + N, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + if tl.constexpr(inp.dtype.element_ty == tl.float16) or tl.constexpr( + inp.dtype.element_ty == tl.bfloat16 + ): + cdtype = tl.float32 + else: + cdtype = inp.dtype.element_ty + + # Map the program id to the row of inp it should compute. pid = tle.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] - X = X + pid * N - Mean = Mean + pid + inp = inp + pid * N + out = out + pid row_mask = pid < M - # Compute mean - _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + _sum = tl.zeros([BLOCK_M, BLOCK_N], dtype=cdtype) for off in range(0, N, BLOCK_N): cols = off + tl.arange(0, BLOCK_N)[None, :] col_mask = cols < N mask = row_mask and col_mask - a = tl.load(X + cols, mask, other=0.0).to(tl.float32) - _mean += a - mean = tl.sum(_mean, axis=1) / N - mean = mean[:, None] - tl.store(Mean, mean, row_mask) - + a = tl.load(inp + cols, mask, other=0).to(cdtype) + _sum += a + summed = tl.sum(_sum, axis=1)[:, None] + mean = summed / N + tl.store(out, mean, row_mask) -def mean_dim(x, dim, keepdim=False, *, dtype=None): - logger.debug("GEMS MEAN DIM") +def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): + logger.debug("GEMS MEAN_DIM") if dtype is None: - dtype = x.dtype - if dim is None: - out = mean(x, dtype=dtype) + dtype = inp.dtype + if dtype is torch.bool: + inp = inp.to(torch.int64) + dtype = torch.int64 + + if dim == []: + # mean over all elements + if not keepdim: + return mean(inp, dtype=dtype) + else: + dim_num = inp.ndim + return torch.reshape(mean(inp, dtype=dtype), [1] * dim_num) + + shape = list(inp.shape) + + # -------- normalize dim to a list of ints -------- + if isinstance(dim, int): + dim = [dim] + else: + try: + dim = list(dim) + except TypeError: + raise TypeError( + f"dim must be an int, iterable of ints, or [], got {type(dim)}" + ) + + dim = [d % inp.ndim for d in dim] + # ------------------------------------------------- + + if len(dim) == 1: + dim0 = dim[0] + N = inp.shape[dim0] # reduction length + # product of dims before dim0; use initializer 1 for empty slice + M = reduce(lambda x, y: x * y, shape[:dim0], 1) + inp = inp.contiguous() + K = inp.numel() // M // N + shape[dim0] = 1 + if out is None: + out = torch.empty(shape, dtype=dtype, device=inp.device) + + with torch_device_fn.device(inp.device): + if K > 1: + grid = lambda meta: (M, triton.cdiv(K, meta["TILE_K"]), 1) + mean_dim_kernel_non_inner[grid]( + out, + inp, + M, + N, + K, + ) + else: + grid = (M, 1, 1) + mean_dim_kernel_inner[grid]( + out, + inp, + M, + N, + ) if not keepdim: - out = out.reshape([1] * x.ndim) + out = out.squeeze(dim=dim0) return out + else: + inp = dim_compress(inp, dim) + N = 1 + for i in dim: + N *= shape[i] + shape[i] = 1 + M = inp.numel() // N + if out is None: + out = torch.empty(shape, dtype=dtype, device=inp.device) - shape = list(x.shape) - dim = [d % x.ndim for d in dim] - x = dim_compress(x, dim) - N = 1 - for i in dim: - N *= shape[i] - shape[i] = 1 - M = x.numel() // N - out = torch.empty(shape, dtype=dtype, device=x.device) - grid = lambda META: (triton.cdiv(M, META["BLOCK_M"]),) - - with torch_device_fn.device(x.device): - mean_dim_kernel[grid](x, out, M, N) - if not keepdim: - out = out.squeeze(dim) - return out + grid = lambda meta: (triton.cdiv(M, meta["BLOCK_M"]),) + with torch_device_fn.device(inp.device): + mean_dim_kernel[grid](inp, out, M, N) + if not keepdim: + out = out.squeeze(dim=dim) + return out + + +def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): + logger.debug("GEMS MEAN_DIM (wrapper)") + return mean_dim_comm(inp, dim, keepdim, dtype=dtype) From c554fe5058c8aa06bfa9468799c5c73e9b510e90 Mon Sep 17 00:00:00 2001 From: henghengxiedaima <1149963331@qq.com> Date: Wed, 17 Sep 2025 22:42:05 +0800 Subject: [PATCH 2/5] merge local changes with remote --- src/flag_gems/ops/mean.py | 4 +- .../_nvidia/heuristics_config_utils.py | 47 +++++++++++++++++++ 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index 0c0d3bbc7..bca5978e9 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -80,7 +80,7 @@ def mean(inp, *, dtype=None): @libentry() -@triton.heuristics(runtime.get_heuristic_config("softmax_non_inner")) +@triton.heuristics(runtime.get_heuristic_config("mean_non_inner")) @triton.jit def mean_dim_kernel_non_inner( output_ptr, @@ -307,4 +307,4 @@ def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): logger.debug("GEMS MEAN_DIM (wrapper)") - return mean_dim_comm(inp, dim, keepdim, dtype=dtype) + return mean_dim_comm(inp, dim, keepdim, dtype=dtype) \ No newline at end of file diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index 326c5d3d2..017a863e7 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -2,6 +2,11 @@ import triton +_MIN_TILE_N = 64 +_MAX_TILE_N_PER_ROW = 4096 +_MAX_ONE_TILE_N = 2048 + + def simple_elementwise_blocksize_heur(args): return 1024 @@ -232,6 +237,42 @@ def vdot_heur_block_size(args): return 1024 +def mean_heur_tile_k(args): + MAX_TILE_K = 512 + NUM_SMS = torch.cuda.get_device_properties( + torch.cuda.current_device() + ).multi_processor_count + tile_k = 1 + upper_bound = min(args["K"], MAX_TILE_K) + max_tile_k_allowed_by_tile_n = max(1, _MAX_TILE_N_PER_ROW // _MIN_TILE_N) + upper_bound = min(upper_bound, max_tile_k_allowed_by_tile_n) + while tile_k <= upper_bound: + num_blocks = args["M"] * triton.cdiv(args["K"], tile_k) + num_waves = num_blocks / NUM_SMS + if (num_waves > 1) and (tile_k * 2 <= upper_bound): + tile_k *= 2 + else: + break + return tile_k + + +def mean_heur_tile_n_non_inner(args): + tile_k = args.get("TILE_K", 1) + limit_by_k = max(1, _MAX_TILE_N_PER_ROW // tile_k) + N = args.get("N", 1) + desired = min(max(N, _MIN_TILE_N), limit_by_k) + desired = min(desired, _MAX_ONE_TILE_N, limit_by_k) + tile_n = triton.next_power_of_2(desired) + if tile_n > limit_by_k: + tile_n = limit_by_k + tile_n = max(tile_n, _MIN_TILE_N) + return tile_n + + +def mean_heur_one_tile_per_cta(args): + return args["TILE_N"] >= args["N"] + + HEURISTICS_CONFIGS = { "argmax": { "BLOCK_M": argmax_heur_block_m, @@ -279,6 +320,12 @@ def vdot_heur_block_size(args): "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, "num_warps": softmax_heur_num_warps_non_inner, }, + "mean_non_inner": { + "TILE_K": mean_heur_tile_k, + "TILE_N": mean_heur_tile_n_non_inner, + "ONE_TILE_PER_CTA": mean_heur_one_tile_per_cta, + "num_warps": softmax_heur_num_warps_non_inner, + }, "softmax_inner": { "TILE_N": softmax_heur_tile_n_inner, "ONE_TILE_PER_CTA": softmax_heur_one_tile_per_cta, From 14a1235c96d8a47839d30bb29b67de34fcb2db07 Mon Sep 17 00:00:00 2001 From: you-and-you <1823382186@qq.com> Date: Tue, 25 Nov 2025 16:13:40 +0800 Subject: [PATCH 3/5] pre-commit --- src/flag_gems/ops/mean.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index bca5978e9..469e32c8e 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -307,4 +307,4 @@ def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): logger.debug("GEMS MEAN_DIM (wrapper)") - return mean_dim_comm(inp, dim, keepdim, dtype=dtype) \ No newline at end of file + return mean_dim_comm(inp, dim, keepdim, dtype=dtype) From 01bf0d0b9bbc0ac8ad339eadcb89713f37fa0b61 Mon Sep 17 00:00:00 2001 From: you-and-you <1823382186@qq.com> Date: Wed, 26 Nov 2025 10:48:07 +0800 Subject: [PATCH 4/5] update --- src/flag_gems/ops/mean.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/flag_gems/ops/mean.py b/src/flag_gems/ops/mean.py index 469e32c8e..ffbbfd303 100644 --- a/src/flag_gems/ops/mean.py +++ b/src/flag_gems/ops/mean.py @@ -307,4 +307,5 @@ def mean_dim_comm(inp, dim=None, keepdim=False, *, dtype=None, out=None): def mean_dim(inp, dim=None, keepdim=False, *, dtype=None): logger.debug("GEMS MEAN_DIM (wrapper)") + return mean_dim_comm(inp, dim, keepdim, dtype=dtype) From 27973cbf749f01d47caea841248a4dcd0ad4e6ca Mon Sep 17 00:00:00 2001 From: you-and-you <1823382186@qq.com> Date: Fri, 28 Nov 2025 16:05:50 +0800 Subject: [PATCH 5/5] pre-commit --- .../_nvidia/heuristics_config_utils.py | 1 - .../coverage_diff_discard-checkpoint.py | 79 +++++++++++++++++++ tools/code_coverage/coverage_diff_discard.py | 4 +- 3 files changed, 81 insertions(+), 3 deletions(-) create mode 100644 tools/code_coverage/.ipynb_checkpoints/coverage_diff_discard-checkpoint.py diff --git a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py index 017a863e7..320ad0aed 100644 --- a/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py +++ b/src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py @@ -1,7 +1,6 @@ import torch import triton - _MIN_TILE_N = 64 _MAX_TILE_N_PER_ROW = 4096 _MAX_ONE_TILE_N = 2048 diff --git a/tools/code_coverage/.ipynb_checkpoints/coverage_diff_discard-checkpoint.py b/tools/code_coverage/.ipynb_checkpoints/coverage_diff_discard-checkpoint.py new file mode 100644 index 000000000..bbcd09894 --- /dev/null +++ b/tools/code_coverage/.ipynb_checkpoints/coverage_diff_discard-checkpoint.py @@ -0,0 +1,79 @@ +#!/usr/bin/env python + +import os +import re +import sys + + +def get_discard_file_lines(discard_file): + flag_gems_root = os.environ.get("FlagGemsROOT") + dicard_file_lines = {} + with open(discard_file) as f: + for line in f: + line = line.strip() + + if line.startswith(flag_gems_root + "/"): + current_file = line[len(flag_gems_root) + 1 :] + dicard_file_lines[current_file] = [] + continue + + elif line.startswith("--- "): + pattern = r"(\d+) : (\d+)" + match = re.search(pattern, line) + if match: + start, end = map(int, match.groups()) + # Note: Use (start + 1) instead of start + # Because we take the definition of the JIT function into account + for i in range(start + 1, end + 1): + dicard_file_lines[current_file].append(i) + return dicard_file_lines + + +def get_info_file_lines(info_file, discard_file): + discard_file_lines = get_discard_file_lines(discard_file) + discard_lines = [] + num_rm_lines = 0 + base_path = os.environ.get("FlagGemsROOT") + "/" + + with open(info_file) as f: + for line in f: + line = line.strip() + if line.startswith("SF:"): + num_rm_lines = 0 + current_file = line[3:] + if current_file.startswith(base_path): + current_file = current_file[len(base_path) :] + discard_lines = discard_file_lines.get(current_file, []) + elif line.startswith("DA:"): + da = line[3:].split(",") + if int(da[0]) in discard_lines: + num_rm_lines -= 1 + continue + else: + print(line) + continue + elif line.startswith("LF:"): + lf = line.split(":") + print(f"LF:{int(lf[1]) + num_rm_lines}") + continue + elif line.startswith("LH:"): + lh = line.split(":") + print(f"LH:{int(lh[1]) + num_rm_lines}") + continue + print(line) + + +if __name__ == "__main__": + if len(sys.argv) < 3: + print( + "usage: coverage_diff.py info_file discard_file > python-coverage-discard-diff.info" + ) + sys.exit(1) + + info_file, discard_file = sys.argv[1], sys.argv[2] + + if not (os.path.isfile(info_file) or os.path.isfile(discard_file)): + print("Both info_file and discard_file must exist.") + sys.exit(1) + + get_info_file_lines(info_file, discard_file) diff --git a/tools/code_coverage/coverage_diff_discard.py b/tools/code_coverage/coverage_diff_discard.py index 2e915add5..bbcd09894 100644 --- a/tools/code_coverage/coverage_diff_discard.py +++ b/tools/code_coverage/coverage_diff_discard.py @@ -54,11 +54,11 @@ def get_info_file_lines(info_file, discard_file): continue elif line.startswith("LF:"): lf = line.split(":") - print(f"LF:{ int(lf[1])+ num_rm_lines}") + print(f"LF:{int(lf[1]) + num_rm_lines}") continue elif line.startswith("LH:"): lh = line.split(":") - print(f"LH:{ int(lh[1])+ num_rm_lines}") + print(f"LH:{int(lh[1]) + num_rm_lines}") continue print(line)