diff --git a/.github/workflows/pr-perfbench-bot.yml b/.github/workflows/pr-perfbench-bot.yml index e6954bcc4..868b5b51c 100644 --- a/.github/workflows/pr-perfbench-bot.yml +++ b/.github/workflows/pr-perfbench-bot.yml @@ -6,7 +6,7 @@ on: - created permissions: - contents: read + contents: write concurrency: group: "${{ github.workflow }}-${{ github.ref }}" @@ -53,8 +53,12 @@ jobs: run: | python -m venv tll source tll/bin/activate - pip install -r requirements-test.txt - pip install . + export PIP_CONFIG_FILE=/dev/null + export PYTHONUSERBASE="" + pip config unset global.user + pip config unset user.user + pip install --no-user -r requirements-test.txt + pip install --no-user . - name: Install original version run: | @@ -64,25 +68,70 @@ jobs: git checkout main python -m venv tl source tl/bin/activate - pip install -r requirements-test.txt - pip install . + export PIP_CONFIG_FILE=/dev/null + export PYTHONUSERBASE="" + pip config unset global.user || true + pip config unset user.user || true + pip install --no-user -r requirements-test.txt + pip install --no-user . - name: Run performance test id: perfbench run: | source tl/bin/activate python maint/scripts/ci_performance.py + - name: Read markdown table + id: read_md + run: | + echo "content<> $GITHUB_OUTPUT + cat bench.md >> $GITHUB_OUTPUT + echo "EOF" >> $GITHUB_OUTPUT + - name: Upload PNG to GitHub and get URL + id: upload_png + uses: actions/github-script@v8 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const fs = require('fs'); + const content = fs.readFileSync('bench.png').toString('base64'); + // Create blob in the repo + const blob = await github.rest.git.createBlob({ + owner: context.repo.owner, + repo: context.repo.repo, + content: content, + encoding: "base64", + }); + // Attach blob as a tree item + const tree = await github.rest.git.createTree({ + owner: context.repo.owner, + repo: context.repo.repo, + tree: [{ + path: `bench_${context.runId}.png`, + mode: '100644', + type: 'blob', + sha: blob.data.sha + }] + }); + // Raw file URL (works for embedding image) + const url = `https://raw.githubusercontent.com/${context.repo.owner}/${context.repo.repo}/${tree.data.sha}/bench_${context.runId}.png` + core.setOutput("url", url); - name: Post test results as PR comment uses: actions/github-script@v8 with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | + const md = `${{ steps.read_md.outputs.content }}`; + const img = `${{ steps.upload_png.outputs.url }}`; github.rest.issues.createComment({ owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, - body: 'šŸ“Š ​**Performance Test Results** (triggered by @' + context.payload.comment.user.login + '):\n\n' + - 'Run listed here: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\n\n' + - "${{ steps.perfbench.outputs.stdout }}" + body: + 'šŸ“Š **Performance Test Results** (triggered by @' + + context.payload.comment.user.login + ')\n\n' + + 'Run: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}\n\n' + + md + + '\n\nšŸ“ˆ **Speedup Plot:**\n\n' + + `![Speedup Plot](${img})` }) diff --git a/examples/attention_sink/bench_example_attention_sink.py b/examples/attention_sink/bench_example_attention_sink.py new file mode 100644 index 000000000..c38afd3ed --- /dev/null +++ b/examples/attention_sink/bench_example_attention_sink.py @@ -0,0 +1,65 @@ +import tilelang.tools.bench +import example_mha_sink_fwd_bhsd +import example_mha_sink_fwd_bhsd_wgmma_pipelined +import example_mha_sink_bwd_bhsd +import example_gqa_sink_bwd_bhsd +import example_gqa_sink_fwd_bhsd_wgmma_pipelined + + +def bench_example_mha_sink_fwd_bhsd(): + tilelang.tools.bench.process_func(example_mha_sink_fwd_bhsd.benchmark) + + +def bench_example_mha_sink_fwd_bhsd_sliding_window(): + tilelang.tools.bench.process_func( + example_mha_sink_fwd_bhsd.benchmark, + name="example_mha_sink_fwd_bhsd_sliding_window", + window_size=128) + + +def bench_example_mha_sink_fwd_bhsd_wgmma_pipelined(): + tilelang.tools.bench.process_func(example_mha_sink_fwd_bhsd_wgmma_pipelined.benchmark) + + +def bench_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + tilelang.tools.bench.process_func( + example_mha_sink_fwd_bhsd_wgmma_pipelined.benchmark, + name="example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window", + window_size=128) + + +def bench_example_gqa_sink_fwd_bhsd_wgmma_pipelined(): + tilelang.tools.bench.process_func(example_gqa_sink_fwd_bhsd_wgmma_pipelined.benchmark) + + +def bench_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): + tilelang.tools.bench.process_func( + example_gqa_sink_fwd_bhsd_wgmma_pipelined.benchmark, + name="example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window", + window_size=128) + + +def bench_example_mha_sink_bwd_bhsd(): + tilelang.tools.bench.process_func(example_mha_sink_bwd_bhsd.benchmark) + + +def bench_example_mha_sink_bwd_bhsd_sliding_window(): + tilelang.tools.bench.process_func( + example_mha_sink_bwd_bhsd.benchmark, + name="example_mha_sink_bwd_bhsd_sliding_window", + window_size=128) + + +def bench_example_gqa_sink_bwd_bhsd(): + tilelang.tools.bench.process_func(example_gqa_sink_bwd_bhsd.benchmark) + + +def bench_example_gqa_sink_bwd_bhsd_sliding_window(): + tilelang.tools.bench.process_func( + example_gqa_sink_bwd_bhsd.benchmark, + name="example_gqa_sink_bwd_bhsd_sliding_window", + window_size=128) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 1b7de6b6f..bdc8a9a3a 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -196,6 +196,41 @@ def main( print("Speedup: {:.2f}x".format(latency_triton / latency_tilelang)) +def benchmark( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype) + + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + latency = do_bench(kernel(Q, K, V, sinks), warmup=500, rep=10000) + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index f50b94535..f43761ea2 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -182,6 +182,35 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + return do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index b442505fc..58dc0a627 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -507,6 +507,49 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark( + BATCH: int = 1, + H: int = 8, + N_CTX: int = 512, + D_HEAD: int = 64, + groups: int = 2, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn(BATCH, H // groups, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + V = torch.randn_like(K) + sinks = torch.randn(H, dtype=torch_dtype, device="cuda") + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, groups, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + q_shape = (BATCH, H, N_CTX, D_HEAD) + head_kv = H // groups + kv_shape = (BATCH, head_kv, N_CTX, D_HEAD) + dq = torch.zeros(q_shape, dtype=torch.float32, device="cuda") + dk = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + dv = torch.zeros(kv_shape, dtype=torch.float32, device="cuda") + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, warmup=500, rep=10000) + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='Batch size') diff --git a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py index 8d1817267..5a06896b7 100644 --- a/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py @@ -327,6 +327,40 @@ def main( print("Tilelang: {:.2f} TFlops".format(total_flops / latency_tilelang * 1e-9)) +def benchmark( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + groups: int = 8, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False, +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + groups, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, groups, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000) + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index b9fa0fd97..9a433b3ce 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -501,6 +501,46 @@ def tl_bwd(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark( + BATCH: int = 1, + H: int = 1, + N_CTX: int = 512, + D_HEAD: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", +): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + with torch.no_grad(): + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, dtype=torch_dtype, device="cuda") + K = torch.randn_like(Q) + V = torch.randn_like(Q) + sinks = torch.randn(H, dtype=torch_dtype, device=Q.device) + dO = torch.randn_like(Q) + fwd = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, window_size=window_size, dtype=dtype) + O, lse = fwd(Q, K, V, sinks) + + def maybe_contiguous(x): + return x if x.stride(-1) == 1 else x.contiguous() + + do, q, k, v, sinks_c, o = [maybe_contiguous(x) for x in (dO, Q, K, V, sinks, O)] + k_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD, dtype=dtype) + Delta = k_prep(o, do) + k_bwd = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, window_size, dtype=dtype) + k_dsink = flashattn_bwd_dsink(BATCH, H, N_CTX, dtype=dtype) + shape = (BATCH, H, N_CTX, D_HEAD) + dq = torch.zeros(shape, dtype=torch.float32, device=Q.device) + dk = torch.empty(shape, dtype=torch_dtype, device=Q.device) + dv = torch.empty(shape, dtype=torch_dtype, device=Q.device) + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + _ = k_dsink(sinks_c, Delta, lse).sum(0).sum(1) + + def run_kernel_only(): + k_bwd(q, k, v, do, lse, Delta, dq, dk, dv) + + latency_ms = do_bench(run_kernel_only, warmup=500, rep=10000) + return latency_ms + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='Batch size') diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd.py b/examples/attention_sink/example_mha_sink_fwd_bhsd.py index 0ccb69588..8e3e242e8 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd.py @@ -309,6 +309,35 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16"): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000) + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py index 64d6ec698..d832a91e4 100644 --- a/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py +++ b/examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py @@ -315,6 +315,36 @@ def main(batch: int = 1, print("Tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + window_size: Optional[int] = None, + dtype: str = "float16", + tune: bool = False): + torch_dtype = {"float16": torch.float16, "bfloat16": torch.bfloat16}[dtype] + block_M = 128 + block_N = 128 + num_stages = 2 + threads = 256 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + window_size, + block_M=block_M, + block_N=block_N, + num_stages=num_stages, + threads=threads, + dtype=dtype) + Q, K, V, sinks = gen_inputs(batch, heads, seq_q, seq_kv, dim, dtype=torch_dtype) + latency = do_bench(lambda: kernel(Q, K, V, sinks), warmup=500, rep=10000) + return latency + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/blocksparse_attention/bench_example_blocksparse_attention.py b/examples/blocksparse_attention/bench_example_blocksparse_attention.py new file mode 100644 index 000000000..35f577096 --- /dev/null +++ b/examples/blocksparse_attention/bench_example_blocksparse_attention.py @@ -0,0 +1,22 @@ +import tilelang.tools.bench +import example_tilelang_block_sparse_attn +import example_tilelang_sparse_gqa_decode_varlen_indice +import example_tilelang_sparse_gqa_decode_varlen_mask + + +def bench_example_tilelang_block_sparse_attn(): + tilelang.tools.bench.process_func(example_tilelang_block_sparse_attn.benchmark) + + +def bench_example_tilelang_sparse_gqa_decode_varlen_indice(): + tilelang.tools.bench.process_func( + example_tilelang_sparse_gqa_decode_varlen_indice.benchmark, batch=1, max_cache_seqlen=2048) + + +def bench_example_tilelang_sparse_gqa_decode_varlen_mask(): + tilelang.tools.bench.process_func( + example_tilelang_sparse_gqa_decode_varlen_mask.benchmark, batch=1, max_cache_seqlen=2048) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/blocksparse_attention/block_sparse_attn_triton.py b/examples/blocksparse_attention/block_sparse_attn_triton.py index 014f0c5fc..5d6a4a719 100644 --- a/examples/blocksparse_attention/block_sparse_attn_triton.py +++ b/examples/blocksparse_attention/block_sparse_attn_triton.py @@ -1,7 +1,6 @@ # ruff: noqa: E712 import math import torch - import triton import triton.language as tl import torch.nn.functional as F diff --git a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py index 7e90db7e5..18eac0f48 100644 --- a/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py +++ b/examples/blocksparse_attention/example_tilelang_block_sparse_attn.py @@ -1,8 +1,8 @@ import math import torch - import tilelang import tilelang.language as T +from tilelang.profiler import do_bench import torch.nn.functional as F @@ -224,5 +224,29 @@ def main(): test_topk_sparse_attention() +def benchmark(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 1, 1, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.bfloat16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + kernel = blocksparse_flashattn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + + def run_kernel_only(): + kernel(q, k, v, block_mask) + + return do_bench(run_kernel_only) + + if __name__ == "__main__": + benchmark() main() diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 1c4b847de..30b0f7237 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -8,6 +8,7 @@ import argparse import time import math +from tilelang.profiler import do_bench from heuristic import num_splits_heuristic @@ -572,6 +573,139 @@ def main(args): print(f"Speedup: {kernel_time_fa / kernel_time:.2f}x") +def benchmark(args): + + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = args.batch, args.heads, args.heads_kv, args.max_cache_seqlen, args.dim, args.dim_v + sparse_ratio = args.sparse_ratio + block_N = args.block_N + page_block_size = args.page_block_size + num_blocks = args.num_pages + max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') + cache_seqlens = torch.randint( + max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device='cuda') + V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), + dtype=dtype, + device='cuda') + max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) + block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device='cuda') + block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), + dtype=torch.int32, + device='cuda') + total_blocks_needed = sum( + int(math.ceil(cache_seqlens[seq_idx].item() / page_block_size)) for seq_idx in range(batch)) + available_blocks = list(range(total_blocks_needed)) + import random + random.seed(42) + random.shuffle(available_blocks) + block_assignment = {} + block_idx_counter = 0 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = available_blocks[block_idx_counter] + block_table[seq_idx, block_idx] = physical_block_idx + block_assignment[(seq_idx, block_idx)] = physical_block_idx + block_idx_counter += 1 + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_blocks_needed = int(math.ceil(seq_len / page_block_size)) + for block_idx in range(num_blocks_needed): + physical_block_idx = block_assignment[(seq_idx, block_idx)] + start_token = block_idx * page_block_size + end_token = min(start_token + page_block_size, seq_len) + actual_block_size = end_token - start_token + K_cache[physical_block_idx, :actual_block_size, :, :] = K[seq_idx, + start_token:end_token, :, :] + V_cache[physical_block_idx, :actual_block_size, :, :] = V[seq_idx, + start_token:end_token, :, :] + for seq_idx in range(batch): + seq_len = cache_seqlens[seq_idx].item() + num_tile = int(math.ceil(seq_len / block_N)) + if sparse_ratio == 0.0: + selected_blocks = min(num_tile, max_selected_blocks) + for head_idx in range(heads_kv): + for i in range(selected_blocks): + block_indices[seq_idx, head_idx, i] = num_tile - 1 - i + for i in range(selected_blocks, max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + else: + num_selected = int(num_tile * (1.0 - sparse_ratio)) + num_selected = max(1, min(num_selected, max_selected_blocks)) + all_blocks = list(range(num_tile)) + for head_idx in range(heads_kv): + selected_blocks = [] + recent_blocks = 1 + selected_blocks.append(num_tile - 1) + if num_selected > recent_blocks: + remaining_blocks = [b for b in all_blocks if b not in selected_blocks] + if remaining_blocks: + import random + random.seed(42) + additional_blocks = random.sample( + remaining_blocks, + min(num_selected - recent_blocks, len(remaining_blocks))) + selected_blocks.extend(additional_blocks) + + selected_blocks.sort(reverse=True) + + for i in range(len(selected_blocks)): + block_indices[seq_idx, head_idx, i] = selected_blocks[i] + for i in range(len(selected_blocks), max_selected_blocks): + block_indices[seq_idx, head_idx, i] = -1 + + sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, + num_blocks) + kernel = sparse_attn.kernel + batch = sparse_attn.batch + heads = sparse_attn.heads + heads_kv = sparse_attn.heads_kv + dim_v = sparse_attn.dim_v + dim = sparse_attn.dim + block_size = sparse_attn.block_N + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_attn.block_H - 1) // sparse_attn.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + + num_sm = sparse_attn.num_sm + + num_split = num_splits_heuristic( + total_mblocks, + num_sm, + num_n_blocks, + num_m_blocks, + size_one_kv_head, + is_causal_or_local=True, + max_splits=128) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') + output_partial = torch.empty((batch, heads, num_split, dim_v), + dtype=torch.float32, + device='cuda') + + def run_kernel_only(): + kernel( + Q, + K_cache, + V_cache, + block_indices, + cache_seqlens, + block_table, + glse, + output_partial, + ) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index b30875228..4c1d322de 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -7,6 +7,7 @@ import time import math from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench def flashattn(batch, heads, heads_kv, dim, dim_v): @@ -467,6 +468,74 @@ def main(batch=8, print("sparse time: ", (time.time() - start) / 100 * 1000) +def benchmark(batch=8, + heads=32, + heads_kv=8, + max_cache_seqlen=8192, + dim=128, + dim_v=128, + sparse_ratio=0.8, + block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_indices = torch.full((batch, heads_kv, max_selected_blocks), + -1, + dtype=torch.int32, + device='cuda') + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + if max_valid_block > 0: + for h in range(heads_kv): + valid_indices = torch.randperm( + max_valid_block, device='cuda', dtype=torch.int32)[:max_selected_blocks] + block_indices[b, h, :len(valid_indices)] = valid_indices + + block_indices, _ = block_indices.sort(dim=-1, descending=True) + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size + max_selected_blocks = block_indices.shape[-1] + + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = sparse_kernel.num_sm + + num_split = num_splits_heuristic( + total_mblocks, + num_sm, + num_n_blocks, + num_m_blocks, + size_one_kv_head, + is_causal_or_local=True, + max_splits=128) + + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') + output_partial = torch.empty((batch, heads, num_split, dim_v), + dtype=torch.float32, + device='cuda') + kernel = sparse_kernel.kernel + + def run_kernel_only(): + kernel(Q, K, V, block_indices, cache_seqlens, glse, output_partial) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index 3417bd7f8..2efff2240 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -5,10 +5,10 @@ import tilelang.language as T from einops import rearrange, einsum import argparse - import time import math from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench def flashattn(batch, heads, heads_kv, dim, dim_v): @@ -450,6 +450,77 @@ def main(batch=8, print("sparse time: ", (time.time() - start) / 100 * 1000) +def benchmark(batch=8, + heads=32, + heads_kv=8, + max_cache_seqlen=8192, + dim=128, + dim_v=128, + sparse_ratio=0.8, + block_size=32): + batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + sparse_ratio = sparse_ratio + block_size = block_size + max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) + dtype = torch.float16 + + Q = torch.randn((batch, heads, dim), dtype=dtype, device='cuda') + K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device='cuda') + V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device='cuda') + cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device='cuda') + random_index = torch.randint(0, batch, (1,), device='cuda').item() + cache_seqlens[random_index] = max_cache_seqlen + + num_blocks = (max_cache_seqlen + block_size - 1) // block_size + + valid_num_blocks = torch.ceil(cache_seqlens * (1 - sparse_ratio) / block_size).int() + max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() + block_mask = torch.zeros((batch, heads_kv, num_blocks), dtype=torch.bool, device='cuda') + + for b in range(batch): + max_valid_block = max_valid_num_blocks[b].item() + valid_num_block = valid_num_blocks[b].item() + if valid_num_block > 0: + for h in range(heads_kv): + perm = torch.randperm(max_valid_block, device='cuda')[:valid_num_block] + block_mask[b, h, perm] = True + + model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = model.batch + heads = model.heads + heads_kv = model.heads_kv + dim_v = model.dim_v + dim = model.dim + block_size = model.block_size + block_H = model.block_H + max_cache_seqlen = K.shape[1] + max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size + num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_n_blocks = max_selected_blocks + + size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 + total_mblocks = batch * heads_kv * num_m_blocks + num_sm = model.num_sm + num_split = num_splits_heuristic( + total_mblocks, + num_sm, + num_n_blocks, + num_m_blocks, + size_one_kv_head, + is_causal_or_local=True, + max_splits=128) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device='cuda') + Output_partial = torch.empty((batch, heads, num_split, dim_v), + dtype=torch.float32, + device='cuda') + kernel = model.kernel + + def run_kernel_only(): + kernel(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 85b72b775..0c916b828 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -5,10 +5,10 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic +from tilelang.profiler import do_bench @triton.autotune( diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py index 348572526..52db5b4ff 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_mask.py @@ -4,12 +4,10 @@ import argparse from einops import rearrange, einsum import torch.nn.functional as F - import math import time from heuristic import num_splits_heuristic - @triton.autotune( configs=[ triton.Config({}, num_warps=num_warps, num_stages=num_stages) diff --git a/examples/blocksparse_gemm/bench_example_blocksparse_gemm.py b/examples/blocksparse_gemm/bench_example_blocksparse_gemm.py new file mode 100644 index 000000000..f28836f88 --- /dev/null +++ b/examples/blocksparse_gemm/bench_example_blocksparse_gemm.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_blocksparse_gemm + + +def bench_example_blocksparse_gemm(): + tilelang.tools.bench.process_func(example_blocksparse_gemm.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 8cd3a8218..89cd997d3 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -6,6 +6,7 @@ from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType import torch from typing import List +from tilelang.profiler import do_bench DEFAULT_BLOCK_M = 128 DEFAULT_BLOCK_N = 128 @@ -184,5 +185,30 @@ def main(): print(e) +def benchmark(): + + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + + kernel = blocksparse_matmul( + M, + N, + K, + block_M=DEFAULT_BLOCK_M, + block_N=DEFAULT_BLOCK_N, + block_K=DEFAULT_BLOCK_K, + num_stages=DEFAULT_NUM_STAGES, + thread_num=DEFAULT_THREAD_NUM, + enable_rasteration=DEFAULT_ENABLE_RASTERIZATION) + block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K + mask_shape = (M // block_M, N // block_N, K // block_K) + block_mask = torch.rand(mask_shape).cuda() > sparsity + + def run_kernel_only(): + kernel(a, b, block_mask) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/cast/bench_example_cast.py b/examples/cast/bench_example_cast.py new file mode 100644 index 000000000..324311e4c --- /dev/null +++ b/examples/cast/bench_example_cast.py @@ -0,0 +1,22 @@ +import tilelang.tools.bench +import example_group_per_split_token_cast_to_fp8 +import example_per_token_cast_to_fp8 + + +def bench_example_group_per_split_token_cast_to_fp8(): + tilelang.tools.bench.process_func( + example_group_per_split_token_cast_to_fp8.benchmark, + M=1024, + N=1024, + BG=2, + blk_m=4, + batch_sizes=[128, 896]) + + +def bench_example_per_token_cast_to_fp8(): + tilelang.tools.bench.process_func( + example_per_token_cast_to_fp8.benchmark, M=2048, N=512, blk_m=8) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/cast/example_group_per_split_token_cast_to_fp8.py b/examples/cast/example_group_per_split_token_cast_to_fp8.py index 102ac2021..3e207ca48 100644 --- a/examples/cast/example_group_per_split_token_cast_to_fp8.py +++ b/examples/cast/example_group_per_split_token_cast_to_fp8.py @@ -206,5 +206,35 @@ def run_torch(): print("Torch: {:.2f} ms".format(latency)) +def benchmark(M=8192, N=8192, BG=2, blk_m=8, batch_sizes=None): + if batch_sizes is None: + batch_sizes = [2048, 6144] + if dtype == "float": + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + elif dtype == "float16": + x = torch.randn(M, N, device="cuda", dtype=torch.float16) + elif dtype == "bfloat16": + x = torch.randn(M, N, device="cuda", dtype=torch.bfloat16) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + batch_sizes = torch.tensor(batch_sizes, device="cuda", dtype=torch.int32) + M_max = int(ceil_div(batch_sizes.max(), 128) * 128) + + kernel = group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m) + + x_fp8, x_amax = kernel(x, batch_sizes) + x_fp8_ref, x_amax_ref = ref_program(x, batch_sizes) + + torch_assert_close(x_fp8.to(torch.float32), x_fp8_ref.to(torch.float32), rtol=0.01, atol=0.01) + torch_assert_close(x_amax, x_amax_ref, rtol=0.01, atol=0.01) + + from tilelang.profiler import do_bench + + def run_tilelang(): + kernel(x, batch_sizes) + + return do_bench(run_tilelang) + + if __name__ == "__main__": main() diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 484a092f0..f12bf196e 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -114,5 +114,16 @@ def run_triton(): print("Triton: {:.2f} ms".format(latency)) +def benchmark(M=8192, N=8192, blk_m=8): + kernel = per_token_cast_to_fp8(M, N, blk_m) + x = torch.randn(M, N, device="cuda", dtype=torch.float32) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(x) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/convolution/bench_example_convolution.py b/examples/convolution/bench_example_convolution.py new file mode 100644 index 000000000..6368af9cd --- /dev/null +++ b/examples/convolution/bench_example_convolution.py @@ -0,0 +1,15 @@ +import tilelang.tools.bench +import example_convolution +import example_convolution_autotune + + +def bench_example_convolution(): + tilelang.tools.bench.process_func(example_convolution.benchmark) + + +def bench_example_convolution_autotune(): + tilelang.tools.bench.process_func(example_convolution_autotune.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/convolution/example_convolution.py b/examples/convolution/example_convolution.py index b2696ba8f..2b02f0c57 100644 --- a/examples/convolution/example_convolution.py +++ b/examples/convolution/example_convolution.py @@ -125,5 +125,30 @@ def main(argv=None): print("All checks passed.āœ…") +def benchmark(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument('--n', type=int, default=128, help='n') + parser.add_argument('--c', type=int, default=128, help='c') + parser.add_argument('--h', type=int, default=64, help='h') + parser.add_argument('--w', type=int, default=64, help='w') + parser.add_argument('--f', type=int, default=128, help='f') + parser.add_argument('--k', type=int, default=3, help='k') + parser.add_argument('--s', type=int, default=1, help='s') + parser.add_argument('--d', type=int, default=1, help='d') + parser.add_argument('--p', type=int, default=1, help='p') + + args = parser.parse_args(argv) + N, C, H, W, F, K, S, D, P = args.n, args.c, args.h, args.w, args.f, args.k, args.s, args.d, args.p + + block_m = 64 + block_n = 128 + block_k = 32 + num_stages = 3 + threads = 256 + kernel = convolution(N, C, H, W, F, K, S, D, P, block_m, block_n, block_k, num_stages, threads) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench() + + if __name__ == "__main__": main() diff --git a/examples/convolution/example_convolution_autotune.py b/examples/convolution/example_convolution_autotune.py index 393677489..c0a16ae32 100644 --- a/examples/convolution/example_convolution_autotune.py +++ b/examples/convolution/example_convolution_autotune.py @@ -194,6 +194,24 @@ def main(n: int = 128, print(f"Ref latency: {ref_latency}") +def benchmark(n: int = 128, + c: int = 128, + h: int = 64, + w: int = 64, + f: int = 128, + k: int = 3, + s: int = 1, + d: int = 1, + p: int = 1, + use_autotune: bool = False, + with_roller: bool = True): + N, C, H, W, F, K, S, D, P = n, c, h, w, f, k, s, d, p + config = get_heuristic_config() + kernel = convolution(N, C, H, W, F, K, S, D, P, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench() + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument('--n', type=int, default=128, help='n') diff --git a/examples/deepseek_mla/bench_example_mla_decode.py b/examples/deepseek_mla/bench_example_mla_decode.py new file mode 100644 index 000000000..343c386c7 --- /dev/null +++ b/examples/deepseek_mla/bench_example_mla_decode.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_mla_decode + + +def bench_example_mla_decode(): + tilelang.tools.bench.process_func(example_mla_decode.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 3932d112e..16c6d3e90 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -309,6 +309,26 @@ def main( print(f"TFlops: {total_flops / latency * 1e-9} TFlops") +def benchmark( + batch=1, + heads=128, + kv_heads=1, + kv_ctx=8192, + dim=512, + pe_dim=64, +): + BLOCK_N = 64 + BLOCK_H = min(64, heads // kv_heads) + num_split = 1 + softmax_scale = (dim + pe_dim)**-0.5 + + kernel = flashattn(batch, heads, kv_heads, kv_ctx, dim, pe_dim, BLOCK_N, BLOCK_H, num_split, + softmax_scale) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Randn) + profiler.assert_allclose(ref_program, rtol=1e-4, atol=1e-4) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=132, help='batch size') diff --git a/examples/deepseek_nsa/bench_example_tilelang_nsa.py b/examples/deepseek_nsa/bench_example_tilelang_nsa.py new file mode 100644 index 000000000..b856fd1dc --- /dev/null +++ b/examples/deepseek_nsa/bench_example_tilelang_nsa.py @@ -0,0 +1,15 @@ +import tilelang.tools.bench +import example_tilelang_nsa_fwd +import example_tilelang_nsa_decode + + +def bench_example_tilelang_nsa_fwd(): + tilelang.tools.bench.process_func(example_tilelang_nsa_fwd.benchmark) + + +def bench_example_tilelang_nsa_fwd_decode(): + tilelang.tools.bench.process_func(example_tilelang_nsa_decode.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 58f435509..89334a532 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -178,5 +178,42 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def benchmark(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype = 2, 64, 1, 16, 16, 1, 32, torch.float16 + groups = HQ // H + SEQ_LEN_Q = 1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + ) + + Q = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + + mask = torch.randint(0, 2, (B, SEQ_LEN, groups), device='cuda') + DO = torch.randn((B, SEQ_LEN_Q, HQ, D), dtype=dtype, device='cuda') + + block_indices = torch.full((B, SEQ_LEN_Q, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + for b in range(B): + for t in range(SEQ_LEN_Q): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, :len(i_i)] = i_i + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index 0b71779b8..40903d607 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -184,5 +184,43 @@ def main(): torch.testing.assert_close(ref, out, atol=1e-2, rtol=1e-2) +def benchmark(): + B, SEQ_LEN, H, HQ, D, S, block_size, dtype, scale = 2, 64, 1, 16, 32, 1, 32, torch.float16, 0.1 + kernel = native_sparse_attention( + batch=B, + heads=HQ, + seq_len=SEQ_LEN, + dim=D, + is_causal=True, + block_size=block_size, + groups=HQ // H, + selected_blocks=S, + scale=scale, + ) + torch.random.manual_seed(0) + Q = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda').requires_grad_(True) + K = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + V = torch.randn((B, SEQ_LEN, H, D), dtype=dtype, device='cuda').requires_grad_(True) + g_slc = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) + g_swa = torch.ones((B, SEQ_LEN, HQ), dtype=dtype, device='cuda').requires_grad_(True) + DO = torch.randn((B, SEQ_LEN, HQ, D), dtype=dtype, device='cuda') + block_indices = torch.full((B, SEQ_LEN, H, S), SEQ_LEN, dtype=torch.long, device='cuda') + block_counts = torch.zeros((B, SEQ_LEN, H), dtype=torch.long, device='cuda') + for b in range(B): + for t in range(SEQ_LEN): + for h in range(H): + i_i = torch.randperm(max(1, (t // block_size)))[:S] + block_indices[b, t, h, :len(i_i)] = i_i + block_counts[b, t, h] = (block_indices[b, t, h] != SEQ_LEN).sum().item() + block_indices = block_indices.sort(-1)[0] + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, block_indices.to(torch.int32)) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/deepseek_v32/bench_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/bench_tilelang_example_deepseek_v32.py new file mode 100644 index 000000000..d54797301 --- /dev/null +++ b/examples/deepseek_v32/bench_tilelang_example_deepseek_v32.py @@ -0,0 +1,58 @@ +import tilelang.tools.bench +import fp8_lighting_indexer +import sparse_mla_bwd +import sparse_mla_fwd +import sparse_mla_fwd_pipelined +import topk_selector + + +def bench_topk_selector(): + tilelang.tools.bench.process_func(topk_selector.benchmark) + + +def bench_fp8_lighting_indexer(): + tilelang.tools.bench.process_func( + fp8_lighting_indexer.benchmark, S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) + + +def bench_sparse_mla_fwd(): + tilelang.tools.bench.process_func( + sparse_mla_fwd.benchmark, + S=256, + SKV=1024, + H=64, + HKV=1, + DQK=576, + DV=512, + topk=256, + check_correctness=False) + + +def bench_sparse_mla_fwd_pipelined(): + tilelang.tools.bench.process_func( + sparse_mla_fwd_pipelined.benchmark, + S=256, + SKV=512, + H=64, + HKV=1, + DQK=576, + DV=512, + topk=256, + check_correctness=False) + + +def bench_sparse_mla_bwd(): + tilelang.tools.bench.process_func( + sparse_mla_bwd.benchmark, + S=256, + SKV=512, + H=64, + HKV=1, + DQKV=576, + DV=512, + topk=256, + check_correctness=False) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index dd940648b..1e2bdbd3d 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -304,5 +304,45 @@ def logits_fn(): print(f"cost_ref: {cost_ref}") +def benchmark(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): + torch.manual_seed(0) + q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) + weights = torch.randn(S, H, device="cuda", dtype=torch.float32) + p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) + + ks, ke = generate_random_cu_seqlens( + per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) + + logits_ref, cost_ref = ref_fp8_mqa_logits( + q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) + + logits_tl = mqa_attn_return_logits_interface( + q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) + diff = validate_tensor_match( + logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) + + from tilelang.profiler import do_bench + + def logits_fn(): + return mqa_attn_return_logits_interface( + q=q_fp8, + kv=kv_fp8, + kv_scales=kv_scales, + weights=weights, + cu_seqlen_ks=ks, + cu_seqlen_ke=ke) + + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: + logits_fn() + + print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) + + return do_bench(logits_fn, warmup=100, rep=100) + + if __name__ == "__main__": test_fp8_lighting_indexer() diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index e7f9c6093..9f9e5efeb 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -384,6 +384,47 @@ def fn(): print(f'bwd tflops = ', per_token_flop * S / (ms * 1e-3) / 1e12) +def benchmark(B=1, + S=4096, + SKV=8192, + H=64, + HKV=1, + DQKV=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + check_correctness=True): + q = torch.randn((B, S, H, DQKV), dtype=dtype, device='cuda').requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQKV), dtype=dtype, device='cuda').requires_grad_(True) + do = torch.randn((B, S, H, DV), dtype=dtype, device='cuda') + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, :len(i_i)] = i_i + + from sparse_mla_fwd import sparse_mla_fwd_interface + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices) + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + D = 512 + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, None, True) + delta = preprocess_kernel(tl_out, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + return bwd_kernel(q, kv, do, indices, tl_lse, delta, dkv) + + return do_bench(run_kernel_only, rep=100, warmup=250) + + if __name__ == "__main__": test_sparse_mla_bwd( B=1, diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index e65b89017..a15850e4b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -301,6 +301,59 @@ def fn(): print("fwd tflops = ", (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) +def benchmark(B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + block_I=64, + num_stages=2, + threads=256): + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(max(1, t))[:topk] + indices[b, t, h, :len(i_i)] = i_i + + is_casual = True + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + dim = 512 + tail_dim = dim_plus_tail_dim - dim + _, _, _, topk = indices.shape + kernel = sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + None, + is_casual, + block_I=block_I, + num_stages=num_stages, + threads=threads) + + def run_kernel_only(): + kernel(q, kv, indices) + + from tilelang.profiler import do_bench + + return do_bench( + run_kernel_only, + rep=100, + warmup=250, + ) + + if __name__ == "__main__": test_sparse_mla_fwd( B=1, diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 1621d85ba..38d76f358 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -456,6 +456,60 @@ def fn(): print(f'fwd tflops = ', (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12) +def benchmark(B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + q_start_s_index=1024, + check_correctness=True): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device='cuda').requires_grad_(True) / 10 + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device='cuda') + for b in range(B): + for t in range(S): + for h in range(HKV): + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, :len(i_i)] = i_i + + kernel = sparse_mla_fwd_interface( + q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + sm_scale = None + is_casual = True + return_kernel = False + print_kernel = False + dim = 512 + tail_dim = dim_plus_tail_dim - dim + _, _, _, topk = indices.shape + CP0 = q_start_s_index == 0 + + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, KV_stride, + kv_group, sm_scale, is_casual, CP0) + + def ran_kernel_only(): + kernel(q, kv, indices, torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")) + + from tilelang.profiler import do_bench + return do_bench( + ran_kernel_only, + rep=100, + warmup=10, + ) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--test_correctness", action="store_true") diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 4a4b43277..0012fa100 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -245,5 +245,37 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): print(f"Average torch.topk time: {elapsed_time_ms / n_iters:.3f} ms") +def benchmark(batch=64, seq_len=32 * 1024, topk=2048): + + batch = 64 + seq_len = 32 * 1024 + topk = 2048 + torch.manual_seed(1) + input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() + starts = torch.zeros(batch, dtype=torch.int32).cuda() + ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len + + indexes = tl_topk(input, starts, ends, topk) + + indexes_ref = torch.topk(input, topk, dim=-1)[1] + + for i in range(batch): + ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() + trt_np = indexes[i].cpu().to(torch.int32).numpy() + + set_ref = set(ref_np) + set_trt = set(trt_np) + intersection = set_ref & set_trt + print("selected/all:", len(intersection), "/", len(set_ref), "=", + len(intersection) / len(set_ref)) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + tl_topk(input, starts, ends, topk) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": test_topk_selector() diff --git a/examples/dequantize_gemm/bench_example_dequantize_gemm.py b/examples/dequantize_gemm/bench_example_dequantize_gemm.py new file mode 100644 index 000000000..b400bd80f --- /dev/null +++ b/examples/dequantize_gemm/bench_example_dequantize_gemm.py @@ -0,0 +1,40 @@ +import tilelang.tools.bench +import example_dequant_gemm_bf16_fp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper +import example_dequant_gemm_bf16_mxfp4_hopper_tma +import example_dequant_gemm_fp4_hopper +import example_dequant_gemm_w4a8 +import example_dequant_gemv_fp16xint4 +import example_dequant_groupedgemm_bf16_mxfp4_hopper + + +def bench_example_dequant_gemv_fp16xint4(): + tilelang.tools.bench.process_func(example_dequant_gemv_fp16xint4.benchmark) + + +def bench_example_dequant_gemm_fp4_hopper(): + tilelang.tools.bench.process_func(example_dequant_gemm_fp4_hopper.benchmark) + + +def bench_example_dequant_gemm_bf16_fp4_hopper(): + tilelang.tools.bench.process_func(example_dequant_gemm_bf16_fp4_hopper.benchmark) + + +def bench_example_dequant_gemm_bf16_mxfp4_hopper(): + tilelang.tools.bench.process_func(example_dequant_gemm_bf16_mxfp4_hopper.benchmark) + + +def bench_example_dequant_gemm_bf16_mxfp4_hopper_tma(): + tilelang.tools.bench.process_func(example_dequant_gemm_bf16_mxfp4_hopper_tma.benchmark) + + +def bench_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + tilelang.tools.bench.process_func(example_dequant_groupedgemm_bf16_mxfp4_hopper.benchmark) + + +def bench_example_dequant_gemm_w4a8(): + tilelang.tools.bench.process_func(example_dequant_gemm_w4a8.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index e30845b8d..c70f52cf0 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -437,6 +437,26 @@ def main(m=256, n=256, k=256, fast_dequant=True, tune=False): print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(m=256, n=256, k=256, fast_dequant=True, tune=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + fast_dequant=fast_dequant, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": main(256, 256, 256, True) main(256, 256, 256, False) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index ac1417aeb..c4f15b25e 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -538,6 +538,28 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": M, N, K = 256, 256, 256 scale_size = 32 diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py index 7dad79597..01b396ae3 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper_tma.py @@ -554,6 +554,28 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False, tune=False): + kernel = matmul( + m, + n, + k, + "bfloat16", + "bfloat16", + "float32", + num_bits=4, + scale_size=scale_size, + block_M=256, + block_N=128, + block_K=128, + num_stages=2, + threads=256, + split=1, + fast_dequant=fast_dequant, + with_bias=with_bias) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": M, N, K = 256, 256, 256 scale_size = 32 diff --git a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py index c5588d516..dfd2813e9 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fp4_hopper.py @@ -291,6 +291,14 @@ def main(m=256, n=256, k=256, tune=False): print(f"Best config: {best_config}") +def benchmark(m=256, n=256, k=256): + kernel = matmul( + m, n, k, "float16", "float16", "float32", num_bits=4, tune=False)( + block_M=128, block_N=128, block_K=128, num_stages=2, threads=256, split=1) + profiler = kernel.get_profiler(tilelang.TensorSupplyType.Integer) + return profiler.do_bench(warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--m', type=int, default=256, help='M') diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index 52ee8216f..431f74d45 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -187,6 +187,14 @@ def main(m=128, n=256, k=256, tune=False): print(f"Best tflops: {total_flops / best_latency * 1e-9}") +def benchmark(m=128, n=256, k=256): + kernel = matmul_int8xint4( + m, n, k, "int8", "int32", "int32", num_bits=4, tune=False)( + block_M=32, block_N=32, block_K=128, num_stages=1, threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=512, help="Matrix dimension M") diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index d3e90ec93..7980fa476 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -205,5 +205,47 @@ def main() -> None: torch.testing.assert_close(C, ref_c, atol=1e3, rtol=1e-1) +def benchmark(): + M = 1 + N = 1024 + K = 1024 + in_dtype = "float16" + out_dtype = "float16" + accum_dtype = "float16" + num_bits = 4 + storage_dtype = "int8" + source_format = "uint" + n_partition = 4 + reduce_thread = 32 + fast_decoding = True + trans_A = False + trans_B = True + group_size = -1 + with_scaling = False + + kernel = dequantize_gemv(M, N, K, in_dtype, out_dtype, accum_dtype, num_bits, storage_dtype, + source_format, n_partition, reduce_thread, fast_decoding, trans_A, + trans_B, group_size, with_scaling) + + storage_nbit = int("".join(c for c in storage_dtype if c.isdigit())) + num_elems_per_byte = storage_nbit // num_bits + A = torch.rand(M, K, dtype=getattr(torch, in_dtype)).cuda() + qB = torch.randint( + 0, 127, (N, K // num_elems_per_byte), dtype=getattr(torch, storage_dtype)).cuda() + C = torch.zeros(M, N, dtype=getattr(torch, accum_dtype)).cuda() + + if fast_decoding: + from tilelang.quantize.utils import interleave_weight + qB = interleave_weight(qB, num_bits, in_dtype) + kernel(A, qB, C) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A, qB, C) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py index c4cf5fb50..d5c889952 100644 --- a/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py @@ -531,6 +531,69 @@ def main(m=256, print("All checks pass. āœ…") +def benchmark(m=256, + n=256, + k=256, + scale_size=32, + topk=4, + E=32, + fast_dequant=True, + with_bias=False, + tune=False): + block_M, block_N, block_K = 128, 256, 128 + num_stages = 1 + threads = 512 + split = 1 + num_bits = 4 + num_elems_per_byte = 8 // num_bits + qk = k // num_elems_per_byte + A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, padding_M = get_data( + m, n, k, qk, scale_size, topk, E, block_M) + + if tune: + with set_autotune_inputs([A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids]): + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + ) + else: + kernel = matmul( + m, + n, + k, + topk, + E, + padding_M, + "bfloat16", + "bfloat16", + "float32", + num_bits=num_bits, + scale_size=scale_size, + fast_dequant=fast_dequant, + with_bias=with_bias, + block_M=block_M, + block_N=block_N, + block_K=block_K, + num_stages=num_stages, + threads=threads, + split=split, + ) + + return tilelang.profiler.do_bench( + lambda: kernel(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids), warmup=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( diff --git a/examples/dynamic_shape/bench_example_dynamic.py b/examples/dynamic_shape/bench_example_dynamic.py new file mode 100644 index 000000000..cf1eddb22 --- /dev/null +++ b/examples/dynamic_shape/bench_example_dynamic.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_dynamic + + +def bench_example_dynamic(): + tilelang.tools.bench.process_func(example_dynamic.benchmark, M=1024, N=1024, K=1024) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/dynamic_shape/example_dynamic.py b/examples/dynamic_shape/example_dynamic.py index be018c8b7..5011fcd96 100644 --- a/examples/dynamic_shape/example_dynamic.py +++ b/examples/dynamic_shape/example_dynamic.py @@ -107,5 +107,28 @@ def main(M=16384, N=16384, K=16384): accum_dtype, num_stages, threads) +def benchmark(M, N, K): + block_M, block_N, block_K = 128, 128, 32 + trans_A, trans_B = False, False + in_dtype, out_dtype = "float16", "float16" + accum_dtype = "float32" + num_stages = 3 + threads = 128 + kernel = matmul_dynamic_mnk(block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, + accum_dtype, num_stages, threads) + import torch + if trans_A: + A = torch.rand(K, M, device="cuda", dtype=getattr(torch, in_dtype)) + else: + A = torch.rand(M, K, device="cuda", dtype=getattr(torch, in_dtype)) + if trans_B: + B = torch.rand(N, K, device="cuda", dtype=getattr(torch, in_dtype)) + else: + B = torch.rand(K, N, device="cuda", dtype=getattr(torch, in_dtype)) + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(input_tensors=[A, B, C]) + + if __name__ == "__main__": main() diff --git a/examples/elementwise/bench_example_elementwise.py b/examples/elementwise/bench_example_elementwise.py new file mode 100644 index 000000000..71923fa3c --- /dev/null +++ b/examples/elementwise/bench_example_elementwise.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_elementwise_add + + +def bench_example_elementwise_add(): + tilelang.tools.bench.process_func(example_elementwise_add.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index bc9bb4df5..65e9c13b8 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -4,6 +4,7 @@ import tilelang import tilelang.language as T from tilelang.autotuner import AutoTuner +from tilelang.profiler import do_bench def ref_program(x, y): @@ -80,5 +81,22 @@ def main(): torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) +def benchmark(): + parser = argparse.ArgumentParser() + parser.add_argument("--m", type=int, default=1024) + parser.add_argument("--n", type=int, default=1024) + args, _ = parser.parse_known_args() + M, N = args.m, args.n + a = torch.randn(M, N, dtype=torch.float32, device="cuda") + b = torch.randn(M, N, dtype=torch.float32, device="cuda") + config = {"block_M": 32, "block_N": 32, "threads": 128} + kernel = elementwise_add(M, N, **config, in_dtype="float32", out_dtype="float32") + + def run_kernel_only(): + kernel(a, b) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/flash_attention/bench_example_flash_attention.py b/examples/flash_attention/bench_example_flash_attention.py new file mode 100644 index 000000000..7ad16b559 --- /dev/null +++ b/examples/flash_attention/bench_example_flash_attention.py @@ -0,0 +1,91 @@ +import tilelang.tools.bench +import example_gqa_fwd_bshd +import example_gqa_fwd_bshd_wgmma_pipelined +import example_mha_fwd_bhsd +import example_mha_fwd_bhsd_wgmma_pipelined +import example_mha_fwd_bshd +import example_mha_fwd_bshd_wgmma_pipelined +import example_mha_fwd_varlen +import example_gqa_bwd_tma_reduce_varlen +import example_gqa_bwd +import example_gqa_bwd_wgmma_pipelined +import example_mha_bwd_bshd +import example_mha_bwd_bhsd +import example_mha_bwd_bshd_wgmma_pipelined + + +def bench_example_gqa_bwd_tma_reduce_varlen(): + tilelang.tools.bench.process_func( + example_gqa_bwd_tma_reduce_varlen.benchmark, name="example_gqa_bwd_tma_reduce_varlen") + + +def bench_example_gqa_bwd(): + tilelang.tools.bench.process_func(example_gqa_bwd.benchmark, name="example_gqa_bwd") + + +def bench_example_gqa_bwd_wgmma_pipelined(): + tilelang.tools.bench.process_func( + example_gqa_bwd_wgmma_pipelined.benchmark, name="example_gqa_bwd_wgmma_pipelined") + + +def bench_example_mha_bwd_bshd(): + tilelang.tools.bench.process_func(example_mha_bwd_bshd.benchmark, name="example_mha_bwd_bshd") + + +def bench_example_mha_bwd_bhsd(): + tilelang.tools.bench.process_func(example_mha_bwd_bhsd.benchmark, name="example_mha_bwd_bhsd") + + +def bench_example_mha_bwd_bshd_wgmma_pipelined(): + tilelang.tools.bench.process_func( + example_mha_bwd_bshd_wgmma_pipelined.benchmark, name="example_mha_bwd_bshd_wgmma_pipelined") + + +def bench_example_gqa_fwd_bshd_wgmma_pipelined(): + tilelang.tools.bench.process_func( + example_gqa_fwd_bshd_wgmma_pipelined.benchmark, + batch=1, + heads=16, + seq_len=1024, + dim=128, + is_causal=False, + groups=16, + tune=False) + + +def bench_example_gqa_fwd_bshd(): + tilelang.tools.bench.process_func( + example_gqa_fwd_bshd.benchmark, + batch=1, + heads=16, + seq_len=1024, + dim=128, + is_causal=False, + groups=16, + tune=False) + + +def bench_example_mha_fwd_bhsd_wgmma_pipelined(): + tilelang.tools.bench.process_func(example_mha_fwd_bhsd_wgmma_pipelined.benchmark) + + +def bench_example_mha_fwd_bhsd(): + tilelang.tools.bench.process_func(example_mha_fwd_bhsd.benchmark) + + +def bench_example_mha_fwd_bshd_wgmma_pipelined(): + tilelang.tools.bench.process_func( + example_mha_fwd_bshd_wgmma_pipelined.benchmark, batch=1, heads=32, seq_len=256) + + +def bench_example_mha_fwd_bshd(): + tilelang.tools.bench.process_func(example_mha_fwd_bshd.benchmark, batch=1, seq_len=256) + + +def bench_example_mha_fwd_varlen(): + tilelang.tools.bench.process_func( + example_mha_fwd_varlen.benchmark, batch=4, heads=16, seq_len=512, dim=64) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/flash_attention/example_gqa_bwd.py b/examples/flash_attention/example_gqa_bwd.py index 968d1de33..7ec8c1591 100644 --- a/examples/flash_attention/example_gqa_bwd.py +++ b/examples/flash_attention/example_gqa_bwd.py @@ -526,6 +526,50 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + head_kv = H // groups + Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index a9604f4de..6c1f00d52 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -767,6 +767,58 @@ def run1(): ) +def benchmark(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + total_q = BATCH * N_CTX + total_kv = BATCH * N_CTX + head_kv = H // groups + Q = torch.randn(total_q, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(total_q, H, D_HEAD_V, device=device, dtype=torch.half) + cu_seqlens_q = torch.arange(0, (BATCH + 1) * N_CTX, N_CTX, device=device, dtype=torch.int32) + cu_seqlens_k = cu_seqlens_q + max_seqlen_q = N_CTX + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, max_seqlen_q, D_HEAD_V) + kernel = flashattn_bwd_split( + BATCH, + total_q, + total_kv, + N_CTX, + H, + max_seqlen_q, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups, + ) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(groups, total_kv, head_kv, D_HEAD_QK, device=device, dtype=torch.float16) + dV = torch.zeros(groups, total_kv, head_kv, D_HEAD_V, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO, cu_seqlens_q) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, cu_seqlens_q, cu_seqlens_k, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": arch = nvcc.get_target_compute_version() print(f"Detected GPU compute capability: {arch}") diff --git a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py index e916812f5..6defa74b2 100644 --- a/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_bwd_wgmma_pipelined.py @@ -378,6 +378,49 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD_QK = 192 + D_HEAD_V = 128 + groups = 16 + causal = False + device = "cuda" + torch.manual_seed(42) + head_kv = H // groups + Q = torch.randn(BATCH, N_CTX, H, D_HEAD_QK, device=device, dtype=torch.half) + K = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.half) + V = torch.randn(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.half) + O = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + dO = torch.randn(BATCH, N_CTX, H, D_HEAD_V, device=device, dtype=torch.half) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) + kernel = flashattn_bwd( + BATCH, + H, + N_CTX, + D_HEAD_QK, + D_HEAD_V, + causal, + block_M=128, + block_N=32, + threads=256, + num_stages=2, + groups=groups) + dQ = torch.zeros_like(Q, dtype=torch.float32) + dK = torch.zeros(BATCH, N_CTX, head_kv, D_HEAD_QK, device=device, dtype=torch.float32) + dV = torch.zeros(BATCH, N_CTX, head_kv, D_HEAD_V, device=device, dtype=torch.float32) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') diff --git a/examples/flash_attention/example_gqa_fwd_bshd.py b/examples/flash_attention/example_gqa_fwd_bshd.py index a6d3b5f20..51fb84373 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd.py +++ b/examples/flash_attention/example_gqa_fwd_bshd.py @@ -268,6 +268,33 @@ def main(batch: int = 1, print(f"Ref latency: {ref_latency}") +def benchmark(batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, + tune: bool = False): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + kernel = flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=groups, + block_M=64, + block_N=64, + num_stages=2, + threads=128) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') diff --git a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py index 03ad15e94..214a10389 100644 --- a/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_gqa_fwd_bshd_wgmma_pipelined.py @@ -242,6 +242,36 @@ def main( print(f"Ref latency: {ref_latency}") +def benchmark( + batch: int = 1, + heads: int = 64, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + groups: int = 16, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=groups, + block_M=128, + block_N=128, + num_stages=2, + threads=256) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') diff --git a/examples/flash_attention/example_mha_bwd_bhsd.py b/examples/flash_attention/example_mha_bwd_bhsd.py index d91d1770f..c2b8c5967 100644 --- a/examples/flash_attention/example_mha_bwd_bhsd.py +++ b/examples/flash_attention/example_mha_bwd_bhsd.py @@ -351,6 +351,37 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, H, N_CTX, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') diff --git a/examples/flash_attention/example_mha_bwd_bshd.py b/examples/flash_attention/example_mha_bwd_bshd.py index 7c85f982e..99fa1c840 100644 --- a/examples/flash_attention/example_mha_bwd_bshd.py +++ b/examples/flash_attention/example_mha_bwd_bshd.py @@ -342,6 +342,37 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(): + BATCH = 1 + H = 16 + N_CTX = 512 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(42) + block_M = 64 + block_N = 64 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + dV = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float16) + Delta = mod_prep(O, dO) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=100, rep=1000) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py index e8ee5d973..38faa7637 100644 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py @@ -327,6 +327,38 @@ def run1(): print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(): + BATCH = 1 + H = 32 + N_CTX = 256 + D_HEAD = 64 + causal = False + device = "cuda" + torch.manual_seed(0) + block_M = 128 + block_N = 128 if D_HEAD <= 64 else 32 + Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + O = torch.randn_like(Q) + dO = torch.randn_like(Q) + lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) + with torch.no_grad(): + mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) + kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) + dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) + dK = torch.zeros_like(Q, dtype=torch.float16) + dV = torch.zeros_like(Q, dtype=torch.float16) + Delta = mod_prep(O, dO) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='Batch size') diff --git a/examples/flash_attention/example_mha_fwd_bhsd.py b/examples/flash_attention/example_mha_fwd_bhsd.py index e0e0bca22..e64511083 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd.py +++ b/examples/flash_attention/example_mha_fwd_bhsd.py @@ -225,6 +225,35 @@ def main( print(f"Ref latency: {ref_latency}") +def benchmark( + batch: int = 1, + heads: int = 1, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 64, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=64, + block_N=64, + num_stages=1, + threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') diff --git a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py index b797bbcc6..0aba68031 100644 --- a/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bhsd_wgmma_pipelined.py @@ -228,6 +228,34 @@ def main( print(f"Ref latency: {ref_latency}") +def benchmark( + batch: int = 1, + heads: int = 32, + seq_q: int = 256, + seq_kv: int = 256, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_q * seq_kv * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + kernel = flashattn( + batch, + heads, + seq_q, + seq_kv, + dim, + is_causal, + block_M=128, + block_N=128, + num_stages=2, + threads=256) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/flash_attention/example_mha_fwd_bshd.py b/examples/flash_attention/example_mha_fwd_bshd.py index b5b728287..a847153fc 100644 --- a/examples/flash_attention/example_mha_fwd_bshd.py +++ b/examples/flash_attention/example_mha_fwd_bshd.py @@ -211,6 +211,25 @@ def main( print(f"Ref latency: {ref_latency}") +def benchmark( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn( + batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=1, threads=128) + profiler = kernel.get_profiler() + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py index 02d8baef2..71ca474f2 100644 --- a/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py +++ b/examples/flash_attention/example_mha_fwd_bshd_wgmma_pipelined.py @@ -216,6 +216,25 @@ def main( print(f"Ref latency: {ref_latency}") +def benchmark( + batch: int = 8, + heads: int = 32, + seq_len: int = 4096, + dim: int = 128, + is_causal: bool = False, + tune: bool = False, +): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + kernel = flashattn( + batch, heads, seq_len, dim, is_causal, block_M=128, block_N=128, num_stages=2, threads=256) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/flash_attention/example_mha_fwd_varlen.py b/examples/flash_attention/example_mha_fwd_varlen.py index bbb4546ca..cd780998d 100644 --- a/examples/flash_attention/example_mha_fwd_varlen.py +++ b/examples/flash_attention/example_mha_fwd_varlen.py @@ -285,6 +285,50 @@ def main(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): print("All checks passed.āœ…") +def benchmark(batch: int = 8, heads: int = 64, seq_len: int = 2048, dim: int = 128): + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + tilelang.testing.set_random_seed(0) + causal = False + if causal: + total_flops *= 0.5 + dtype = torch.float16 + device = torch.device("cuda") + window_size = (-1, -1) + q = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + k = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + v = torch.randn(batch, seq_len, heads, dim, dtype=dtype, requires_grad=True).to(device) + query_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(seq_len, batch, device, mode="random") + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + dq_pad_fn, + dk_pad_fn, + ) = generate_qkv( + q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + UQ = q_unpad.shape[0] + UK = k_unpad.shape[0] + UKV = k_unpad.shape[0] + kernel = flashattn(batch, UQ, UKV, heads, dim, causal) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=8, help='batch size') diff --git a/examples/flash_decoding/bench_example_flash_decoding.py b/examples/flash_decoding/bench_example_flash_decoding.py new file mode 100644 index 000000000..66dd3d1eb --- /dev/null +++ b/examples/flash_decoding/bench_example_flash_decoding.py @@ -0,0 +1,22 @@ +import tilelang.tools.bench +import example_gqa_decode +import example_mha_inference + + +def bench_example_gqa_decode(): + tilelang.tools.bench.process_func(example_gqa_decode.benchmark) + + +def bench_example_mha_inference(): + tilelang.tools.bench.process_func( + example_mha_inference.benchmark, + BATCH=1, + H=32, + Q_CTX=128, + KV_CTX=2048, + D_HEAD=128, + causal=False) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 46d9beeaa..2e1c392cd 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -494,6 +494,19 @@ def main(batch: int = 1, print(f"Ref latency: {ref_latency}") +def benchmark(batch: int = 1, + heads: int = 32, + groups: int = 8, + kv_seqlen: int = 8192, + dim: int = 128, + tune: bool = False): + batch, heads, groups, kv_seqlen, dim = batch, heads, groups, kv_seqlen, dim + config, sm_version = get_heuristic_config() + kernel = flashattn(batch, heads, groups, kv_seqlen, dim, **config) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--batch', type=int, default=1, help='batch size') diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index 0360b3e2b..bf0bd82a9 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -325,5 +325,17 @@ def main(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): print("{:.2f} TFlops".format(total_flops / latency * 1e-9)) +def benchmark(BATCH=1, H=32, Q_CTX=128, KV_CTX=8192, D_HEAD=128, causal=False): + flops_per_matmul = 2.0 * BATCH * H * Q_CTX * KV_CTX * D_HEAD + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + BLOCK_M = 128 + BLOCK_N = 64 + kernel = flashattn(BATCH, H, Q_CTX, KV_CTX, D_HEAD, causal, BLOCK_M, BLOCK_N) + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + return profiler.do_bench(n_warmup=10, n_repeat=10) + + if __name__ == "__main__": main() diff --git a/examples/fusedmoe/bench_example_fusedmoe.py b/examples/fusedmoe/bench_example_fusedmoe.py new file mode 100644 index 000000000..65698fcdd --- /dev/null +++ b/examples/fusedmoe/bench_example_fusedmoe.py @@ -0,0 +1,18 @@ +import tilelang.tools.bench +import example_fusedmoe_tilelang + + +def bench_example_fusedmoe_tilelang(): + tilelang.tools.bench.process_func( + example_fusedmoe_tilelang.benchmark, + d_hidden=1024, + d_expert=256, + n_routed_experts=8, + n_shared_experts=1, + n_experts_per_token=4, + batch_size=1, + seq_len=1024) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index a8d684965..b7c89d08a 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -551,5 +551,31 @@ def main(d_hidden=7168, print("āœ… Tilelang and Torch match") +def benchmark(d_hidden=7168, + d_expert=2048, + n_routed_experts=8, + n_shared_experts=1, + n_experts_per_token=4, + batch_size=1, + seq_len=8192): + config = { + "dhidden": d_hidden, + "dexpert": d_expert, + "nroutedexperts": n_routed_experts, + "nsharedexperts": n_shared_experts, + "nexpertspertoken": n_experts_per_token, + "bs": batch_size, + "seqlen": seq_len, + "seed": 81394 + } + from tilelang.profiler import do_bench + data = generate_input(**config) + + def run_custom_kernel(): + custom_kernel(data).to(torch.float32) + + return do_bench(run_custom_kernel, warmup=100, rep=1000) + + if __name__ == "__main__": main() diff --git a/examples/gemm/bench_example_gemm.py b/examples/gemm/bench_example_gemm.py new file mode 100644 index 000000000..78576f10d --- /dev/null +++ b/examples/gemm/bench_example_gemm.py @@ -0,0 +1,25 @@ +import tilelang.tools.bench +import example_gemm +import example_gemm_autotune +import example_gemm_intrinsics +import example_gemm_schedule + + +def bench_example_gemm_autotune(): + tilelang.tools.bench.process_func(example_gemm_autotune.benchmark, M=1024, N=1024, K=1024) + + +def bench_example_gemm_intrinsics(): + tilelang.tools.bench.process_func(example_gemm_intrinsics.benchmark, M=1024, N=1024, K=1024) + + +def bench_example_gemm_schedule(): + tilelang.tools.bench.process_func(example_gemm_schedule.benchmark) + + +def bench_example_gemm(): + tilelang.tools.bench.process_func(example_gemm.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/gemm/example_gemm.py b/examples/gemm/example_gemm.py index f18cd388a..35fb4836b 100644 --- a/examples/gemm/example_gemm.py +++ b/examples/gemm/example_gemm.py @@ -58,5 +58,11 @@ def main(): print(f"tilelang Latency: {latency}ms") +def benchmark(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + profiler = kernel.get_profiler() + return profiler.do_bench(backend="cupti") + + if __name__ == "__main__": main() diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 661ef1276..e406a2d85 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -261,6 +261,14 @@ def main(M: int = 4096, print(f"Ref TFlops: {2 * M * N * K / ref_latency * 1e-9}") +def benchmark(M: int = 4096, N: int = 4096, K: int = 4096): + + config = get_heuristic_config() + kernel = matmul(M, N, K, **config) + profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) + return profiler.do_bench() + + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Autotuned MatMul Benchmark") parser.add_argument("--m", type=int, default=4096, help="Matrix dimension M") diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index 5c014ce3a..fad92804b 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -181,5 +181,12 @@ def main(M=4096, N=4096, K=4096): profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) +def benchmark(M=4096, N=4096, K=4096): + in_dtype, out_dtype, accum_dtype = "float16", "float16", "float32" + kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler = kernel.get_profiler() + return profiler.do_bench(profiler.func, warmup=25) + + if __name__ == "__main__": main(M=4096, N=4096, K=4096) diff --git a/examples/gemm/example_gemm_persistent.py b/examples/gemm/example_gemm_persistent.py index a2a7122d3..b8e9b7318 100644 --- a/examples/gemm/example_gemm_persistent.py +++ b/examples/gemm/example_gemm_persistent.py @@ -149,6 +149,18 @@ def main(M=4096, N=4096, K=4096): print(f"Persistent GEMM Speedup: {non_persistent_latency / persistent_latency}") +def benchmark(M=4096, N=4096, K=4096): + BLOCK_M = 128 + BLOCK_N = 256 + BLOCK_K = 64 + threads = 256 + num_stages = 3 + persistent_kernel = matmul_persistent(M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, threads, num_stages) + persistent_profiler = persistent_kernel.get_profiler( + tensor_supply_type=tilelang.TensorSupplyType.Randn) + return persistent_profiler.do_bench(warmup=500) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--M', type=int, default=8192, help='M dimension') diff --git a/examples/gemm/example_gemm_schedule.py b/examples/gemm/example_gemm_schedule.py index f4727412b..b74e98c30 100644 --- a/examples/gemm/example_gemm_schedule.py +++ b/examples/gemm/example_gemm_schedule.py @@ -65,5 +65,18 @@ def main(): print(kernel.get_kernel_source()) +def benchmark(): + kernel = matmul(1024, 1024, 1024, 128, 128, 32) + import torch + a = torch.randn(1024, 1024).cuda().half() + b = torch.randn(1024, 1024).cuda().half() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/gemm_fp8/bench_example_gemm_fp8.py b/examples/gemm_fp8/bench_example_gemm_fp8.py new file mode 100644 index 000000000..9b82da922 --- /dev/null +++ b/examples/gemm_fp8/bench_example_gemm_fp8.py @@ -0,0 +1,20 @@ +import tilelang.tools.bench +import example_tilelang_gemm_fp8 +import example_tilelang_gemm_fp8_2xAcc +import example_tilelang_gemm_fp8_intrinsic + + +def bench_example_tilelang_gemm_fp8_2xAcc(): + tilelang.tools.bench.process_func(example_tilelang_gemm_fp8_2xAcc.benchmark) + + +def bench_example_tilelang_gemm_fp8_intrinsic(): + tilelang.tools.bench.process_func(example_tilelang_gemm_fp8_intrinsic.benchmark) + + +def bench_example_tilelang_gemm_fp8(): + tilelang.tools.bench.process_func(example_tilelang_gemm_fp8.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index a403ed068..cb843707e 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -61,5 +61,18 @@ def main(): test_gemm_fp8(1024, 1024, 1024, 'float8_e5m2') +def benchmark(): + M, N, K = 1024, 1024, 1024 + dtype = "float8_e4m3" + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(warmup=25) + dtype = "float8_e5m2" + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(warmup=25) + return (latency_e4m3 + latency_e5m2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index 1d9207aff..650b0fa55 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -78,5 +78,18 @@ def main(): test_gemm_fp8(1024, 1024, 8192, 'float8_e5m2') +def benchmark(): + M, N, K = 1024, 1024, 8192 + dtype = "float8_e4m3" + kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(warmup=25) + dtype = "float8_e5m2" + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(warmup=25) + return (latency_e4m3 + latency_e5m2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index ed44aab69..c8808fdfa 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -219,5 +219,19 @@ def main(): assert_tl_matmul_correctness(128, 128, 128, "float8_e5m2", "float32", "float32") +def benchmark(): + M, N, K = 128, 128, 128 + out_dtype, accum_dtype = "float32", "float32" + in_dtype = "float8_e4m3" + kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e4m3 = profiler_e4m3.do_bench(warmup=25) + in_dtype = "float8_e5m2" + kernel_e5m2 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(warmup=25) + return (latency_e4m3 + latency_e5m2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/bench_example_gemm_splitk.py b/examples/gemm_splitk/bench_example_gemm_splitk.py new file mode 100644 index 000000000..9b99924e5 --- /dev/null +++ b/examples/gemm_splitk/bench_example_gemm_splitk.py @@ -0,0 +1,15 @@ +import tilelang.tools.bench +import example_tilelang_gemm_splitk +import example_tilelang_gemm_splitk_vectorize_atomicadd + + +def bench_example_tilelang_gemm_splitk(): + tilelang.tools.bench.process_func(example_tilelang_gemm_splitk.benchmark) + + +def bench_example_tilelang_gemm_splitk_vectorize_atomicadd(): + tilelang.tools.bench.process_func(example_tilelang_gemm_splitk_vectorize_atomicadd.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk.py b/examples/gemm_splitk/example_tilelang_gemm_splitk.py index c96669711..acadd5aee 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk.py @@ -67,5 +67,27 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def benchmark(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py index 145d622ed..c114f8ff6 100644 --- a/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py +++ b/examples/gemm_splitk/example_tilelang_gemm_splitk_vectorize_atomicadd.py @@ -66,5 +66,28 @@ def main(): torch.testing.assert_close(c, ref_c.to(c.dtype), rtol=1e-2, atol=1e-2) +def benchmark(): + M = 1024 + N = 1024 + K = 1024 + block_M = 128 + block_N = 128 + block_K = 32 + split_k = 4 + + kernel = matmul(M, N, K, block_M, block_N, block_K, split_k) + import torch + torch.random.manual_seed(42) + a = torch.randn(M, K).cuda().half() + b = torch.randn(K, N).cuda().half() + c = torch.zeros(M, N).cuda().float() + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(a, b, c) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/gemm_streamk/bench_example_tilelang_gemm_splitk.py b/examples/gemm_streamk/bench_example_tilelang_gemm_splitk.py new file mode 100644 index 000000000..54d908a14 --- /dev/null +++ b/examples/gemm_streamk/bench_example_tilelang_gemm_splitk.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_tilelang_gemm_streamk + + +def bench_example_tilelang_gemm_streamk(): + tilelang.tools.bench.process_func(example_tilelang_gemm_streamk.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index 31cf40647..09c639be9 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -201,5 +201,32 @@ def main(): torch.testing.assert_close(C, b_c, rtol=1e-2, atol=1e-2) +def benchmark(): + kernel = tl_matmul_streamk( + m, + n, + k, + streamk_tiles, + BLOCK_SIZE_M, + BLOCK_SIZE_N, + BLOCK_SIZE_K, + False, + True, + "float16", + "float16", + "float32", + 2, + 64, + ) + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + kernel(A, B, b_c) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A, B, b_c) + + return do_bench(run_kernel_only, warmup=10, rep=100) + + if __name__ == "__main__": main() diff --git a/examples/gemv/bench_example_gemv.py b/examples/gemv/bench_example_gemv.py new file mode 100644 index 000000000..33bf53723 --- /dev/null +++ b/examples/gemv/bench_example_gemv.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_gemv + + +def bench_example_gemv(): + tilelang.tools.bench.process_func(example_gemv.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 58e0114be..0c11e7837 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -375,5 +375,23 @@ def main(do_bench: bool = True): print(f"TileLang BlockReduce Latency: {tilelang_tile_latency} ms\n") +def benchmark(): + N, K = 1024, 1024 + latency = 0.0 + kernel_list = [ + naive_gemv(N, K, 128, 128), + naive_splitk_gemv(N, K, 32, 32), + splitk_gemv(N, K, 32, 32, 32), + splitk_gemv_vectorized(N, K, 2, 32), + splitk_gemv_vectorized_tvm(N, K, 2, 32), + gemv_alloc_reducer(N, K, block_M=128, block_N=128) + ] + for kernel in kernel_list: + profiler = kernel.get_profiler() + # Benchmark the TileLang kernel itself, not the PyTorch reference. + latency += profiler.do_bench(warmup=50) + return latency / len(kernel_list) + + if __name__ == "__main__": main() diff --git a/examples/linear_attention/bench_linear_attn.py b/examples/linear_attention/bench_linear_attn.py new file mode 100644 index 000000000..68be8bdfd --- /dev/null +++ b/examples/linear_attention/bench_linear_attn.py @@ -0,0 +1,15 @@ +import tilelang.tools.bench +import example_linear_attn_bwd +import example_linear_attn_fwd + + +def bench_example_linear_attn_fwd(): + tilelang.tools.bench.process_func(example_linear_attn_fwd.benchmark) + + +def bench_example_linear_attn_bwd(): + tilelang.tools.bench.process_func(example_linear_attn_bwd.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/linear_attention/example_linear_attn_bwd.py b/examples/linear_attention/example_linear_attn_bwd.py index 568bcc55f..0f60193d4 100644 --- a/examples/linear_attention/example_linear_attn_bwd.py +++ b/examples/linear_attention/example_linear_attn_bwd.py @@ -211,6 +211,21 @@ def main(B=1, S=1024, H=16, D=128): print(f'Speedup: {t1/t2:.3f}x') +def benchmark(B=1, S=1024, H=16, D=128): + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) + k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) + v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16, requires_grad=True) + do = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q = l2norm_fwd(q)[0].requires_grad_(True) + k = l2norm_fwd(k)[0].requires_grad_(True) + kernel = tl_fused_chunk_bwd_kernel(B, S, H, D, D) + dQ = torch.zeros_like(q, dtype=torch.float32) + dK = torch.zeros_like(k, dtype=torch.float32) + dV = torch.zeros_like(v, dtype=torch.float32) + kernel(q, k, v, do, dQ, dK, dV) + return do_bench(lambda: kernel(q, k, v, do, dQ, dK, dV), backend='cupti') + + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--B', type=int, default=8, help='Batch size') diff --git a/examples/linear_attention/example_linear_attn_fwd.py b/examples/linear_attention/example_linear_attn_fwd.py index 03900a7e6..bbf1c9b79 100644 --- a/examples/linear_attention/example_linear_attn_fwd.py +++ b/examples/linear_attention/example_linear_attn_fwd.py @@ -144,6 +144,18 @@ def main(B=1, S=512, H=16, D=128): print(f'Speedup: {t1/t2:.3f}x') +def benchmark(B=1, S=512, H=16, D=128): + q = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + k = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + v = torch.randn((B, S, H, D), device='cuda', dtype=torch.float16) + q, _ = l2norm_fwd(q) + k, _ = l2norm_fwd(k) + B, S, H, D = q.shape + kernel = tl_fused_chunk_fwd_kernel(B, S, H, D, D) + o = torch.zeros((B, S, H, D), device='cuda', dtype=torch.float32) + return do_bench(lambda: kernel(q, k, v, o), backend='cupti') + + if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--B', type=int, default=8, help='Batch size') diff --git a/examples/minference/bench_vs_sparse_attn.py b/examples/minference/bench_vs_sparse_attn.py new file mode 100644 index 000000000..373719669 --- /dev/null +++ b/examples/minference/bench_vs_sparse_attn.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_vertical_slash_sparse_attn + + +def bench_example_vertical_slash_sparse_attn(): + tilelang.tools.bench.process_func(example_vertical_slash_sparse_attn.benchmark, argv=[]) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/minference/example_vertical_slash_sparse_attn.py b/examples/minference/example_vertical_slash_sparse_attn.py index 48df3e091..aa133639c 100644 --- a/examples/minference/example_vertical_slash_sparse_attn.py +++ b/examples/minference/example_vertical_slash_sparse_attn.py @@ -596,5 +596,70 @@ def main(argv=None): print(f"speedup: {triton_time / tilelang_time:.2f}x") +def benchmark(argv=None): + parser = argparse.ArgumentParser() + + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=1) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--head_dim", type=int, default=64) + parser.add_argument("--vertical_size", type=int, default=1000) + parser.add_argument("--slash_size", type=int, default=200) + + args = parser.parse_args(argv) + + BATCH, N_HEADS, SEQ_LEN, D_HEAD = args.batch, args.heads, args.seq_len, args.head_dim + + vertical_size, slash_size = args.vertical_size, args.slash_size + + torch.manual_seed(0) + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + q_len = SEQ_LEN + + vertical_size, slash_size = min(q_len, vertical_size), min(q_len, slash_size) + last_q = 64 + qk = torch.einsum('bhmk, bhnk -> bhmn', q[:, :, -last_q:, :], k) + arange = torch.arange(last_q, device="cuda") + qk[:, :, :, -last_q:] = torch.where(arange[None, None, :, None] >= arange[None, None, None, :], + qk[:, :, :, -last_q:], -torch.inf) + qk = torch.nn.functional.softmax(qk, dim=-1, dtype=torch.float32) + vertical = qk.sum(-2, keepdim=True) + vertical[..., :30] = torch.inf + vertical_topk = torch.topk(vertical, vertical_size, -1).indices + + slash = sum_all_diagonal_matrix(qk)[..., :-last_q + 1] + slash[..., -30:] = torch.inf + + slash = (q_len - 1) - torch.topk(slash, slash_size, -1).indices + + block_size_M = 64 + batch_size, num_heads, context_size, head_dim = q.shape + pad = (block_size_M - context_size) & (block_size_M - 1) + if pad == block_size_M: + pad = 0 + q = torch.nn.functional.pad(q, [0, 0, 0, pad, 0, 0, 0, 0]) + k = torch.nn.functional.pad(k, [0, 0, 0, pad, 0, 0, 0, 0]) + v = torch.nn.functional.pad(v, [0, 0, 0, pad, 0, 0, 0, 0]) + + if head_dim not in [16, 32, 64, 128, 256, 512]: + target_dim = 2**math.ceil(math.log2(head_dim)) - head_dim + q = torch.nn.functional.pad(q, [0, target_dim, 0, 0, 0, 0, 0, 0]) + k = torch.nn.functional.pad(k, [0, target_dim, 0, 0, 0, 0, 0, 0]) + v = torch.nn.functional.pad(v, [0, target_dim, 0, 0, 0, 0, 0, 0]) + + vertical_topk = vertical_topk.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( + dim=-1, descending=False)[0] + slash = slash.to(torch.int32).reshape((batch_size, num_heads, -1)).sort( + dim=-1, descending=True)[0] + + tl_kernel = _tl_vs_sparse_flashattn(batch_size, num_heads, context_size, head_dim, + vertical_topk.shape[2], slash.shape[2]) + + return do_bench(lambda: tl_kernel) + + if __name__ == "__main__": main() diff --git a/examples/seer_attention/bench_block_sparse_attn_tilelang.py b/examples/seer_attention/bench_block_sparse_attn_tilelang.py new file mode 100644 index 000000000..a7a7a4643 --- /dev/null +++ b/examples/seer_attention/bench_block_sparse_attn_tilelang.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import block_sparse_attn_tilelang + + +def bench_block_sparse_attn_tilelang(): + tilelang.tools.bench.process_func(block_sparse_attn_tilelang.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/seer_attention/block_sparse_attn_tilelang.py b/examples/seer_attention/block_sparse_attn_tilelang.py index 219d3ee35..c6ca0b106 100644 --- a/examples/seer_attention/block_sparse_attn_tilelang.py +++ b/examples/seer_attention/block_sparse_attn_tilelang.py @@ -266,5 +266,61 @@ def main(): test_topk_sparse_attention_qlen_lt_klen() +def benchmark(): + BATCH, N_HEADS, SEQ_LEN, D_HEAD = 4, 2, 256, 64 + TOPK = 2 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, SEQ_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + downsample_factor = BLOCK + downsample_len = math.ceil(SEQ_LEN / downsample_factor) + x_ds = torch.randn([BATCH, N_HEADS, downsample_len, downsample_len], + device='cuda', + dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn( + BATCH, N_HEADS, SEQ_LEN, SEQ_LEN, D_HEAD, downsample_len, is_causal=True) + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(q, k, v, block_mask.to(torch.int8)) + + latency_1 = do_bench(run_kernel_only) + + BATCH, N_HEADS = 1, 1 + Q_LEN, K_LEN, D_HEAD = 128, 256, 64 + TOPK = 1 + BLOCK = 64 + torch.manual_seed(0) + + q = torch.randn(BATCH, N_HEADS, Q_LEN, D_HEAD, device='cuda', dtype=torch.float16) + k = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) + v = torch.randn(BATCH, N_HEADS, K_LEN, D_HEAD, device='cuda', dtype=torch.float16) + + downsample_factor = BLOCK + downsample_len = math.ceil(K_LEN / downsample_factor) + x_ds = torch.randn( + BATCH, N_HEADS, downsample_len, downsample_len, device='cuda', dtype=torch.float16) + x_ds[:, :, :, 0] = 100 + block_mask = get_sparse_attn_mask_from_topk(x_ds, topk=TOPK) + + kernel = blocksparse_flashattn( + BATCH, N_HEADS, Q_LEN, K_LEN, D_HEAD, downsample_len, is_causal=True) + print(kernel.get_kernel_source()) + + def run_kernel_only2(): + kernel(q, k, v, block_mask.to(torch.int8)) + + latency_2 = do_bench(run_kernel_only2) + + return (latency_1 + latency_2) / 2 + + if __name__ == "__main__": main() diff --git a/examples/sparse_tensorcore/bench_example_sparse_tensorcore.py b/examples/sparse_tensorcore/bench_example_sparse_tensorcore.py new file mode 100644 index 000000000..40f491e30 --- /dev/null +++ b/examples/sparse_tensorcore/bench_example_sparse_tensorcore.py @@ -0,0 +1,11 @@ +import tilelang.tools.bench +import tilelang +import tilelang_example_sparse_tensorcore + + +def bench_example_sparse_tensorcore(): + tilelang.tools.bench.process_func(tilelang_example_sparse_tensorcore.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py index 59c79c283..b4386c4af 100644 --- a/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py +++ b/examples/sparse_tensorcore/tilelang_example_sparse_tensorcore.py @@ -120,5 +120,32 @@ def main(): run_gemm_sp(512, 1024, 768, "float16", "float16", "float32", 128, 128, 128, 2, 128) +def benchmark(): + M, N, K, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, num_threads = 512, 1024, 768, 128, 128, 128, "float16", "float16", "float32", 2, 128 + kernel = matmul_sp( + M, + N, + K, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + num_threads, + ) + A = generate_2_to_4_sparse_tensor((M, K), dtype=torch.float16, device='cuda') + A_sparse, E = compress_sm90(A, block_k=block_K, transposed=False) + B = torch.randn((K, N), device='cuda', dtype=torch.float16) + + from tilelang.profiler import do_bench + + def run_kernel_only(): + kernel(A_sparse, E, B) + + return do_bench(run_kernel_only) + + if __name__ == "__main__": main() diff --git a/examples/topk/bench_topk_tilelang.py b/examples/topk/bench_topk_tilelang.py new file mode 100644 index 000000000..2ec910743 --- /dev/null +++ b/examples/topk/bench_topk_tilelang.py @@ -0,0 +1,10 @@ +import tilelang.tools.bench +import example_topk + + +def bench_example_topk(): + tilelang.tools.bench.process_func(example_topk.benchmark) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/topk/example_topk.py b/examples/topk/example_topk.py index 0ca19fb18..db46fd7c3 100644 --- a/examples/topk/example_topk.py +++ b/examples/topk/example_topk.py @@ -93,5 +93,29 @@ def main(argv=None): print(f"Tilelang latency: {tilelang_latency}") +def benchmark(argv=None): + parser = argparse.ArgumentParser() + parser.add_argument("--M", type=int, default=320, help="num_tokens") + parser.add_argument("--N", type=int, default=128, help="num_experts") + parser.add_argument("--topk", type=int, default=6, help="topk") + parser.add_argument("--blk_m", type=int, default=64, help="blk_m") + # In benchmark mode, ignore process-wide sys.argv unless an explicit argv is provided. + args = parser.parse_args(argv or []) + M, N, topk, blk_m = args.M, args.N, args.topk, args.blk_m + + logits = torch.rand((M, N), device="cuda", dtype=torch.float32) + + kernel = tl_topk(M=M, N=N, topk=topk, blk_m=blk_m) + tl_gates, tl_indices = kernel(logits) + + torch_gates, torch_indices = ref_program(logits, topk) + + torch.testing.assert_close(tl_gates, torch_gates) + torch.testing.assert_close(tl_indices, torch_indices) + + profiler = kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Auto) + return profiler.do_bench() + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/bench_example_warp_specialize.py b/examples/warp_specialize/bench_example_warp_specialize.py new file mode 100644 index 000000000..3d8c90db1 --- /dev/null +++ b/examples/warp_specialize/bench_example_warp_specialize.py @@ -0,0 +1,29 @@ +import tilelang.tools.bench +import example_warp_specialize_gemm_barrierpipe_stage2 +import example_warp_specialize_gemm_copy_0_gemm_1 +import example_warp_specialize_gemm_copy_1_gemm_0 +import example_warp_specialize_gemm_softpipe_stage2 + + +def bench_example_warp_specialize_gemm_barrierpipe_stage2(): + tilelang.tools.bench.process_func( + example_warp_specialize_gemm_barrierpipe_stage2.benchmark, M=1024, N=1024, K=1024) + + +def bench_example_warp_specialize_gemm_copy_0_gemm_1(): + tilelang.tools.bench.process_func( + example_warp_specialize_gemm_copy_0_gemm_1.benchmark, M=1024, N=1024, K=1024) + + +def bench_example_warp_specialize_gemm_copy_1_gemm_0(): + tilelang.tools.bench.process_func( + example_warp_specialize_gemm_copy_1_gemm_0.benchmark, M=1024, N=1024, K=1024) + + +def bench_example_warp_specialize_gemm_softpipe_stage2(): + tilelang.tools.bench.process_func( + example_warp_specialize_gemm_softpipe_stage2.benchmark, M=1024, N=1024, K=1024) + + +if globals().get("__name__") == "__main__": + tilelang.tools.bench.main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py index b738a4b9c..0c22d41af 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_barrierpipe_stage2.py @@ -89,5 +89,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def benchmark(M=16384, N=16384, K=16384): + tilelang.disable_cache() + block_M = 128 + block_N = 128 + block_K = 64 + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench() + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py index 9ba9f6816..6649953e0 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_0_gemm_1.py @@ -82,5 +82,27 @@ def main(M=1024, N=1024, K=1024): print(f"Latency: {latency} ms") +def benchmark(M=1024, N=1024, K=1024): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_0_gemm_1(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench() + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py index faaf48c64..47f011a22 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_copy_1_gemm_0.py @@ -83,5 +83,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def benchmark(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul_warp_specialize_copy_1_gemm_0(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench() + + if __name__ == "__main__": main() diff --git a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py index 3b1d86719..37c587857 100644 --- a/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py +++ b/examples/warp_specialize/example_warp_specialize_gemm_softpipe_stage2.py @@ -79,5 +79,28 @@ def main(M=16384, N=16384, K=16384): print(f"Latency: {latency} ms") +def benchmark(M=16384, N=16384, K=16384): + block_M = 128 + block_N = 128 + block_K = 64 + + jit_kernel = matmul(M, N, K, block_M, block_N, block_K) + + import torch + + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + c = jit_kernel(a, b) + + ref_c = a @ b + + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + + profiler = jit_kernel.get_profiler(tensor_supply_type=tilelang.TensorSupplyType.Normal) + + return profiler.do_bench() + + if __name__ == "__main__": main() diff --git a/maint/scripts/bench_entry.py b/maint/scripts/bench_entry.py new file mode 100644 index 000000000..69d2ad02e --- /dev/null +++ b/maint/scripts/bench_entry.py @@ -0,0 +1,4 @@ +import tilelang.tools.bench as b + +if __name__ == "__main__": + b.bench_all() diff --git a/maint/scripts/ci_performance.py b/maint/scripts/ci_performance.py index 998e7b650..3a1ac876c 100644 --- a/maint/scripts/ci_performance.py +++ b/maint/scripts/ci_performance.py @@ -1,49 +1,75 @@ import subprocess import re from tabulate import tabulate +import tilelang +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns -import os - -env = os.environ.copy() -env["TILELANG_CLEAR_CACHE"] = "1" +tilelang.disable_cache() def parse_output(output): data = {} for line in output.split('\n'): line = line.strip() - if line.startswith('Latency:'): - match = re.search(r'Latency: ([\d.]+)', line) - data['latency'] = match.group(1) if match else 'N/A' - elif line.startswith('TFlops:'): - match = re.search(r'TFlops: ([\d.]+)', line) - data['best_tflops'] = match.group(1) if match else 'N/A' - elif line.startswith('Config:'): - data['config'] = line.split('Config: ')[-1] - elif line.startswith('Reference TFlops:'): - match = re.search(r'Reference TFlops: ([\d.]+)', line) - data['ref_tflops'] = match.group(1) if match else 'N/A' + m = re.match(r"\|\s*([^\|]+)\s*\|\s*([0-9\.]+)\s*\|", line) + if m is not None: + data[m.group(1)] = float(m.group(2)) return data -output_v1 = subprocess.run(['./tl/bin/python', './maint/scripts/performance.py'], - capture_output=True, - text=True, - env=env).stdout -data_v1 = parse_output(output_v1) +output_v1 = subprocess.run( + ['./tl/bin/python', '-c', 'import tilelang.tools.bench as b; b.bench_all()'], + capture_output=True, + text=True).stdout +output_v2 = subprocess.run( + ['./tll/bin/python', '-c', 'import tilelang.tools.bench as b; b.bench_all()'], + capture_output=True, + text=True).stdout -output_v2 = subprocess.run(['./tll/bin/python', './maint/scripts/performance.py'], - capture_output=True, - text=True, - env=env).stdout +data_v1 = parse_output(output_v1) data_v2 = parse_output(output_v2) +table = [] +for key in data_v1.keys(): + speedup = data_v1[key] / data_v2[key] + table.append([key, data_v1[key], data_v2[key], speedup]) +table.sort(key=lambda x: x[-1]) + +headers = ["File", "Original Latency", "Current Latency", "Speedup"] + +with open("bench.md", "w") as f: + f.write( + tabulate(table, headers=headers, tablefmt="github", stralign="left", numalign="decimal")) + f.write("\n") -table = [[ - "original", data_v1['latency'], data_v1['best_tflops'], data_v1['ref_tflops'], data_v1['config'] -], [ - "current", data_v2['latency'], data_v2['best_tflops'], data_v2['ref_tflops'], data_v2['config'] -]] +df = pd.DataFrame(table, columns=headers) +df = df.sort_values("Speedup", ascending=False).reset_index(drop=True) +fig_width = max(0, len(df) * 0.35) +plt.figure(figsize=(fig_width, 8)) +sns.set_theme(style="whitegrid", font_scale=0.9) +bar_colors = sns.color_palette("magma", len(df)) +bars = plt.bar(range(len(df)), df["Speedup"], color=bar_colors, edgecolor="black") +top3_idx = df.nlargest(3, "Speedup").index +bot3_idx = df.nsmallest(3, "Speedup").index +label_idx = set(top3_idx.tolist() + bot3_idx.tolist()) -headers = ["version", "Best Latency (s)", "Best TFlops", "Reference TFlops", "Best Config"] +for i, val in enumerate(df["Speedup"]): + if i in label_idx: + plt.text( + i, + val + 0.02, + f"{val:.2f}x", + ha="center", + va="bottom", + color="red", + fontsize=8, + fontweight="bold") -print(tabulate(table, headers=headers, tablefmt="github", stralign="left", numalign="decimal")) +plt.xticks(range(len(df)), df["File"], rotation=70, ha='right', fontsize=12) +plt.ylabel("Current Speedup vs Original", fontsize=14) +plt.title("Current Speedup vs Original", fontsize=14, fontweight="bold") +plt.ylim(0, max(df["Speedup"]) * 1.2) +sns.despine() +plt.tight_layout() +plt.savefig("bench.png", dpi=300) diff --git a/tilelang/tools/bench.py b/tilelang/tools/bench.py new file mode 100644 index 000000000..a9f56de5c --- /dev/null +++ b/tilelang/tools/bench.py @@ -0,0 +1,184 @@ +import os +import re +import sys +import inspect +import traceback +import contextlib +import warnings +from tabulate import tabulate +import matplotlib.pyplot as plt +import importlib.util +import multiprocessing as mp + +__all__ = ["main", "process_func"] +_RECORDS = [] + + +@contextlib.contextmanager +def suppress_output(): + # Context manager that redirects stdout/stderr to os.devnull (supports fileno) + devnull = open(os.devnull, "w") + saved_stdout = sys.stdout + saved_stderr = sys.stderr + sys.stdout = devnull + sys.stderr = devnull + try: + yield + finally: + sys.stdout = saved_stdout + sys.stderr = saved_stderr + devnull.close() + + +def process_func(func, *args, name=None, **kwargs): + import torch + latency = None + try: + with suppress_output(): + latency = func(*args, **kwargs) + torch.cuda.synchronize() + except Exception: + pass + + if name is None: + name = func.__module__ + if latency is not None: + _RECORDS.append((f"{name}", latency)) + print(f"{name}", latency) + else: + warnings.warn( + f"benchmark for {name} failed", + RuntimeWarning, + stacklevel=2, + ) + + +def analyze_records(records, out_dir): + # Analyze the data and draw a chart + records.sort(key=lambda x: x[1]) + headers = ["Functions", "Avg Latency (ms)"] + print( + tabulate(_RECORDS, headers=headers, tablefmt="github", stralign="left", numalign="decimal")) + + names = [r[0] for r in records] + lats = [r[1] for r in records] + plt.figure(figsize=(max(len(names) * 2.2, 6), 6)) + plt.bar(names, lats) + plt.xlabel("Latency (ms)") + plt.title("Benchmark Results") + out_path = os.path.join(out_dir, "bench_result.png") + + plt.tight_layout() + plt.savefig(out_path, dpi=200) + plt.close() + + print(f"Saved Bar chart to {out_path}") + + +def _load_module(full_path): + module_name = os.path.splitext(os.path.basename(full_path))[0] + spec = importlib.util.spec_from_file_location(module_name, full_path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def _bench_worker(file_path, func_name, queue): + import torch + a = torch.randn(1, device="cuda") + b = torch.randn(1, device="cuda") + (a + b).sum().item() + torch.cuda.synchronize() + local_records = [] + global _RECORDS + _RECORDS = local_records + try: + mod = _load_module(file_path) + func = getattr(mod, func_name) + func() + except Exception: + traceback.print_exc() + finally: + queue.put(local_records) + + +def main(): + # Entry point — automatically run all bench_* functions in caller file. + mp.set_start_method("spawn", force=True) + test_file = inspect.getsourcefile(sys._getframe(1)) + out_dir = os.path.dirname(test_file) + module = {} + with open(test_file) as f: + exec(f.read(), module) + + bench_funcs = [] + for name, func in module.items(): + if name.startswith("bench_") and callable(func): + bench_funcs.append((test_file, name)) + + queue = mp.Queue() + + for file_path, func_name in bench_funcs: + p = mp.Process(target=_bench_worker, args=(file_path, func_name, queue)) + p.start() + p.join() + + if p.exitcode == 0: + try: + child_records = queue.get_nowait() + except Exception: + child_records = [] + _RECORDS.extend(child_records) + else: + print(f"[SKIP] {file_path}:{func_name} crashed, skipping this benchmark.") + + print(len(_RECORDS)) + + analyze_records(_RECORDS, out_dir) + + +def bench_all(): + # Do benchmark for all bench_* functions in examples + mp.set_start_method("spawn", force=True) + current_dir = os.path.dirname(os.path.abspath(__file__)) + examples_root = os.path.abspath(os.path.join(current_dir, "../../examples")) + + bench_funcs = [] + added_roots = set() + + for root, _, files in os.walk(examples_root): + for file_name in files: + if re.match(r"^bench_.*\.py$", file_name): + full_path = os.path.join(root, file_name) + if root not in added_roots: + sys.path.insert(0, root) + added_roots.add(root) + mod = _load_module(full_path) + for name in dir(mod): + if name.startswith("bench_"): + func = getattr(mod, name) + if callable(func): + bench_funcs.append((full_path, name)) + + queue = mp.Queue() + + for file_path, func_name in bench_funcs: + p = mp.Process(target=_bench_worker, args=(file_path, func_name, queue)) + p.start() + p.join() + + if p.exitcode == 0: + try: + child_records = queue.get_nowait() + except Exception: + child_records = [] + _RECORDS.extend(child_records) + else: + print(f"[SKIP] {file_path}:{func_name} crashed, skipping this benchmark.") + + print(len(_RECORDS)) + + if _RECORDS: + print(tabulate(_RECORDS, tablefmt="github", stralign="left", numalign="decimal")) + else: + print("[WARN] no benchmark records collected.")