-
Notifications
You must be signed in to change notification settings - Fork 593
refactor: backend_requirement + supported_compute_capability decorator for gemm #2000
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughThis PR introduces comprehensive GPU infrastructure enhancements across Hopper architecture support, quantization frameworks, and performance profiling. Key additions include Hopper GMMA-based fused attention (FMHA v2), batch POD (prefill+decode) attention kernels, explicit validation guards for grouped FP8 GEMM operations, extensive benchmark suite improvements, and systematic replacement of direct CUDA device management with RAII guards. Multiple deprecations and normalizations of backend naming conventions occur alongside CI/workflow configuration updates. Changes
Sequence Diagram(s)sequenceDiagram
participant Prefill as Prefill Phase
participant KVCache as KV Cache
participant Decode as Decode Phase
participant Output as Output
Prefill->>KVCache: Store Q/K/V into paged cache<br/>(with optional FP8 quantization)
Prefill->>Output: Write initial attention output
Note over Prefill,KVCache: Prefill completes,<br/>plan_info prepared
Decode->>KVCache: Load K/V from paged cache<br/>(retrieve block offsets, indices)
Decode->>Decode: Compute QΒ·K attention scores<br/>(with optional attention sink masking)
Decode->>Decode: Apply softmax normalization
Decode->>KVCache: Load V from cache
Decode->>Output: Write final attention output
Note over Decode,Output: Decode batches process<br/>with separate workspace
Estimated code review effortπ― 5 (Critical) | β±οΈ ~120+ minutes Areas requiring extra attention:
Possibly related issues
Possibly related PRs
Suggested labels
Suggested reviewers
Pre-merge checks and finishing touchesβ Failed checks (2 warnings)
β Passed checks (1 passed)
β¨ Finishing touches
π§ͺ Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
Summary of ChangesHello @jimmyzho, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly refactors the General Matrix Multiply (GEMM) operations by implementing a new decorator-based system. The primary goal is to streamline and consolidate the validation of inputs and hardware compatibility requirements for various GEMM functions. By abstracting these checks into decorators and dedicated problem-size validation functions, the core logic of the GEMM implementations becomes cleaner and more focused, while ensuring consistent and informative error handling across the codebase. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with π and π on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request is a great refactoring that introduces @backend_requirement and @supported_compute_capability decorators to centralize validation logic for GEMM functions. This improves code clarity and maintainability by separating validation from the core logic. The conversion of assert statements to ValueError exceptions is also a good improvement for more robust error handling.
I've found a critical bug that would lead to a runtime error, and a couple of missing validation checks that were present in the original code. My review includes suggestions to fix these issues. Overall, the direction of the changes is excellent.
flashinfer/gemm.py
Outdated
| if out is None: | ||
| if out_dtype is None: | ||
| out_dtype = torch.bfloat16 | ||
| else: | ||
| if out_dtype is None: | ||
| out_dtype = out.dtype | ||
| if out.shape != (a.shape[0], n): | ||
| raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}") | ||
| if out.dtype != out_dtype: | ||
| raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}") | ||
|
|
||
| _validate_fp8_output_dtype(out_dtype) | ||
|
|
||
| n = b.shape[1] | ||
| k = b.shape[2] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable n is used on line 2624 before it is assigned a value on line 2631. This will raise an UnboundLocalError if out is not None. The shape variables n and k should be defined before they are used in the shape checks.
| if out is None: | |
| if out_dtype is None: | |
| out_dtype = torch.bfloat16 | |
| else: | |
| if out_dtype is None: | |
| out_dtype = out.dtype | |
| if out.shape != (a.shape[0], n): | |
| raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}") | |
| if out.dtype != out_dtype: | |
| raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}") | |
| _validate_fp8_output_dtype(out_dtype) | |
| n = b.shape[1] | |
| k = b.shape[2] | |
| n = b.shape[1] | |
| k = b.shape[2] | |
| if out is None: | |
| if out_dtype is None: | |
| out_dtype = torch.bfloat16 | |
| else: | |
| if out_dtype is None: | |
| out_dtype = out.dtype | |
| if out.shape != (a.shape[0], n): | |
| raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}") | |
| if out.dtype != out_dtype: | |
| raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}") | |
| _validate_fp8_output_dtype(out_dtype) |
| num_groups, n, k_ = b.shape | ||
| m_, n_ = d.shape | ||
| m__ = m_indices.numel() | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The check for positive dimensions n, k, and num_groups is missing. The original code had assert n > 0 and k > 0 and num_groups > 0. This check is important to prevent unexpected behavior with empty or invalid dimensions.
| if n <= 0 or k <= 0 or num_groups <= 0: | |
| raise ValueError(f"n, k, and num_groups must be positive, but got n={n}, k={k}, num_groups={num_groups}") |
| raise ValueError(f"k must be a multiple of 16, but got {k}") | ||
|
|
||
| num_groups = m_indptr.shape[0] - 1 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A check to ensure that the number of groups in b matches num_groups derived from m_indptr is missing. This was present in the original code (assert b.shape[0] == num_groups) and is important for correctness.
| if b.shape[0] != num_groups: | |
| raise ValueError(f"The number of groups in b ({b.shape[0]}) does not match the number of groups in m_indptr ({num_groups}).") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ Outside diff range comments (2)
flashinfer/deep_gemm.py (1)
1365-1410: Fix incorrect shape checks and overβstrict dtype validation in contiguous checker.
num_groups != m__is wrong:m_indicesis per-row (length m), not per-group. This will falsely reject valid inputs.- Allow
d.dtypeto be either bf16 or float32, which the runtime supports.Apply:
@@ - if m != m_ or k != k_ or n != n_ or m__ != m_ or num_groups != m__: - raise ValueError(f"Shape mismatch. m = {m}, m_ = {m_}, k = {k}, k_ = {k_}, n = {n}, n_ = {n_}, m__ = {m__}") + if m != m_ or k != k_ or n != n_ or m__ != m_: + raise ValueError( + f"Shape mismatch. m={m}, m_={m_}, k={k}, k_={k_}, n={n}, n_={n_}, |m_indices|={m__}" + ) @@ - if d.dtype != torch.bfloat16: - raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if d.dtype not in (torch.bfloat16, torch.float): + raise ValueError(f"d must be bfloat16 or float32, got {d.dtype}")Optionally, quiet Ruff warnings:
- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _ = a_fp8 + b, _ = b_fp8flashinfer/gemm.py (1)
2589-2665: Fix F821:nused before assignment in groupwise problem-size check.Move
n/kderivation before usingnin output shape checks; keep dtype validation.Apply:
@@ - if out is None: - if out_dtype is None: - out_dtype = torch.bfloat16 - else: - if out_dtype is None: - out_dtype = out.dtype - if out.shape != (a.shape[0], n): - raise ValueError(f"Shape mismatch. out.shape = {out.shape}, (a.shape[0], n) = {(a.shape[0], n)}") - if out.dtype != out_dtype: - raise ValueError(f"dtype mismatch. out.dtype = {out.dtype}, out_dtype = {out_dtype}") - - _validate_fp8_output_dtype(out_dtype) - - n = b.shape[1] - k = b.shape[2] + n = b.shape[1] + k = b.shape[2] + + if out is None: + if out_dtype is None: + out_dtype = torch.bfloat16 + else: + if out_dtype is None: + out_dtype = out.dtype + if out.shape != (a.shape[0], n): + raise ValueError( + f"Shape mismatch. out.shape={out.shape}, expected={(a.shape[0], n)}" + ) + if out.dtype != out_dtype: + raise ValueError(f"dtype mismatch. out.dtype={out.dtype}, out_dtype={out_dtype}") + + _validate_fp8_output_dtype(out_dtype)
π§Ή Nitpick comments (4)
flashinfer/deep_gemm.py (2)
1428-1432: Nit: unused variablek_.Use
_to silence RUF059.- num_groups, n, k_ = b.shape + num_groups, n, _ = b.shape
1461-1514: Masked checker: align dtype, contiguity, and unused unpack.
- Allow
d.dtypeto be bf16 or f32 to match runtime support.- Use
_for unused unpacked scale tensors.Apply:
- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _ = a_fp8 + b, _ = b_fp8 @@ - if d.dtype != torch.bfloat16: - raise ValueError(f"d must be bfloat16, but got {d.dtype}") + if d.dtype not in (torch.bfloat16, torch.float): + raise ValueError(f"d must be bfloat16 or float32, got {d.dtype}")flashinfer/gemm.py (2)
2734-2734: Silence F841: unused local.-num_groups = m_indptr.shape[0] - 1 +_num_groups = m_indptr.shape[0] - 1
2940-2940: Silence F841: unused local.-num_groups = m_indptr.shape[0] - 1 +_num_groups = m_indptr.shape[0] - 1
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (2)
flashinfer/deep_gemm.py(4 hunks)flashinfer/gemm.py(9 hunks)
π§° Additional context used
𧬠Code graph analysis (2)
flashinfer/gemm.py (2)
flashinfer/utils.py (4)
supported_compute_capability(772-852)backend_requirement(855-1028)is_sm120a_supported(504-506)is_sm121a_supported(509-511)flashinfer/deep_gemm.py (2)
_check_group_deepgemm_fp8_nt_contiguous_problem_size(1366-1409)_check_m_grouped_fp8_gemm_nt_masked_problem_size(1462-1509)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
ceil_div(575-586)round_up(589-591)supported_compute_capability(772-852)backend_requirement(855-1028)
πͺ GitHub Actions: pre-commit
flashinfer/gemm.py
[error] 2624-2624: ruff (F821): Undefined name 'n'. Cannot determine type of 'n'.
[error] 2625-2625: ruff (F821): Undefined name 'n'. Cannot determine type of 'n'.
[error] 2734-2734: ruff (F841): Local variable 'num_groups' is assigned to but never used.
[error] 2940-2940: ruff (F841): Local variable 'num_groups' is assigned to but never used.
[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook.
flashinfer/deep_gemm.py
[error] 1411-1411: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.
[error] 1511-1511: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.
[error] 1-1: Trailing whitespace detected and fixed by pre-commit hook.
πͺ Ruff (0.14.2)
flashinfer/gemm.py
2012-2012: Unused function argument: A
(ARG001)
2013-2013: Unused function argument: B
(ARG001)
2014-2014: Unused function argument: A_scale
(ARG001)
2015-2015: Unused function argument: B_scale
(ARG001)
2016-2016: Unused function argument: dtype
(ARG001)
2017-2017: Unused function argument: out
(ARG001)
2018-2018: Unused function argument: backend
(ARG001)
2026-2026: Unused function argument: A
(ARG001)
2027-2027: Unused function argument: B
(ARG001)
2028-2028: Unused function argument: A_scale
(ARG001)
2029-2029: Unused function argument: B_scale
(ARG001)
2030-2030: Unused function argument: dtype
(ARG001)
2031-2031: Unused function argument: out
(ARG001)
2032-2032: Unused function argument: backend
(ARG001)
2041-2041: Unused function argument: A_scale
(ARG001)
2042-2042: Unused function argument: B_scale
(ARG001)
2043-2043: Unused function argument: dtype
(ARG001)
2044-2044: Unused function argument: out
(ARG001)
2045-2045: Unused function argument: backend
(ARG001)
2048-2048: Avoid specifying long messages outside the exception class
(TRY003)
2052-2052: Unused function argument: A
(ARG001)
2053-2053: Unused function argument: B
(ARG001)
2054-2054: Unused function argument: A_scale
(ARG001)
2055-2055: Unused function argument: B_scale
(ARG001)
2057-2057: Unused function argument: out
(ARG001)
2058-2058: Unused function argument: backend
(ARG001)
2165-2165: Unused function argument: a
(ARG001)
2166-2166: Unused function argument: b
(ARG001)
2167-2167: Unused function argument: a_scale
(ARG001)
2168-2168: Unused function argument: b_scale
(ARG001)
2170-2170: Unused function argument: mma_sm
(ARG001)
2171-2171: Unused function argument: scale_granularity_mnk
(ARG001)
2172-2172: Unused function argument: out
(ARG001)
2173-2173: Unused function argument: out_dtype
(ARG001)
2174-2174: Unused function argument: backend
(ARG001)
2177-2177: Avoid specifying long messages outside the exception class
(TRY003)
2185-2185: Unused function argument: b
(ARG001)
2186-2186: Unused function argument: a_scale
(ARG001)
2187-2187: Unused function argument: b_scale
(ARG001)
2188-2188: Unused function argument: scale_major_mode
(ARG001)
2189-2189: Unused function argument: mma_sm
(ARG001)
2191-2191: Unused function argument: out
(ARG001)
2192-2192: Unused function argument: out_dtype
(ARG001)
2193-2193: Unused function argument: backend
(ARG001)
2196-2196: Avoid specifying long messages outside the exception class
(TRY003)
2198-2198: Avoid specifying long messages outside the exception class
(TRY003)
2206-2206: Unused function argument: a_scale
(ARG001)
2207-2207: Unused function argument: b_scale
(ARG001)
2208-2208: Unused function argument: scale_major_mode
(ARG001)
2209-2209: Unused function argument: mma_sm
(ARG001)
2210-2210: Unused function argument: scale_granularity_mnk
(ARG001)
2211-2211: Unused function argument: out
(ARG001)
2213-2213: Unused function argument: backend
(ARG001)
2216-2216: Avoid specifying long messages outside the exception class
(TRY003)
2219-2221: Avoid specifying long messages outside the exception class
(TRY003)
2595-2595: Unused function argument: scale_granularity_mnk
(ARG001)
2602-2602: Avoid specifying long messages outside the exception class
(TRY003)
2604-2604: Avoid specifying long messages outside the exception class
(TRY003)
2606-2606: Avoid specifying long messages outside the exception class
(TRY003)
2608-2608: Avoid specifying long messages outside the exception class
(TRY003)
2610-2610: Avoid specifying long messages outside the exception class
(TRY003)
2612-2612: Avoid specifying long messages outside the exception class
(TRY003)
2614-2614: Avoid specifying long messages outside the exception class
(TRY003)
2624-2624: Undefined name n
(F821)
2625-2625: Avoid specifying long messages outside the exception class
(TRY003)
2625-2625: Undefined name n
(F821)
2627-2627: Avoid specifying long messages outside the exception class
(TRY003)
2635-2635: Avoid specifying long messages outside the exception class
(TRY003)
2637-2637: Avoid specifying long messages outside the exception class
(TRY003)
2639-2639: Avoid specifying long messages outside the exception class
(TRY003)
2645-2647: Avoid specifying long messages outside the exception class
(TRY003)
2793-2793: Avoid specifying long messages outside the exception class
(TRY003)
2795-2795: Avoid specifying long messages outside the exception class
(TRY003)
2797-2797: Avoid specifying long messages outside the exception class
(TRY003)
2799-2799: Avoid specifying long messages outside the exception class
(TRY003)
2801-2801: Avoid specifying long messages outside the exception class
(TRY003)
2803-2803: Avoid specifying long messages outside the exception class
(TRY003)
2805-2805: Avoid specifying long messages outside the exception class
(TRY003)
2807-2807: Avoid specifying long messages outside the exception class
(TRY003)
2809-2809: Avoid specifying long messages outside the exception class
(TRY003)
2811-2811: Avoid specifying long messages outside the exception class
(TRY003)
2822-2822: Avoid specifying long messages outside the exception class
(TRY003)
2826-2826: Avoid specifying long messages outside the exception class
(TRY003)
2833-2833: Avoid specifying long messages outside the exception class
(TRY003)
2838-2838: Avoid specifying long messages outside the exception class
(TRY003)
2840-2840: Avoid specifying long messages outside the exception class
(TRY003)
2845-2845: Avoid specifying long messages outside the exception class
(TRY003)
2847-2847: Avoid specifying long messages outside the exception class
(TRY003)
3010-3010: Unused function argument: out_dtype
(ARG001)
3154-3154: Unused function argument: out_dtype
(ARG001)
flashinfer/deep_gemm.py
1371-1371: Unused function argument: recipe
(ARG001)
1372-1372: Unused function argument: compiled_dims
(ARG001)
1379-1379: Avoid specifying long messages outside the exception class
(TRY003)
1381-1381: Avoid specifying long messages outside the exception class
(TRY003)
1384-1384: Avoid specifying long messages outside the exception class
(TRY003)
1395-1395: Avoid specifying long messages outside the exception class
(TRY003)
1397-1397: Avoid specifying long messages outside the exception class
(TRY003)
1399-1399: Avoid specifying long messages outside the exception class
(TRY003)
1401-1401: Avoid specifying long messages outside the exception class
(TRY003)
1403-1403: Avoid specifying long messages outside the exception class
(TRY003)
1407-1407: Avoid specifying long messages outside the exception class
(TRY003)
1431-1431: Unpacked variable k_ is never used
(RUF059)
1468-1468: Unused function argument: recipe
(ARG001)
1469-1469: Unused function argument: compiled_dims
(ARG001)
1475-1475: Avoid specifying long messages outside the exception class
(TRY003)
1477-1477: Avoid specifying long messages outside the exception class
(TRY003)
1480-1480: Avoid specifying long messages outside the exception class
(TRY003)
1482-1482: Unpacked variable sfa is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1483-1483: Unpacked variable sfb is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1491-1491: Avoid specifying long messages outside the exception class
(TRY003)
1493-1493: Avoid specifying long messages outside the exception class
(TRY003)
1495-1495: Avoid specifying long messages outside the exception class
(TRY003)
1497-1497: Avoid specifying long messages outside the exception class
(TRY003)
1499-1499: Avoid specifying long messages outside the exception class
(TRY003)
1501-1501: Avoid specifying long messages outside the exception class
(TRY003)
1503-1503: Avoid specifying long messages outside the exception class
(TRY003)
1507-1507: Avoid specifying long messages outside the exception class
(TRY003)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (2)
flashinfer/gemm.py (2)
3012-3014: Forwarding check OK, but inherits dtype/shape constraints from deep_gemm.Once deep_gemm checker is fixed to allow bf16/float32 and correct
m_indicessemantics, this indirection remains valid. No changes needed here.
3161-3164: Forwarding check OK (batch path).Same note as above; relies on deep_gemm masked checker. No change required.
| @backend_requirement( | ||
| common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, | ||
| ) | ||
| def m_grouped_fp8_gemm_nt_contiguous( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Decorator misuse: missing backend_checks causes mypy error and runtime failure.
@backend_requirement requires a backend map. Also, calls pass no backend kwarg, so use a None key as the default backend and gate CC via a trivial requirement.
Apply:
-@backend_requirement(
- common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement},
+ common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size,
+)And add this requirement helper near the checker:
+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement(
+ a_fp8, b_fp8, d, m_indices, recipe=None, compiled_dims="nk"
+):
+ # CC gating only; shape/dtype validated by common_check.
+ return TrueCommittable suggestion skipped: line range outside the PR's diff.
π§° Tools
πͺ GitHub Actions: pre-commit
[error] 1411-1411: mypy: Missing positional argument 'backend_checks' in call to 'backend_requirement'.
π€ Prompt for AI Agents
In flashinfer/deep_gemm.py around lines 1411-1414, the @backend_requirement
decorator is missing a backend_checks mapping which causes mypy/runtime failures
because callers pass no backend kwarg; add a trivial requirement helper (e.g., a
function that returns True or otherwise acts as a no-op requirement) near the
existing _check_group_deepgemm_fp8_nt_contiguous_problem_size checker, then
update the decorator call to include backend_checks with a None key mapped to
that trivial requirement (backend_checks={None: trivial_requirement}) so the
decorator has a default backend entry and CC gating remains intact.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ Outside diff range comments (1)
flashinfer/deep_gemm.py (1)
1366-1414: Add missing positive-dimension checks and silence unused params.
- Validate n, k, num_groups > 0 to prevent empty/invalid shapes (was present before per prior review).
- Prefix unused parameters to satisfy linters.
Apply:
-@supported_compute_capability([100, 103]) -def _check_group_deepgemm_fp8_nt_contiguous_problem_size( +@supported_compute_capability([100, 103]) +def _check_group_deepgemm_fp8_nt_contiguous_problem_size( a_fp8: Tuple[torch.Tensor, torch.Tensor], b_fp8: Tuple[torch.Tensor, torch.Tensor], d: torch.Tensor, m_indices: torch.Tensor, - recipe: Optional[Tuple[int, int, int]] = None, - compiled_dims: str = "nk", + recipe: Optional[Tuple[int, int, int]] = None, + compiled_dims: str = "nk", ) -> bool: @@ m, k = a.shape num_groups, n, k_ = b.shape m_, n_ = d.shape m__ = m_indices.numel() # Type and shape checks @@ if m_indices.dtype != torch.int32: raise ValueError(f"m_indices must be int32, but got {m_indices.dtype}") + # Positive-dimension checks + if n <= 0 or k <= 0 or num_groups <= 0: + raise ValueError( + f"n, k, and num_groups must be positive, got n={n}, k={k}, num_groups={num_groups}" + ) # D must be N-major
β»οΈ Duplicate comments (5)
flashinfer/gemm.py (3)
2557-2560: Critical: Empty backend map causes gemm_fp8_nt_blockscaled to always error.This mirrors a previously reported issue. Add a default requirement with CC gating.
+@supported_compute_capability([100, 103, 120, 121]) +def _cutlass_gemm_fp8_nt_blockscaled_requirement( + a, b, a_scale, b_scale, scale_major_mode="MN", mma_sm=1, out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_gemm_fp8_nt_blockscaled_problem_size, -) +@backend_requirement( + {None: _cutlass_gemm_fp8_nt_blockscaled_requirement}, + common_check=_check_gemm_fp8_nt_blockscaled_problem_size, +)
2590-2655: Add missing groups check and validate out dtype early.Ensure b.shape[0] matches num_groups derived from m_indptr; set/validate out_dtype like other paths.
@@ - num_groups = m_indptr.shape[0] - 1 + num_groups = m_indptr.shape[0] - 1 + if b.shape[0] != num_groups: + raise ValueError( + f"b.shape[0] ({b.shape[0]}) must equal num_groups ({num_groups})" + )
2865-2868: Critical: group_gemm_mxfp8_mxfp4_nt_groupwise also uses {}.Provide a default requirement with CC gating.
+@supported_compute_capability([100, 103, 110, 120, 121]) +def _cutlass_group_gemm_mxfp8_mxfp4_requirement( + a, b, a_scale, b_scale, m_indptr, + mma_sm=1, tile_m=128, tile_n=128, tile_k=128, swap_ab=True, + out=None, out_dtype=None +): + return True + -@backend_requirement( - {}, - common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, -) +@backend_requirement( + {None: _cutlass_group_gemm_mxfp8_mxfp4_requirement}, + common_check=_check_group_gemm_mxfp8_mxfp4_nt_groupwise_problem_size, +)flashinfer/deep_gemm.py (2)
1416-1420: Critical: backend_requirement with empty backend map always fails at runtime.Passing {} and not supplying a backend kwarg makes the wrapper raise BackendSupportedError. Provide a default backend entry keyed by None and a trivial requirement function with proper CC gating.
Apply:
+@supported_compute_capability([100, 103]) +def _deepgemm_sm10x_requirement( + a_fp8, b_fp8, d, m_indices, recipe=None, compiled_dims="nk" +): + # CC gating only; shape/dtype validated by common_check. + return True + -@backend_requirement( - {}, - common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement}, + common_check=_check_group_deepgemm_fp8_nt_contiguous_problem_size, +)
1529-1533: Critical: same empty backend map issue for masked path.Mirror the contiguous fix by adding a default backend requirement with CC gating.
-@backend_requirement( - {}, - common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, -) +@backend_requirement( + {None: _deepgemm_sm10x_requirement}, + common_check=_check_m_grouped_fp8_gemm_nt_masked_problem_size, +)
π§Ή Nitpick comments (3)
flashinfer/deep_gemm.py (2)
1431-1438: Minor: avoid unusedk_.Use
_to silence the unused binding.- num_groups, n, k_ = b.shape + num_groups, n, _ = b.shape
1467-1527: LGTM with tiny lint nits.Logic is correct and includes positivity checks. Prefix unused
sfa/sfbto satisfy linters.- a, sfa = a_fp8 - b, sfb = b_fp8 + a, _sfa = a_fp8 + b, _sfb = b_fp8flashinfer/gemm.py (1)
2165-2235: NT-groupwise single-GEMM backend checks: solid, consider reusing for blockscaled.The requirement/check split is good. If desired, mirror the CUTLASS CC list for the blockscaled helper to keep consistency.
π Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
π Files selected for processing (2)
flashinfer/deep_gemm.py(4 hunks)flashinfer/gemm.py(9 hunks)
π§° Additional context used
𧬠Code graph analysis (2)
flashinfer/deep_gemm.py (1)
flashinfer/utils.py (4)
ceil_div(575-586)round_up(589-591)supported_compute_capability(772-852)backend_requirement(855-1028)
flashinfer/gemm.py (2)
flashinfer/utils.py (4)
supported_compute_capability(772-852)backend_requirement(855-1028)is_sm120a_supported(504-506)is_sm121a_supported(509-511)flashinfer/deep_gemm.py (2)
_check_group_deepgemm_fp8_nt_contiguous_problem_size(1367-1413)_check_m_grouped_fp8_gemm_nt_masked_problem_size(1468-1526)
πͺ Ruff (0.14.2)
flashinfer/deep_gemm.py
1372-1372: Unused function argument: recipe
(ARG001)
1373-1373: Unused function argument: compiled_dims
(ARG001)
1379-1379: Avoid specifying long messages outside the exception class
(TRY003)
1381-1381: Avoid specifying long messages outside the exception class
(TRY003)
1384-1386: Avoid specifying long messages outside the exception class
(TRY003)
1397-1399: Avoid specifying long messages outside the exception class
(TRY003)
1401-1401: Avoid specifying long messages outside the exception class
(TRY003)
1403-1403: Avoid specifying long messages outside the exception class
(TRY003)
1405-1405: Avoid specifying long messages outside the exception class
(TRY003)
1407-1407: Avoid specifying long messages outside the exception class
(TRY003)
1411-1411: Avoid specifying long messages outside the exception class
(TRY003)
1437-1437: Unpacked variable k_ is never used
(RUF059)
1474-1474: Unused function argument: recipe
(ARG001)
1475-1475: Unused function argument: compiled_dims
(ARG001)
1480-1480: Avoid specifying long messages outside the exception class
(TRY003)
1482-1482: Avoid specifying long messages outside the exception class
(TRY003)
1485-1487: Avoid specifying long messages outside the exception class
(TRY003)
1489-1489: Unpacked variable sfa is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1490-1490: Unpacked variable sfb is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
1502-1504: Avoid specifying long messages outside the exception class
(TRY003)
1506-1508: Avoid specifying long messages outside the exception class
(TRY003)
1510-1512: Avoid specifying long messages outside the exception class
(TRY003)
1514-1514: Avoid specifying long messages outside the exception class
(TRY003)
1516-1516: Avoid specifying long messages outside the exception class
(TRY003)
1518-1518: Avoid specifying long messages outside the exception class
(TRY003)
1520-1520: Avoid specifying long messages outside the exception class
(TRY003)
1524-1524: Avoid specifying long messages outside the exception class
(TRY003)
flashinfer/gemm.py
2011-2011: Unused function argument: A
(ARG001)
2012-2012: Unused function argument: B
(ARG001)
2013-2013: Unused function argument: A_scale
(ARG001)
2014-2014: Unused function argument: B_scale
(ARG001)
2015-2015: Unused function argument: dtype
(ARG001)
2016-2016: Unused function argument: out
(ARG001)
2017-2017: Unused function argument: backend
(ARG001)
2025-2025: Unused function argument: A
(ARG001)
2026-2026: Unused function argument: B
(ARG001)
2027-2027: Unused function argument: A_scale
(ARG001)
2028-2028: Unused function argument: B_scale
(ARG001)
2029-2029: Unused function argument: dtype
(ARG001)
2030-2030: Unused function argument: out
(ARG001)
2031-2031: Unused function argument: backend
(ARG001)
2040-2040: Unused function argument: A_scale
(ARG001)
2041-2041: Unused function argument: B_scale
(ARG001)
2042-2042: Unused function argument: dtype
(ARG001)
2043-2043: Unused function argument: out
(ARG001)
2044-2044: Unused function argument: backend
(ARG001)
2047-2047: Avoid specifying long messages outside the exception class
(TRY003)
2052-2052: Unused function argument: A
(ARG001)
2053-2053: Unused function argument: B
(ARG001)
2054-2054: Unused function argument: A_scale
(ARG001)
2055-2055: Unused function argument: B_scale
(ARG001)
2057-2057: Unused function argument: out
(ARG001)
2058-2058: Unused function argument: backend
(ARG001)
2166-2166: Unused function argument: a
(ARG001)
2167-2167: Unused function argument: b
(ARG001)
2168-2168: Unused function argument: a_scale
(ARG001)
2169-2169: Unused function argument: b_scale
(ARG001)
2171-2171: Unused function argument: mma_sm
(ARG001)
2172-2172: Unused function argument: scale_granularity_mnk
(ARG001)
2173-2173: Unused function argument: out
(ARG001)
2174-2174: Unused function argument: out_dtype
(ARG001)
2175-2175: Unused function argument: backend
(ARG001)
2178-2178: Avoid specifying long messages outside the exception class
(TRY003)
2186-2186: Unused function argument: b
(ARG001)
2187-2187: Unused function argument: a_scale
(ARG001)
2188-2188: Unused function argument: b_scale
(ARG001)
2189-2189: Unused function argument: scale_major_mode
(ARG001)
2190-2190: Unused function argument: mma_sm
(ARG001)
2192-2192: Unused function argument: out
(ARG001)
2193-2193: Unused function argument: out_dtype
(ARG001)
2194-2194: Unused function argument: backend
(ARG001)
2197-2197: Avoid specifying long messages outside the exception class
(TRY003)
2199-2199: Avoid specifying long messages outside the exception class
(TRY003)
2207-2207: Unused function argument: a_scale
(ARG001)
2208-2208: Unused function argument: b_scale
(ARG001)
2209-2209: Unused function argument: scale_major_mode
(ARG001)
2210-2210: Unused function argument: mma_sm
(ARG001)
2211-2211: Unused function argument: scale_granularity_mnk
(ARG001)
2212-2212: Unused function argument: out
(ARG001)
2214-2214: Unused function argument: backend
(ARG001)
2217-2217: Avoid specifying long messages outside the exception class
(TRY003)
2220-2222: Avoid specifying long messages outside the exception class
(TRY003)
2596-2596: Unused function argument: scale_granularity_mnk
(ARG001)
2603-2603: Avoid specifying long messages outside the exception class
(TRY003)
2605-2605: Avoid specifying long messages outside the exception class
(TRY003)
2607-2607: Avoid specifying long messages outside the exception class
(TRY003)
2609-2609: Avoid specifying long messages outside the exception class
(TRY003)
2611-2611: Avoid specifying long messages outside the exception class
(TRY003)
2613-2615: Avoid specifying long messages outside the exception class
(TRY003)
2617-2617: Avoid specifying long messages outside the exception class
(TRY003)
2630-2632: Avoid specifying long messages outside the exception class
(TRY003)
2634-2636: Avoid specifying long messages outside the exception class
(TRY003)
2641-2641: Avoid specifying long messages outside the exception class
(TRY003)
2643-2643: Avoid specifying long messages outside the exception class
(TRY003)
2645-2645: Avoid specifying long messages outside the exception class
(TRY003)
2651-2653: Avoid specifying long messages outside the exception class
(TRY003)
2798-2800: Avoid specifying long messages outside the exception class
(TRY003)
2802-2802: Avoid specifying long messages outside the exception class
(TRY003)
2804-2804: Avoid specifying long messages outside the exception class
(TRY003)
2806-2806: Avoid specifying long messages outside the exception class
(TRY003)
2808-2808: Avoid specifying long messages outside the exception class
(TRY003)
2810-2810: Avoid specifying long messages outside the exception class
(TRY003)
2812-2812: Avoid specifying long messages outside the exception class
(TRY003)
2814-2814: Avoid specifying long messages outside the exception class
(TRY003)
2816-2816: Avoid specifying long messages outside the exception class
(TRY003)
2818-2818: Avoid specifying long messages outside the exception class
(TRY003)
2829-2831: Avoid specifying long messages outside the exception class
(TRY003)
2835-2837: Avoid specifying long messages outside the exception class
(TRY003)
2844-2846: Avoid specifying long messages outside the exception class
(TRY003)
2851-2851: Avoid specifying long messages outside the exception class
(TRY003)
2853-2853: Avoid specifying long messages outside the exception class
(TRY003)
2858-2858: Avoid specifying long messages outside the exception class
(TRY003)
2860-2860: Avoid specifying long messages outside the exception class
(TRY003)
3022-3022: Unused function argument: out_dtype
(ARG001)
3172-3172: Unused function argument: out_dtype
(ARG001)
β° Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Deploy Docs
π Additional comments (2)
flashinfer/deep_gemm.py (1)
48-53: Imports look good.The utils import now brings in supported_compute_capability and backend_requirement correctly.
flashinfer/gemm.py (1)
2009-2072: bmm_fp8 requirements mapping looks correct.Backends and CC gating are wired; common check validates output dtype.
| @backend_requirement( | ||
| {}, | ||
| common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: group_gemm_fp8_nt_groupwise uses {} backend map.
Add a default backend requirement with CC gating; signature must match this API.
+@supported_compute_capability([100, 120, 121])
+def _cutlass_group_gemm_fp8_nt_groupwise_requirement(
+ a, b, a_scale, b_scale, m_indptr,
+ scale_granularity_mnk=(1, 128, 128),
+ scale_major_mode="MN",
+ mma_sm=1,
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
-)
+@backend_requirement(
+ {None: _cutlass_group_gemm_fp8_nt_groupwise_requirement},
+ common_check=_check_group_gemm_fp8_nt_groupwise_problem_size,
+)π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @backend_requirement( | |
| {}, | |
| common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, | |
| ) | |
| @supported_compute_capability([100, 120, 121]) | |
| def _cutlass_group_gemm_fp8_nt_groupwise_requirement( | |
| a, b, a_scale, b_scale, m_indptr, | |
| scale_granularity_mnk=(1, 128, 128), | |
| scale_major_mode="MN", | |
| mma_sm=1, | |
| out=None, out_dtype=None | |
| ): | |
| return True | |
| @backend_requirement( | |
| {None: _cutlass_group_gemm_fp8_nt_groupwise_requirement}, | |
| common_check=_check_group_gemm_fp8_nt_groupwise_problem_size, | |
| ) |
| @backend_requirement( | ||
| {}, | ||
| common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: DeepGEMM group path requires a default backend requirement.
Add CC-gated default backend for SM100/103.
+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement_group(
+ a, b, a_scale, b_scale, m_indices, scale_granularity_mnk=(1, 128, 128),
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement_group},
+ common_check=_check_group_deepgemm_fp8_nt_groupwise_problem_size,
+)Committable suggestion skipped: line range outside the PR's diff.
π€ Prompt for AI Agents
In flashinfer/gemm.py around lines 3033 to 3036, the @backend_requirement
decorator on the DeepGEMM group path is missing a CC-gated default backend; add
a default backend entry that targets SM100 and SM103 (e.g. a "default" key with
a CUDA backend CC filter for 100 and 103) so the decorator includes both the
existing specific requirements and a default backend gated to compute capability
100 and 103.
| @backend_requirement( | ||
| {}, | ||
| common_check=_check_batch_deepgemm_fp8_nt_groupwise, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: DeepGEMM batch path also uses {}.
Add CC-gated default backend.
+@supported_compute_capability([100, 103])
+def _deepgemm_sm10x_requirement_batch(
+ a, b, a_scale, b_scale, masked_m, expected_m,
+ scale_granularity_mnk=(1, 128, 128),
+ out=None, out_dtype=None
+):
+ return True
+
-@backend_requirement(
- {},
- common_check=_check_batch_deepgemm_fp8_nt_groupwise,
-)
+@backend_requirement(
+ {None: _deepgemm_sm10x_requirement_batch},
+ common_check=_check_batch_deepgemm_fp8_nt_groupwise,
+)π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @backend_requirement( | |
| {}, | |
| common_check=_check_batch_deepgemm_fp8_nt_groupwise, | |
| ) | |
| @supported_compute_capability([100, 103]) | |
| def _deepgemm_sm10x_requirement_batch( | |
| a, b, a_scale, b_scale, masked_m, expected_m, | |
| scale_granularity_mnk=(1, 128, 128), | |
| out=None, out_dtype=None | |
| ): | |
| return True | |
| @backend_requirement( | |
| {None: _deepgemm_sm10x_requirement_batch}, | |
| common_check=_check_batch_deepgemm_fp8_nt_groupwise, | |
| ) |
π€ Prompt for AI Agents
In flashinfer/gemm.py around lines 3181 to 3184, the @backend_requirement
decorator for the DeepGEMM batch path currently passes an empty dict for
backends ({}); change it to include a CC-gated default backend entry so the
batch path is enabled only when the compute-capability gate is satisfied. Update
the decorator's first argument to supply a default backend dict guarded by the
same CC check used elsewhere (i.e., add the CC-gated backend entry instead of
{}), keeping the existing common_check=_check_batch_deepgemm_fp8_nt_groupwise.
|
/bot run |
|
[FAILED] Pipeline #37488727: 1/17 passed |
<!-- .github/pull_request_template.md --> ## π Description - Update the autotune logic in trtllm-gen moe. Instead of using a fixed `tile_tokens_dim`, tune in a range of `[max(8,tile_token_dim/2), tile_token_dim, min(128, tile_token_dim*2), min(128, tile_token_dim*4)]` - Add FP8 MOE autotune logic, initial PR flashinfer-ai#1494 from @aleozlx, update logic to sync with new autotuner. - Update logic in `test_trtllm_gen_fused_moe.py`. - Update the `conftest.py` to speed up test, previously use `try_first` which introduce duplicate run - Add log_once in logger <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Runtime autotuning with per-tile dynamic routing and selectable MoE runner options (gated activation, shuffled-weight, weight-layout). * One-time (deduplicated) logging helpers added to JIT logger. * **Deprecations** * tile_tokens_dim removed from new paths and marked deprecated in legacy entry points; new tuning parameters introduced for autotuning. * **Tests** * Tests refactored for autotuning/routing with new helpers and improved handling/reporting for missing JIT cache. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: jiahanc <[email protected]> Co-authored-by: yzh119 <[email protected]>
|
|
||
|
|
||
| def m_grouped_fp8_gemm_nt_contiguous( | ||
| @supported_compute_capability([100, 103]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, the compute capability will not be checked actually on the common_check function.
We would either need to:
- Redesign the decorator so that we can just add @supported_compute_capability to m_grouped_fp8_gemm_nt_contiguous directly; OR
- Change the decorator's implementations of wrapper and is_compute_capability_supported so it also checks this on the common_check.
- Add a backend parameter to m_grouped_fp8_gemm_nt_contiguous. Though this is deepgemm and I don't think we want to call this a backend for now.
My preference would go to option 2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think 2 also makes most sense. In a lot of the APIs there are also no 'backend' arg to be passed in so we can't only check @ supported_compute_capability there. I can change this in a separate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A separate PR is fine, since we wouldn't cause a regression (we didn't have CC checks before the current PR anyway).
| impl(a, sfa, b, sfb, d, m_indices) | ||
|
|
||
|
|
||
| @supported_compute_capability([100, 103]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here, the compute capabilities will be ignored.
flashinfer/gemm.py
Outdated
| "cudnn": _cudnn_bmm_fp8_requirement, | ||
| "cublas": _cublas_bmm_fp8_requirement, | ||
| "cutlass": _cutlass_bmm_fp8_requirement, | ||
| "auto": _cublas_bmm_fp8_requirement, # cublas default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
auto doesn't default to cublas. It actually sets backends = ["cutlass", "cublas", "cudnn"] , this usually will be autotuned across backends.
That's not trivial to resolve. One way may be to just create an aggregate function "check_all_backends" that calls the 3 checker functions.
Note that the function that takes all these backends is fp8_gemm_sm100, which also does a bunch of checks. So perhaps we need to rely on those, since the checks here are very minimal anyway.
But this does remain a general issue: "How do we isolate backend support checks when the auto backend is chosen". I think the behavior that we wish is "filter out unsuitable backends", rather than just rejecting auto when one of the backends fails. Maybe this can be an additional "attribute" that the decorator adds for the special auto backend. To illustrate, we would have:
bmm_fp8.suitable_auto_backends() return ["cutlass", "cublas", "cudnn"] if all checks pass, or ["cutlass", "cublas"] if cudnn's check failed. bmm_fp8.suitable_auto_backends() would return None if the selected backend was not "auto".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bmm_fp8.suitable_auto_backends() return ["cutlass", "cublas", "cudnn"] if all checks pass, or ["cutlass", "cublas"] if cudnn's check failed.
In this decorator, we could set an attribute with the final list of supported backends for the actual function to reference when passing into the autotuner, is this what you had in mind?
<!-- .github/pull_request_template.md --> ## π Description This PR fixes illegal memory access of trtllm-gen attention kernels. It changes the workspace buffer from `int_workspace_buffer` to `float_workspace_buffer`. `int_workspace_buffer` is a fixed sized buffer and not initialized to zero, which should not be used. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues Issue flashinfer-ai#1928 ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed memory allocation in the decode module to improve computation accuracy and stability during text generation. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description Update version in `version.txt` to v0.5.0 as we prepare for v0.5.0rc1 release. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Version bump to 0.5.0 (no functional changes) <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description The regex cannot recognize release candidates (`v0.5.0rc1`) or post releases (`v1.2.3.post1`): https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551 This PR fixes the issue. ## π Related Issues https://github.com/flashinfer-ai/flashinfer/actions/runs/18929490991/job/54049304551 ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Enhanced version string parsing in the wheel package indexing process to support more complex version formats, including pre-release, post-release, and development versions, ensuring compatibility with PEP 440 versioning standards. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description In flashinfer-ai#1809 we previously added a compute-capability-based support check for `mm_fp4`. However, we missed enabling SM121 for backend = `cudnn` and `cutlass`. Additionally, we marked `trtllm` as supported on SM120 when it is not. Current PR fixes it. Example benchmark and pytest command on SM121 after the fix ``` (py312) root@f414f262f02a:/flashinfer/benchmarks# python3 flashinfer_benchmark.py --routine mm_fp4 --m 8192 --n 7168 --k 512 --out_dtype bfloat16 --backends cudnn cutlass --use_128x4_sf_layout --use_nvfp4 --refcheck --use_cupti /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( [PERF] cudnn :: median time 0.656 ms; std 0.025 ms; achieved tflops 91.701 TFLOPs/sec; achieved tb_per_sec 0.185 TB/sec [PERF] cutlass :: median time 0.669 ms; std 0.022 ms; achieved tflops 89.859 TFLOPs/sec; achieved tb_per_sec 0.181 TB/sec (py312) root@f414f262f02a:/flashinfer# pytest tests/gemm/test_mm_fp4.py ====================================================================================================================== test session starts ====================================================================================================================== platform linux -- Python 3.12.11, pytest-8.4.2, pluggy-1.6.0 rootdir: /flashinfer configfile: pytest.ini collected 3240 items ... ======================================================================================================================= warnings summary ======================================================================================================================== ../opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285 /opt/conda/envs/py312/lib/python3.12/site-packages/torch/cuda/__init__.py:285: UserWarning: Found GPU0 NVIDIA GB10 which is of cuda capability 12.1. Minimum and Maximum cuda capability supported by this version of PyTorch is (8.0) - (12.0) warnings.warn( -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html ========================================================================================================= 450 passed, 2790 skipped, 1 warning in 8.24s ========================================================================================================== ``` <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Expanded hardware compatibility by adding support for newer NVIDIA GPU architectures. * FP4 quantized operations now available across multiple backends on supported devices. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description Fix the regression in vLLM and SGLang with FI 0.4.0 in bmm_fp8 ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes cc: @yzh119 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed memory layout handling for tensor operations in GPU computations to ensure proper alignment, improving stability and performance. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description Use real head sizes, seq lens and add comparison with sequential prefill + decode. Results on H100 (without overlap, which only adds ~150GB/s for persistent): <img width="433" height="571" alt="image" src="https://github.com/user-attachments/assets/50de01cd-e5ca-450c-9cc0-521d83b7e487" /> cc @yzh119 ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added comprehensive performance benchmarking for batch attention operations with detailed timing measurements. * Introduced sequential dual-kernel benchmark path with extended memory bandwidth reporting. * **Tests** * Updated benchmark test configurations to use deterministic, fixed values for improved reproducibility. * Adjusted benchmark parameters for consistency across test iterations. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
β¦shinfer-ai#2011) ## π Description This PR removes an assertion in the cutlass fused moe bindings to enable non-gated activations in nvfp4. It also adds a test for this path with relu2 activation. ## π Related Issues N/A ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [v] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [v] I have installed the hooks with `pre-commit install`. - [v] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [v] Tests have been added or updated as needed. - [v] All tests are passing (`unittest`, etc.). ## Reviewer Notes N/A <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Enhanced quantized Mixture of Experts models to support configurable activation types (Swiglu and ReLU2) in the NVFP4 quantization path. * Improved parameter handling to correctly adapt weight shapes and quantization settings based on the selected activation type. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Omer Ullman Argov <[email protected]>
β¦qa backend (flashinfer-ai#2001) <!-- .github/pull_request_template.md --> ## π Description Expose xqa backend to trtllm attention interface, and improve layout coverage of trtllm-gen and xqa backends. Now both trtllm-gen/xqa supports NHD/HND kv-cache layout. * support NHD layout for trtllm-gen * refactor xqa (flashinfer-ai@869c0c1) * allow user passed stride_page/head/token * support both HND and NHD * remove macros such as PAGED_KV_CACHE_LAYOUT and USE_PAGED_KV_CACHE * adding unittests for both trtllm-gen/xqa on NHD/HND * adding unified API for trtllm-gen/xqa, and unified unittest ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added xqa-based batch decode API and public kv_layout option (NHD/HND); added enable_pdl toggle to inference wrappers. * **Improvements** * Automatic backend selection for decoding, consistent KV-layout normalization across paths, and unified stride-aware paged-KV handling with layout-aware shapes, scales, and workspace handling. * **Tests** * Expanded tests to cover both KV layouts, enable_pdl, new batch-decode workflows, backend/layout permutations, and fp8/mixed-dtype scenarios. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Qidi Sang <[email protected]> Co-authored-by: yzh119 <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
<!-- .github/pull_request_template.md --> ## π Description Add the MxInt4 x BF16 TRTLLM GEN moe ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * MXInt4 MoE inference path with public API and test coverage; MXInt4 + BF16 supported end-to-end. * Exposed new MXInt4 op and helper in the package exports. * **Refactor** * Block-scale/interleave routines generalized to support uint8 and bfloat16 inputs and outputs. * GEMM/BatchedGemm configs now include an element-wise activation option and are arch-aware (CUDA arch). * **Tests** * Added MXInt4 quantization and runtime tests for MoE. * **Chores** * Updated packaged artifact path/checksum. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Siyuan Fu <[email protected]> Signed-off-by: Nikita Korobov <[email protected]> Co-authored-by: Siyuan Fu <[email protected]>
β¦tion (flashinfer-ai#2163) <!-- .github/pull_request_template.md --> ## π Description `trtllm_batch_decode_with_kv_cache_mla` and `xqa_batch_decode_with_kv_cache_mla` currently reside in `decode.py`. This PR moves them to `mla.py` and makes them show up in the documentation via adding them to `attention.rst`. Note that the addition to documentation at the correct place requires this refactor as the docs generator looks at each module for indexing. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added new batch decode entry points for MLA with KV cache, supporting additional configuration parameters for enhanced control. * XQA MLA functionality now officially documented and available. * **Documentation** * Updated API reference with expanded MLA module documentation. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description - Fix test_cute_dsl_gemm_allreduce_two_shot.py regression from nvidia-cutlass-dsl upgrade to 4.3.1 (removed helper functions) - GB300 enabled for this kernel as well need 8xB200 pytest tests/gemm/test_cute_dsl_gemm_allreduce_two_shot.py -v ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit ## Release Notes * **New Features** * Added new synchronization helper functions for inter-GPU coordination: `spin_lock_multimem_arrive`, `spin_lock_atom_cas_acquire_wait`, and `sm_wise_inter_gpu_multimem_barrier`. * **Tests** * Extended test coverage to support additional GPU architectures (SM10.3). <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
This PR updates the Docker CI image tags to the latest version: `20251205-54c1678` Updated images: - flashinfer/flashinfer-ci-cu126:20251205-54c1678 - flashinfer/flashinfer-ci-cu128:20251205-54c1678 - flashinfer/flashinfer-ci-cu129:20251205-54c1678 - flashinfer/flashinfer-ci-cu130:20251205-54c1678 Auto-generated by [release-ci-docker workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/19960576464) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated internal Docker image tags to latest versions for infrastructure maintenance. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: yzh119 <[email protected]> Co-authored-by: Zihao Ye <[email protected]>
<!-- .github/pull_request_template.md --> ## π Description <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactor** * Updated DeepSeek routing kernel naming within the public API for improved consistency * **Tests** * Updated test suite to reflect the updated kernel naming convention <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
β¦infer-ai#2111) <!-- .github/pull_request_template.md --> ## π Description This PR refactors the out-dated fa3 codebase, more specifically, for page_size>1, the page offset calculation is performed inside the kernel, without the need of a standalone function call to block_sparse_indices_to_vector_sparse_offsets, and optimize the kv_offset calculation with prefetching and shuffling. This PR also fixes the failed unittest on hopper. However, the FA3 structure in our codebase is still terrible outdated without important features such as `IntraWGOverlap` and `RescaleOBeforeGemm`, will follow up soon in a later PR. ## π Related Issues This PR should fixes flashinfer-ai#1647 ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Enforced consistent K/V page-strides and stride_n for sparse paged attention with runtime validation. * **New Features** * Exposed page-interval stride metadata for paged KV caches. * Expanded FP8 support: per-head/per-tensor scale propagation and configurable output dtype; added ragged-KV execution path and benchmarks. * **Refactor** * Unified sparse attention to a page-based addressing model and simplified workspace/epilogue behavior; removed legacy block-to-vector conversion helpers and related APIs. * **Tests** * Added FP8 ragged/paged tests, expanded page-size coverage, and removed obsolete vector-sparse unit test. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description We currently have unit tests failing as: ``` ========================================== Running: pytest --continue-on-collection-errors -s --junitxml=/junit/tests/comm/test_trtllm_mnnvl_allreduce.py.xml "tests/comm/test_trtllm_mnnvl_allreduce.py" ========================================== Abort(1090447) on node 0 (rank 0 in comm 0): Fatal error in PMPI_Init_thread: Unknown error class, error stack: MPIR_Init_thread(192)........: MPID_Init(1665)..............: MPIDI_OFI_mpi_init_hook(1586): (unknown)(): Unknown error class [unset]: write_line error; fd=-1 buf=:cmd=abort exitcode=1090447 : system msg for write_line failure : Bad file descriptor Abort(1090447) on node 0 (rank 0 in comm 0): Fatal error in PMPI_Init_thread: Unknown error class, error stack: MPIR_Init_thread(192)........: MPID_Init(1665)..............: MPIDI_OFI_mpi_init_hook(1586): ... Extension modules: numpy._core._multiarray_umath, numpy.linalg._umath_linalg, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, cuda.bindings._bindings.cydriver, cuda.bindings.cydriver, cuda.bindings.driver, tvm_ffi.core, markupsafe._speedups, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, mpi4py.MPI (total: 22) !!!!!!! Segfault encountered !!!!!!! ... β FAILED: tests/comm/test_trtllm_mnnvl_allreduce.py ``` These tests should be skipping in a single GPU environment, but are failing, which indicates that they are failing at MPI module load time. The current `dockerfile.cuXXX` installs MPI via `RUN conda install -n py312 -y mpi4py`. Upon investigating the docker build logs, [A month ago (Nov. 4)](https://github.com/flashinfer-ai/flashinfer/actions/runs/19084098717/job/54520197904#step:6:802), ``` flashinfer-ai#17 13.68 mpi-1.0.1 | mpich 6 KB conda-forge flashinfer-ai#17 13.68 mpi4py-4.1.1 |py312hd0af0b3_100 866 KB conda-forge flashinfer-ai#17 13.68 mpich-4.3.2 | h79b1c89_100 5.4 MB conda-forge ``` was being installed, [but yesterday](https://github.com/flashinfer-ai/flashinfer/actions/runs/19960576464/job/57239792717#step:6:673): ``` flashinfer-ai#17 13.59 impi_rt-2021.13.1 | ha770c72_769 41.7 MB conda-forge flashinfer-ai#17 13.59 mpi-1.0 | impi 6 KB conda-forge flashinfer-ai#17 13.59 mpi4py-4.1.1 |py312h18f78f0_102 864 KB conda-forge ``` is being installed. The mpich vs. impi are Implementations to the MPI: MPICH vs. Intel MPI. This is currently the suspected issue underlying the MPI load failures. Current PR specifies the MPI implementation via `RUN conda install -n py312 -y mpi4py mpich`. The result of the current PR produces ([build log](https://github.com/flashinfer-ai/flashinfer/actions/runs/19976372640/job/57293423165?pr=2182#step:6:436)): ``` flashinfer-ai#15 14.63 mpi-1.0.1 | mpich 6 KB conda-forge flashinfer-ai#15 14.63 mpi4py-4.1.1 |py312hd0af0b3_102 865 KB conda-forge flashinfer-ai#15 14.63 mpich-4.3.2 | h79b1c89_100 5.4 MB conda-forge ``` which now matches what we had before <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. -->
This PR updates the Docker CI image tags to the latest version: `20251206-185d63a` Updated images: - flashinfer/flashinfer-ci-cu126:20251206-185d63a - flashinfer/flashinfer-ci-cu128:20251206-185d63a - flashinfer/flashinfer-ci-cu129:20251206-185d63a - flashinfer/flashinfer-ci-cu130:20251206-185d63a Auto-generated by [release-ci-docker workflow](https://github.com/flashinfer-ai/flashinfer/actions/runs/19993421764) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Updated CI build container images to latest versions for improved build reliability and security. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> Co-authored-by: yzh119 <[email protected]>
β¦-ai#2189) <!-- .github/pull_request_template.md --> ## π Description `test_jit_warmup.py` contains two tests `test_warmpup_llama` and `test_warmpup_llama_sm90` where the _sm90 variant is the first one plus an additional fa3 prefill compilation only supported on SM90. While under general circumstances, this test should pass on non-sm90 cards, it might fail if the environment contains an incomplete `flashinfer-jit-cache` that is missing sm90 kernels and `FLASHINFER_CUDA_ARCH_LIST` is not correctly set to include 9.0a. Since: 1. This can interfere with unit testing by causing unexpected failures on non SM90 cards and 2. On non SM90 cards, there is no additional gain from testing sm90 kernel compilation current PR disables it. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Updated test execution conditions to ensure tests run only on supported hardware configurations. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
β¦r-ai#2188) <!-- .github/pull_request_template.md --> ## π Description [CUDA Release notes](https://docs.nvidia.com/cuda/parallel-thread-execution/#release-notes) indicate full sm120 support on CUDA 12.8 and sm121 on 12.9. Changing the version noted in `aot.py`. The change alters the coverage in `flashinfer-jit-cache` -- today some sm120 kernels are not compiled & shipped in cuda 12.9, but are included in 13.0. The change here will insure kernels are included. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Adjusted GPU capability detection and CUDA version requirements: * SM110 now requires CUDA 13.0 (raised from 12.9). * SM120 support expanded to CUDA 12.8 (lowered from 13.0). * SM121 support updated to CUDA 12.9 (from 13.0). * **Impact** * Broader compatibility for SM120/SM121 on slightly older CUDA; SM110 now requires a newer CUDA runtime. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description https://pypi.org/project/flashinfer-python/ seems we have 0.5.3, but the nightly and version is still 0.5.2 <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Chores** * Version updated to 0.5.3 <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
β¦lashinfer-ai#2197) <!-- .github/pull_request_template.md --> ## π Description Add this to public docs CUDA version support 12.6, 12.8, 13.0, 13.1 and state that our goal is to follow PyTorch support + latest CUDA version <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Documentation** * Updated CUDA support documentation with explicit supported versions: 12.6, 12.8, 13.0, and 13.1. * Documented CUDA version prerequisites for installation. * Added clarification that supported CUDA versions align with PyTorch's officially supported versions and the latest CUDA release. * Enhanced GPU support section header for better clarity. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
β¦-ai#2196) <!-- .github/pull_request_template.md --> ## π Description This PR addresses the documentation issues and adds proper validation for unsupported configurations in the attention prefill wrappers raised in flashinfer-ai#1709. Changes * `BatchPrefillWithPagedKVCacheWrapper`: * Updated docstring to include `trtllm-gen` as a valid backend option, which was * Added validation for non-causal attention with `trtllm-gen` backend: The kernel only supports causal attention. Previously, using `causal=False` would silently run causal attention anyway (causing incorrect outputs). Now it raises a clear NotImplementedError. supported but undocumented. * `BatchPrefillWithRaggedKVCacheWrapper`: Fixed docstring to list `cutlass` instead of `trtllm-gen` as a backend option. * The `cutlass` backend is supported but was missing from docs, while `trtllm-gen` raises NotImplementedError for ragged KV cache. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues flashinfer-ai#1709 <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for trtllm-gen and cutlass backend options * **Documentation** * Clarified available backends and default configurations * Improved error guidance for unsupported backend combinations <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
β¦lashinfer-ai#2193) <!-- .github/pull_request_template.md --> ## π Description <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added packedβweights support for fused MoE calls with a new public flag to select packed vs. unpacked weight handling; runtime paths now distinguish runners by this flag. * Exposed the flag across public wrappers so callers can opt into packedβweights behavior. * **Tests** * Added W4A8 quantization pathway tests and helpers to validate INT4 packing/dequantization and endβtoβend W4A8 behavior. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
<!-- .github/pull_request_template.md --> ## π Description This PR will permute page tables during benchmarking of different backends and ensure the correct page table gets passed to the backends. <!-- What does this PR do? Briefly describe the changes and why theyβre needed. --> ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Tests** * Improved batch operation tests by introducing randomized page ordering and updated indexing to validate behavior under non-sequential page arrangements, increasing robustness and coverage. * Confirmed no changes to public interfaces or external behavior; these updates only strengthen internal test validation. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Joshua Park <[email protected]>
<!-- .github/pull_request_template.md --> ## π Description SM100 configs are generated for SM110 devices, but of the valid MOE kernels for SM110, only TMA Epilogue Scheduling is supported. In the generated configs, `EpilogueScheduleType::AUTO` is not handled well in the `getDispatchFunctionForSM100` dispatch function (it is never resolved to a concrete value, either TMA or NO_SMEM), so all configs are rejected. This PR pins the generated configs to `EpilogueScheduleType::TMA` for SM110 devices. It also allows SM100 configs to run on SM110 devices. Future work would include generating SM110 configs specifically for SM110 devices, rather than rely on SM100 configs. ## π Related Issues <!-- Link any related issues here --> ## π Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### β Pre-commit Checks - [x] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [x] I have installed the hooks with `pre-commit install`. - [x] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## π§ͺ Tests - [x] Tests have been added or updated as needed. - [x] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed GPU configuration validation to support latest hardware generations. * **Refactor** * Optimized kernel execution scheduling for improved performance on newer devices. <sub>βοΈ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 9
Note
Due to the large number of review comments, Critical severity comments were prioritized as inline comments.
Caution
Some comments are outside the diff and canβt be posted inline due to platform limitations.
β οΈ Outside diff range comments (7)
benchmarks/routines/gemm.py (1)
842-883: Useblock_sizevariable instead of duplicating the ternary expression.On line 843,
block_sizeis computed, but the validation call on lines 864-866 uses a hardcoded ternary instead of the variable. This creates a maintenance risk if the logic diverges.Additionally, line 883 returns
Noneinstead ofres, which is inconsistent with other early-exit paths in the codebase (e.g., lines 399-400 returnres).- block_size=16 - if use_nvfp4 - else 32, # nvfp4 only supports 16; mxfp4 only supports 32. + block_size=block_size,if len(backends) == 0: print("[ERROR] No backends passed validation. Exiting.") - return + return rescsrc/batch_mla_run.cu (1)
30-42: Fix dtype check forkv_indicesto match the templatedIdTypeparameter instead of hardcodingdl_int32.The
CHECK_INPUT_TYPE(kv_indices, dl_int32)at line 40 is incorrect becauseIdTypeis a template parameter determined bydtype_idx(which comes fromqo_indptr.dtypeinmla.py). If a user passesqo_indptrwith dtypeint64, the generated kernel will haveIdType = int64_t, but the input validation will still expectint32. This creates a type mismatch with the subsequent cast at line 77:params.kv_indices = static_cast<IdType*>(kv_indices.data_ptr());Either:
- Validate
kv_indicesdtype dynamically based onIdType, or- Document and enforce that
kv_indicesmust always beint32, and ensuredtype_idxis alwaysint32_tin practice.Also verify that all callers of
BatchMLAPagedAttentionRunpass thereturn_lse_base_on_eparameter (line 35, 118) or rely on a default value wrapper.benchmarks/routines/attention.py (4)
518-528: Passseq_lensas a 1D tensor for trtllm-native decodeThe
flashinfer.decode.trtllm_batch_decode_with_kv_cachefunction expectsseq_lenswith shape[batch_size], butactual_seq_lens_kvis created with shape[batch_size, 1, 1, 1]. Use.flatten()to reshape:- seq_lens=actual_seq_lens_kv, + seq_lens=actual_seq_lens_kv.flatten(),
1827-1841: Remove thekv_cache.unsqueeze(1)and add validation for MLA parameter constraints.The
kv_cache.unsqueeze(1)is correct and necessaryβflashinfer.mla.trtllm_batch_decode_with_kv_cache_mlarequires 4D tensors for bothqueryandkv_cache(validated internally). However, the benchmark passeshead_dim_ckvandhead_dim_kpeas user-provided arguments without validating that they match flashinfer's hardcoded constraints:kv_lora_rankmust be 512 andqk_rope_head_dimmust be 64. If incorrect values are supplied, the function will fail at runtime. Add explicit validation near argument parsing to fail fast with a clear error message:
- required
head_dim_ckv == 512(mapped tokv_lora_rank)- required
head_dim_kpe == 64(mapped toqk_rope_head_dim)- clarify why
qk_nope_head_dim=128is hardcoded
1428-1447: Flattenseq_lensbefore passing to trtllm_ragged_attention_deepseekThe
seq_lensparameter receivesactual_seq_lens_kv_devicewith shape[B,1,1,1]instead of the expected[B]. Call.view(-1)or.flatten()on the tensor before passing it:seq_lens=actual_seq_lens_kv_device.view(-1),While PyTorch's broadcasting may silently handle the shape mismatch, explicit flattening ensures the function receives the correct expected shape.
979-993: Re-evaluate shape concern for seq_lens parameterThe
seq_lensparameter receivesactual_seq_lens_kv_device, which is constructed astorch.tensor(batch_size_per_seq, dtype=torch.int32)(line 768). This creates a 1D tensor of shape[batch_size], not[B,1,1,1]. The parameter shapes appear correct and identical to what the hopper backend uses (line 934).However, verify that
cum_seq_lens_qandcum_seq_lens_kvare provided astorch.int32as expected by the flashinfer APIβboth currently usetorch.arange(0, batch_size + 1, dtype=torch.int32). If the API requirestorch.int64instead, apply the same dtype conversion that may be needed for the hopper backend.csrc/batch_decode.cu (1)
83-94: Ensure dtype checks match the dispatchedIdType.
TheCHECK_INPUT_TYPEcalls forpaged_kv_indptr,paged_kv_indices, andpaged_kv_last_page_lencurrently hard-codedl_int32, but these tensors are cast toIdType*in thepaged_kv_t<DTypeKV, IdType>constructor (lines 149-151). SinceIdTypeis a template parameter, the dtype checks should be validated against the actualIdTypeused in each generated configuration to avoid future inconsistencies.
π Major comments (20)
benchmarks/bench_logging_overhead.py-31-36 (1)
31-36: Guard LOG_DEST semantics + avoid deleting arbitrary files; ensure decorator sees the same dest
Right now you (a) default LOG_DEST to a file path even though the decorator defaults to stdout, and (b)os.remove(LOG_DEST)can delete unintended files (including a file literally namedstdout). Also, if you ever decide to programmatically set a dest, it must happen before importingflashinfer_apisince config is import-time.@@ -LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) -LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "/tmp/flashinfer_benchmark_log.txt") +LOGGING_LEVEL = int(os.environ.get("FLASHINFER_LOGLEVEL", "0")) +# Keep default aligned with the decorator ("stdout") unless explicitly set by the user. +LOG_DEST = os.environ.get("FLASHINFER_LOGDEST", "stdout") + +def _is_file_log_dest(dest: str) -> bool: + return dest not in ("stdout", "stderr") @@ -from flashinfer.api_logging import flashinfer_api +from flashinfer.api_logging import flashinfer_api @@ - # Clear log file before starting - if os.path.exists(LOG_DEST): - os.remove(LOG_DEST) + # Clear log file before starting (only when logging to a real file path) + if _is_file_log_dest(LOG_DEST) and os.path.exists(LOG_DEST): + os.remove(LOG_DEST) @@ - if LOGGING_LEVEL > 0 and os.path.exists(LOG_DEST): + if LOGGING_LEVEL > 0 and _is_file_log_dest(LOG_DEST) and os.path.exists(LOG_DEST): @@ - if cleanup_log: + if cleanup_log: os.remove(LOG_DEST) print("\n Log file removed (set CLEANUP_LOG=false to keep it)")If you do want a file default, prefer a tempfile (see next comment) and also set
FLASHINFER_LOGDESTexplicitly before importingflashinfer_api.Also applies to: 248-252, 292-310
csrc/fmha_cutlass_sm100.cu-99-101 (1)
99-101: Use consistent device for guard and stream; add device validation for all tensors.The code currently initializes
CUDADeviceGuardfromqo_segment_offsets.device().device_idbut gets the stream fromo.device(). If these tensors land on different CUDA devices, you'll set the device context to one device while launching work on another, causing silent corruption or crashes.Use a canonical device (e.g.,
q.device()) for both the guard and stream, and validate all input tensors are on that device using the existingCHECK_DEVICEmacro:+ const auto dev = q.device(); + CHECK_DEVICE(q, k); + CHECK_DEVICE(q, v); + CHECK_DEVICE(q, qo_segment_offsets); + CHECK_DEVICE(q, kv_segment_offsets); + CHECK_DEVICE(q, work_indptr); + CHECK_DEVICE(q, qo_tile_indices); + CHECK_DEVICE(q, qo_head_indices); + CHECK_DEVICE(q, batch_indices); + CHECK_DEVICE(q, o); + if (maybe_lse.has_value()) CHECK_DEVICE(q, maybe_lse.value()); + - ffi::CUDADeviceGuard device_guard(qo_segment_offsets.device().device_id); - const cudaStream_t stream = get_stream(o.device()); + ffi::CUDADeviceGuard device_guard(dev.device_id); + const cudaStream_t stream = get_stream(dev);csrc/dsv3_router_gemm.cu-118-147 (1)
118-147: Missing matrix dimension compatibility checks.The validation is incomplete:
Missing K-dimension check: No verification that
mat_a.sizes()[1] == mat_b.sizes()[0](inner dimensions must match for GEMM). Ifmat_bhas an incorrect first dimension, this could cause out-of-bounds memory access.Missing output shape validation: No check that
outhas shape[num_tokens, num_experts].Dead code:
output_size(line 126) is computed but never used.constexpr int kNumExperts = 256; constexpr int kHiddenDim = 7168; - std::vector<int64_t> output_size = {mat_a.sizes()[0], mat_b.sizes()[1]}; TVM_FFI_ICHECK(mat_a.dim() == 2 && mat_b.dim() == 2) << "mat_a and mat_b must be 2D tensors"; + TVM_FFI_ICHECK(out.dim() == 2) << "out must be a 2D tensor"; + TVM_FFI_ICHECK(mat_a.sizes()[1] == mat_b.sizes()[0]) + << "mat_a and mat_b inner dimensions must match for GEMM"; + TVM_FFI_ICHECK(out.sizes()[0] == num_tokens && out.sizes()[1] == num_experts) + << "out shape must be [num_tokens, num_experts]"; TVM_FFI_ICHECK(mat_a.strides()[1] == 1 && out.strides()[1] == 1) << "mat_a and out must be row-major";csrc/fmha_v2/fmha/alibi_params.h-34-38 (1)
34-38: Consider overflow protection for integer multiplications.The constructor performs multiple integer multiplications (
h * tp_size,h * rank,s * rank) without overflow checks. While typical values should be safe, large inputs could cause integer overflow and unexpected behavior.Consider adding range validation or using checked arithmetic:
AlibiParams(int h, int s, int tp_size, int rank, float scale_after_alibi = 1.f) : AlibiParams(h * tp_size, scale_after_alibi) { + // Validate to prevent overflow + if (h > INT_MAX / tp_size || h > INT_MAX / rank || s > INT_MAX / rank) { + // Handle error appropriately + } head_idx_offset = h * rank; sequence_pos_offset = s * rank; }Committable suggestion skipped: line range outside the PR's diff.
csrc/fmha_v2/convert.cu-99-149 (1)
99-149: Same divisibility-by-4 concern applies to these templated kernels.The conversion kernels share the same
ii < n / 4pattern. The run functions that call these should also ensuren % 4 == 0to avoid silent data loss.Consider adding assertions to all
run_conversion_*functions:void run_conversion_fp32_to_e4m3(void* dst, void const* src, size_t n, float scale_o) { assert(n % 4 == 0 && "Total elements must be divisible by 4"); convert_fp32_to_T_kernel<fmha::e4m3_t><<<512, 256>>>(dst, src, n, scale_o); }csrc/fmha_v2/convert.cu-19-62 (1)
19-62: Tail elements may be silently dropped ifnis not divisible by 4.The kernel processes elements in groups of 4 (
ii < n / 4). Ifn % 4 != 0, the remaining 1-3 elements will not be converted, potentially causing incorrect output for non-aligned tensor sizes.Consider adding an assertion or handling tail elements:
void run_conversion_int32_to_int8(void* dst, void const* src, int s, int b, int h, int d, float scale) { size_t n = (size_t)s * b * h * d; + assert(n % 4 == 0 && "Total elements must be divisible by 4"); convert_int32_to_int8_kernel<<<512, 256>>>(dst, src, n, scale); }Alternatively, add a scalar fallback loop for tail elements in the kernel.
csrc/fmha_v2/fmha/hopper/smem_tile.h-161-179 (1)
161-179: Initializesmem_read_offset_(currently uninitialized / UB if ever read).
smem_read_offset_is declared but never set in these constructors; if later code starts using it (or itβs used under an#if), this becomes UB.--- a/csrc/fmha_v2/fmha/hopper/smem_tile.h +++ b/csrc/fmha_v2/fmha/hopper/smem_tile.h @@ -161,8 +161,13 @@ struct Smem_tile_hopper_gmma_row_a { // Ctor. inline __device__ Smem_tile_hopper_gmma_row_a(char* smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)) { + : smem_(__nvvm_get_smem_pointer(smem)), + smem_read_offset_(0), + smem_write_offset_(0), + smem_read_buffer_(0), + smem_write_buffer_(0) { @@ -334,8 +339,13 @@ struct Smem_tile_hopper_gmma_col_b { // Ctor. inline __device__ Smem_tile_hopper_gmma_col_b(char* smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)) { + : smem_(__nvvm_get_smem_pointer(smem)), + smem_read_offset_(0), + smem_write_offset_(0), + smem_read_buffer_(0), + smem_write_buffer_(0) { @@ -504,8 +514,13 @@ struct Smem_tile_hopper_gmma_row_b { // Ctor. inline __device__ Smem_tile_hopper_gmma_row_b(char* smem, int tidx) - : smem_(__nvvm_get_smem_pointer(smem)) { + : smem_(__nvvm_get_smem_pointer(smem)), + smem_read_offset_(0), + smem_write_offset_(0), + smem_read_buffer_(0), + smem_write_buffer_(0) {Also applies to: 334-352, 504-513
csrc/fmha_v2/fmha/hopper/smem_tile.h-234-401 (1)
234-401:move_next_write_buffer()is a no-op for B tiles; multi-buffer pipelines will silently break.
BothSmem_tile_hopper_gmma_col_bandSmem_tile_hopper_gmma_row_bhave the write-buffer advance commented out. If anything instantiates these withBUFFERS_PER_TILE_ > 1, stores will repeatedly target the same buffer.If multi-buffer isnβt supported here yet, add a guard; otherwise implement the same wrap logic as
Smem_tile_hopper_gmma_row_a:--- a/csrc/fmha_v2/fmha/hopper/smem_tile.h +++ b/csrc/fmha_v2/fmha/hopper/smem_tile.h @@ -376,15 +376,13 @@ struct Smem_tile_hopper_gmma_col_b { // Move the write offset to next buffer. inline __device__ void move_next_write_buffer() { - // if( BUFFERS_PER_TILE > 1 ) { - // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) - // ? -BYTES_PER_TILE_INC_BOUNDARY - // : BYTES_PER_BUFFER; - // } + if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + } } @@ -538,15 +536,13 @@ struct Smem_tile_hopper_gmma_row_b { // Move the write offset to next buffer. inline __device__ void move_next_write_buffer() { - // if( BUFFERS_PER_TILE > 1 ) { - // this->smem_write_offset_ += ( smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY ) - // ? -BYTES_PER_TILE_INC_BOUNDARY - // : BYTES_PER_BUFFER; - // } + if (BUFFERS_PER_TILE > 1) { + this->smem_write_offset_ += (smem_write_offset_ >= BYTES_PER_TILE_INC_BOUNDARY) + ? -BYTES_PER_TILE_INC_BOUNDARY + : BYTES_PER_BUFFER; + } }Also applies to: 407-562
csrc/fmha_v2/fmha/fragment.h-785-809 (1)
785-809: Type mismatch: Comment indicates FP32 accumulation but Fragment usesbf16_t.The comment on line 785 states "BF16 MMA must accumulate with at least FP32", and the PTX instruction uses
.f32for the accumulator type. However, the struct inherits fromFragment<bf16_t, 8>which storesbf16_telements, notfloat. This appears to be a bug - the accumulator should usefloatto match the PTX semantics.// BF16 MMA must accumulate with at least FP32 template <> -struct Fragment_accumulator<Ampere_hmma_bf16_bf16_traits> : public Fragment<bf16_t, 8> { +struct Fragment_accumulator<Ampere_hmma_bf16_bf16_traits> : public Fragment<float, 8> {Alternatively, if
bf16_taccumulation is intentional, the PTX instruction should be changed to use.bf16accumulator (if supported).csrc/fmha_v2/fmha/gmem_tile_o.h-275-315 (1)
275-315: Potential UB:o_scratch_ptr_pointer arithmetic when it can benullptr.Under
USE_DEMO_BERT_PARAMS,o_scratch_ptr_is initialized tonullptrbut later you still doo_scratch_ptr_ += .... Even if not dereferenced, pointer arithmetic on null is UB in C++. Consider guarding the adjustment under the same macro (or initialize to a valid dummy pointer that is never dereferenced).csrc/fmha_v2/fmha/hopper/gmem_tile_qkv_packed.h-76-105 (1)
76-105:cta_row_offset/tidxunused: likely a correctness gap (or at least confusing API).
cta_row_offsetis accepted but not applied tocoord[1](row/sequence coordinate). If callers rely oncta_row_offset(as in the non-TMA tiles), this will read the wrong rows. Ifblock_info.sum_salready includes the offset everywhere, please add a comment and consider removing the parameter to avoid drift.csrc/fmha_v2/fmha/hopper/tma_descriptor.h-117-121 (1)
117-121:copy_to_deviceignores the stream parameter - uses synchronous memcpy.The
streamparameter is accepted but not used.cudaMemcpyis synchronous and will block the CPU. If async behavior is intended, usecudaMemcpyAsync.// copy the desc to device memory void copy_to_device(void* desc_ptr_d_, cudaStream_t stream = 0) { - FMHA_CHECK_CUDA(cudaMemcpy(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, - cudaMemcpyHostToDevice)); + FMHA_CHECK_CUDA(cudaMemcpyAsync(desc_ptr_d_, desc_ptr_h, TMA_DESC_SIZE_IN_BYTE * batch_size, + cudaMemcpyHostToDevice, stream)); }csrc/fmha_v2/fmha/hopper/tma_descriptor.h-26-47 (1)
26-47: Class violates Rule of 5 - potential for double-free or use-after-free.The class manages heap memory via
new[]/delete[]but only defines a destructor. If an instance is copied, both objects will attempt todelete[]the same memory.Either delete the copy/move operations or implement them correctly:
// ctor Multiple_tma_descriptor() = default; + // Disable copy operations to prevent double-free + Multiple_tma_descriptor(const Multiple_tma_descriptor&) = delete; + Multiple_tma_descriptor& operator=(const Multiple_tma_descriptor&) = delete; + + // Enable move operations + Multiple_tma_descriptor(Multiple_tma_descriptor&& other) noexcept + : desc_ptr_h(other.desc_ptr_h), desc_ptr_d(other.desc_ptr_d), batch_size(other.batch_size) { + other.desc_ptr_h = nullptr; + other.desc_ptr_d = nullptr; + other.batch_size = 0; + } + + Multiple_tma_descriptor& operator=(Multiple_tma_descriptor&& other) noexcept { + if (this != &other) { + if (batch_size > 0) delete[] desc_ptr_h; + desc_ptr_h = other.desc_ptr_h; + desc_ptr_d = other.desc_ptr_d; + batch_size = other.batch_size; + other.desc_ptr_h = nullptr; + other.desc_ptr_d = nullptr; + other.batch_size = 0; + } + return *this; + } + // destructor. ~Multiple_tma_descriptor() {csrc/batch_decode.cu-139-141 (1)
139-141: Add same-device assertions for workspace buffers vsq(prevents cross-device pointer hazards).
This function reads/writes viafloat_workspace_buffer/int_workspace_buffer(Lines 122-124) but sets device+stream fromq(Lines 139-140). If callers ever pass buffers on a different GPU, this becomes undefined behavior.@@ void* float_buffer = static_cast<void*>(float_workspace_buffer.data_ptr()); void* int_buffer = static_cast<void*>(int_workspace_buffer.data_ptr()); + + TVM_FFI_ICHECK_EQ(float_workspace_buffer.device().device_id, q.device().device_id) + << "float_workspace_buffer and q must be on the same CUDA device"; + TVM_FFI_ICHECK_EQ(int_workspace_buffer.device().device_id, q.device().device_id) + << "int_workspace_buffer and q must be on the same CUDA device"; @@ ffi::CUDADeviceGuard device_guard(q.device().device_id); const cudaStream_t stream = get_stream(q.device());csrc/batch_pod.cu-48-60 (1)
48-60:maybe_custom_mask_p/dare unused: likely broken forMaskMode::kCustom. (Line 48-49, Line 57-58, Line 158-325)
If custom mask is supported here, plumb the mask buffer(s) intoPrefillParams/DecodeParamsunderif constexpr (MASK_MODE_* == MaskMode::kCustom); otherwise, reject custom mask modes with an explicitTVM_FFI_ICHECKto avoid silent misbehavior.Also applies to: 158-325
csrc/batch_pod.cu-103-157 (1)
103-157: Enforce same-device (and ideally same-stream semantics) for prefill vs decode; current code assumes it implicitly. (Line 103-105, Line 154-157)
At minimum, addTVM_FFI_ICHECK_EQ(float_workspace_buffer_p.device().device_id, float_workspace_buffer_d.device().device_id)(and similarly for q/o and caches) before launching.csrc/batch_prefill_fp8_sm90.cu-86-107 (1)
86-107: Useffi::CUDADeviceGuardhere too; avoid rawcudaSetDevicein the new Ragged run path. (Line 103-105)
This keeps device state exception-safe and consistent with the plan() and paged run() paths.- cudaSetDevice(float_workspace_buffer.device().device_id); + ffi::CUDADeviceGuard device_guard(float_workspace_buffer.device().device_id); const cudaStream_t stream = get_stream(float_workspace_buffer.device());csrc/batch_pod.cu-326-348 (1)
326-348: ReplaceassertwithTVM_FFI_ICHECKforsm_aware_schedvalidation (asserts compile out). (Line 334-336)- assert( - sm_aware_sched.ndim() == 1 && sm_aware_sched.size(0) == num_sm + 2 && - "sm_aware_sched tensor has incorrect shape or type, should be (num_sm + 2,) of int32"); + TVM_FFI_ICHECK(sm_aware_sched.ndim() == 1); + TVM_FFI_ICHECK_EQ(sm_aware_sched.size(0), num_sm + 2) + << "sm_aware_sched must be shape (num_sm + 2,)";csrc/batch_prefill_fp8_sm90.cu-133-140 (1)
133-140: Add a head-divisibility guard before computinggroup_size. (Line 135-139)params.num_qo_heads = q.size(1); params.num_kv_heads = k.size(1); + TVM_FFI_ICHECK_GT(params.num_kv_heads, 0); + TVM_FFI_ICHECK_EQ(params.num_qo_heads % params.num_kv_heads, 0) + << "num_qo_heads must be divisible by num_kv_heads"; params.group_size = params.num_qo_heads / params.num_kv_heads;csrc/batch_pod_customize_config.jinja-13-13 (1)
13-13: Avoidusing namespace flashinfer;in a header-like template (namespace pollution / ODR hazards).This template is effectively a header (
#pragma once), sousing namespace flashinfer;can leak into includers and create ambiguous-symbol or ADL surprises.Suggested change (minimal, keeps call sites readable):
-using namespace flashinfer; +namespace fi = flashinfer;β¦and then qualify where needed (e.g.,
fi::PosEncodingMode).Committable suggestion skipped: line range outside the PR's diff.
π‘ Minor comments (14)
benchmarks/bench_sampling.py-223-248 (1)
223-248: Bind loop variables in benchmark lambdas to avoid late-binding bugs (Ruff B023).
Ifbench_gpu_timedefers execution or stores the callable, the current lambdas can pick up mutated loop variables, producing incorrect measurements. Safer to bind via default args.- lambda: flashinfer.sampling.top_p_renorm_probs(probs, p), + lambda probs=probs, p=p: flashinfer.sampling.top_p_renorm_probs(probs, p), ... - lambda: flashinfer.sampling.top_k_renorm_probs(probs, k), + lambda probs=probs, k=k: flashinfer.sampling.top_k_renorm_probs(probs, k), ... - lambda: flashinfer.sampling.top_k_mask_logits(logits, k), + lambda logits=logits, k=k: flashinfer.sampling.top_k_mask_logits(logits, k),Also applies to: 250-275, 277-301
benchmarks/bench_logging_overhead.py-324-333 (1)
324-333: Donβt swallow exceptions silently; keep non-zero exit code (and consider narrowing the catch)
CatchingExceptionis fine for printing a helpful message, but ensure the process fails (CI/automation) and consider catchingRuntimeError+ re-raising unknowns.@@ if __name__ == "__main__": try: main() except KeyboardInterrupt: print("\n\nBenchmark interrupted by user.") except Exception as e: print(f"\n\nError during benchmark: {e}") import traceback - traceback.print_exc() + sys.exit(1)Committable suggestion skipped: line range outside the PR's diff.
benchmarks/bench_logging_overhead.py-224-228 (1)
224-228: Prefer sys.exit over exit in scripts
exit()is intended for interactive use;sys.exit()is clearer and more consistent.@@ if not torch.cuda.is_available(): print("\nError: CUDA is not available. This benchmark requires a CUDA device.") - exit(1) + sys.exit(1)Committable suggestion skipped: line range outside the PR's diff.
benchmarks/bench_rope_quantize_fp8_append_cache.py-197-197 (1)
197-197: Useless conditional: both branches return"NHD".The ternary expression always evaluates to
"NHD"regardless ofconfig_name. This appears to be dead code or an incomplete implementation.If MLA should use a different layout, fix the condition. Otherwise, simplify:
- kv_layout="NHD" if config_name != "mla" else "NHD", + kv_layout="NHD",benchmarks/README.md-169-169 (1)
169-169: Table formatting issue: extra column cell.This row has an extra trailing pipe
|creating an empty third cell, which doesn't match the 2-column table structure. This causes a markdown rendering issue.Apply this fix:
-| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts | | +| `--local_num_experts` | Number of experts handled by this device. Default: equals num_experts |.github/CODEOWNERS-26-26 (1)
26-26: Remove duplicate owner entry.
@Anerudhanappears twice in the owner list forflashinfer/cudnn/.-flashinfer/cudnn/ @Anerudhan @yzh119 @bkryu @cyx-6 @Anerudhan +flashinfer/cudnn/ @Anerudhan @yzh119 @bkryu @cyx-6benchmarks/bench_mixed_attention.py-86-103 (1)
86-103: Add verification for persistent BatchAttention output or mark as unused.
o_persistentis computed but never verified against the baseline outputo. Either add verification similar to the batch POD path (lines 149-151) or prefix with underscore to indicate intentional discard.- o_persistent, _ = wrapper_persistent.run(q, kv_data) + o_persistent, _ = wrapper_persistent.run(q, kv_data) + torch.testing.assert_close( + o_persistent, o, rtol=4e-3, atol=4e-3, msg="Persistent BatchAttention output mismatch!" + )Alternatively, if verification is not needed:
- o_persistent, _ = wrapper_persistent.run(q, kv_data) + _o_persistent, _ = wrapper_persistent.run(q, kv_data)csrc/fmha_v2/fmha/fragment.h-1821-1825 (1)
1821-1825: Array initializers only initialize first element.The brace-initialization
{-HUGE_VALF}and{0}only initialize the first element of each array; the remaining elements are zero-initialized by default. While this works correctly for{0}, it may not be the intended behavior for{-HUGE_VALF}. Consider using explicit loops or std::fill equivalent.The default constructor on lines 1697-1705 already handles proper initialization, so these member initializers could be removed or kept consistent:
- float curr_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; - float curr_sum_[ROWS_PER_THREAD] = {0}; - float prev_max_[ROWS_PER_THREAD] = {-HUGE_VALF}; - ; - float prev_sum_[ROWS_PER_THREAD] = {0}; + float curr_max_[ROWS_PER_THREAD]; + float curr_sum_[ROWS_PER_THREAD]; + float prev_max_[ROWS_PER_THREAD]; + float prev_sum_[ROWS_PER_THREAD];csrc/fmha_v2/fmha/hopper/arrive_wait.h-153-189 (1)
153-189: Uninitializedresult32when predicate is false.When
pred == 0, the PTX@P2predicated instruction does not execute, leavingresult32uninitialized. Theselpinstruction will still run but reads from an undefined predicateP3. Consider initializingresult32to a defined value before the asm block.uint64_t* bar_ptr = reinterpret_cast<uint64_t*>(bar_base_ + id); unsigned smem_ptr = __nvvm_get_smem_pointer(bar_ptr); - uint32_t result32; + uint32_t result32 = 0;csrc/fmha_v2/fmha/hopper/kernel_traits.h-223-227 (1)
223-227: Typo: "marjor" should be "major".- // We know V is row marjor. So we can also deduce the descriptor mode. + // We know V is row major. So we can also deduce the descriptor mode.csrc/fmha_v2/fmha/hopper/tma_descriptor.h-50-81 (1)
50-81: Typo in method name: "desctriptor" should be "descriptor".The method name has a typo that will affect API usability and discoverability.
- int set_tma_desctriptor( + int set_tma_descriptor(Apply the same fix to the second overload at line 84.
csrc/fmha_v2/fmha/hopper/compute_tile.h-160-178 (1)
160-178: Potential copy-paste error: UsingSmem_tile_binstead ofSmem_tile_a.In
increment_gmma_desc_a_group(), line 164 iterates usingSmem_tile_b::Gmma_descriptor::NUM_DESCRIPTORSbut operates ongmma_desc_a_. This appears to be a copy-paste error fromincrement_gmma_desc_b_group(). It should useSmem_tile_afor consistency.inline __device__ void increment_gmma_desc_a_group() { bool reset_buffer = gmma_desc_a_[0].get_descriptor(0) >= gmma_desc_a_[0].get_max_descriptor_0(); #pragma unroll - for (int idx = 0; idx < Smem_tile_b::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) { + for (int idx = 0; idx < Smem_tile_a::Gmma_descriptor::NUM_DESCRIPTORS; ++idx) {csrc/batch_pod_kernel_inst.jinja-16-20 (1)
16-20: Remove the βNot sureβ¦β comment before merge. (Line 18)-// Not sure about the below declaration constexpr auto POS_ENCODING_MODE = PosEncodingMode::kNone;csrc/batch_pod_customize_config.jinja-31-33 (1)
31-33:POS_ENCODING_MODE/USE_LOGITS_SOFT_CAPconstants look inconsistent with the per-(P/D) knobs.You define
POS_ENCODING_MODE_P/DandUSE_LOGITS_SOFT_CAP_P/D, but then hard-codePOS_ENCODING_MODE = kNoneandUSE_LOGITS_SOFT_CAP = false. If intentional (e.g., required by another include path), add a short comment here explaining why these βglobalβ values must be overridden/ignored.
| # NOTE: trtllm-native calls trtllm_batch_decode_with_kv_cache | ||
| "7.5": ["fa2"], | ||
| "8.0": ["fa2", "fa2_tc", "cudnn"], | ||
| "8.6": ["fa2", "fa2_tc", "cudnn"], | ||
| "8.9": ["fa2", "fa2_tc", "cudnn"], | ||
| "9.0": ["fa2", "fa2_tc", "cudnn"], | ||
| "10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"], | ||
| "10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-gen-native"], | ||
| "12.0": ["fa2", "fa2_tc", "cudnn"], | ||
| "9.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"], | ||
| "10.0": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"], | ||
| "10.3": ["fa2", "fa2_tc", "cudnn", "trtllm-gen", "trtllm-native"], | ||
| "12.0": ["fa2", "fa2_tc", "cudnn", "trtllm-native"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
π Script executed:
#!/bin/bash
# Verify trtllm-native backend availability and function references
echo "=== Searching for trtllm-native backend references ==="
rg -n "trtllm-native" --type=py -C2
echo -e "\n=== Searching for trtllm_batch_decode_with_kv_cache function ==="
rg -n "trtllm_batch_decode_with_kv_cache" --type=py -C2
echo -e "\n=== Searching for trtllm backend implementations ==="
rg -n "trtllm" --type=py | grep -i "def\|class\|import" | head -20
echo -e "\n=== Checking benchmark utils file structure ==="
wc -l benchmarks/routines/flashinfer_benchmark_utils.py
echo -e "\n=== Reading lines 160-210 to understand context ==="
sed -n '160,210p' benchmarks/routines/flashinfer_benchmark_utils.pyRepository: flashinfer-ai/flashinfer
Length of output: 192
π Web query:
flashinfer trtllm-native backend support compute capability
π‘ Result:
Short answer: FlashInfer's generic flashinfer backend supports NVIDIA compute capabilities from SM75 (7.5) up through ~SM121 (docs show up to 12.1 in vLLM integration), but the TRTLLM (trtllm-native / TRTLLM attention) path requires SM100 (Blackwell) and access to NVIDIA's cubins/artifactory to run. [1][2][3]
Sources:
- vLLM flashinfer backend: supported compute-capability range and TRTLLM sink logic. [1]
- vLLM utils: TRTLLM attention is gated to SM100 and NVIDIA artifactory availability. [2]
- FlashInfer hardware/backend overview (support matrix showing SM75+ and architecture-specific notes). [3]
π Web query:
trtllm_batch_decode_with_kv_cache flashinfer
π‘ Result:
Brief summary β trtllm_batch_decode_with_kv_cache (FlashInfer)
-
What it is: a FlashInfer lowβlevel attention kernel that performs batched decode (incremental) attention using a paged KV cache (vLLM-style paging) and returns the attention output for decode requests. [1][2]
-
Key inputs you must provide: query tensor ([num_tokens, num_heads, head_dim]), kv_cache (paged layout or tuple of K/V pages), a zeroβinitialized workspace_buffer, block_tables (page table: [batch_size, num_pages]), seq_lens ([batch_size] uint32), and max_seq_len. Many optional params control dtype/scales, kv_layout ('HND' or 'NHD'), backend ('auto'/'xqa'/'trtllm-gen'), window_left, sinks, mask, etc. [1]
-
Layouts/backends: supports HND/NHD kv layouts; backend auto picks trtllmβgen or xqa based on SM and kernel availability (sm_100/103 β trtllmβgen, sm_90/120 β xqa). Programmatic Dependent Launch (enable_pdl) and fp8/fp4 options are supported. [1][2]
-
Relation to vLLM/usage: vLLM integrates FlashInfer and calls trtllm_batch_decode_with_kv_cache for decode paths when FlashInfer backend is enabled; ensure FlashInfer's page table/indirection buffers match vLLM expectations. Example wrappers exist (BatchDecodeWithPagedKVCacheWrapper) for easier integration. [4][2]
-
Related ops: there are sibling APIs for cuDNN batched decode and prefill (cudnn_batch_decode_with_kv_cache, trtllm_batch_context_with_kv_cache / prefill). Use prefill for initial context fill and trtllm_batch_decode_with_kv_cache for incremental singleβtoken decode. [6][5]
If you want, I can:
- show a minimal code example wiring query, kv_cache, workspace_buffer and block_tables; or
- point to the specific FlashInfer docs/pages or the vLLM integration snippet.
Sources: FlashInfer docs for trtllm_batch_decode_with_kv_cache and attention API [1][2], FlashInfer prefill docs [5], vLLM FlashInfer backend reference [4], cuDNN batch decode doc [6].
Remove trtllm-native from SM90 (compute capability 9.0).
According to FlashInfer documentation, the trtllm-native backend requires SM100 (Blackwell) and later, plus NVIDIA artifactory access. Adding it to SM90 (9.0) will cause runtime failures. Supported compute capabilities for trtllm-native should be 10.0, 10.3, and 12.0 only. Also verify that the backend name used in the codebase matches FlashInfer's expected backend identifierβthe auto-selection logic typically chooses between trtllm-gen and xqa based on compute capability, not trtllm-native.
π€ Prompt for AI Agents
In benchmarks/routines/flashinfer_benchmark_utils.py around lines 164 to 172,
remove "trtllm-native" from the SM90 ("9.0") entry so it is only present for
SM100+ entries (keep it in "10.0", "10.3", and "12.0"); ensure the mapping lists
for "10.0", "10.3", and "12.0" include "trtllm-native" and "9.0" does not. Also
verify the backend identifier string used matches FlashInfer's expected name
(update any mismatched identifier in the file to the canonical FlashInfer
backend name) and confirm the auto-selection logic picks between the correct
backends (e.g., "trtllm-gen" vs the alternative used by your code) based on
compute capability rather than assuming "trtllm-native" on 9.0.
| using PrefillParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; | ||
| using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>; | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
π Script executed:
#!/bin/bash
set -euo pipefail
# Search for BatchPrefillPagedParams usage and definition
echo "=== BatchPrefillPagedParams usage ==="
rg -n "BatchPrefillPagedParams\s*<" -S | head -20
echo ""
echo "=== Looking for BatchDecodePagedParams or similar decode params ==="
rg -n "BatchDecodePagedParams\s*<|Batch.*Decode.*Params\s*<" -S | head -20
echo ""
echo "=== Struct/class definitions for BatchPrefillPagedParams ==="
rg -n "struct\s+BatchPrefillPagedParams\b|class\s+BatchPrefillPagedParams\b|using\s+BatchPrefillPagedParams\b" -S | head -20
echo ""
echo "=== Struct/class definitions for BatchDecodePagedParams ==="
rg -n "struct\s+BatchDecodePagedParams\b|class\s+BatchDecodePagedParams\b|using\s+BatchDecodePagedParams\b" -S | head -20
echo ""
echo "=== File content around line 34-36 of batch_pod_customize_config.jinja ==="
find . -name "batch_pod_customize_config.jinja" -type f -exec cat -n {} \; | sed -n '25,45p'Repository: flashinfer-ai/flashinfer
Length of output: 529
π Script executed:
#!/bin/bash
set -euo pipefail
echo "=== Context around batch_pod_customize_config.jinja lines 30-50 ==="
cat -n csrc/batch_pod_customize_config.jinja | sed -n '30,50p'
echo ""
echo "=== Context around pod_customize_config.jinja lines 30-50 ==="
cat -n csrc/pod_customize_config.jinja | sed -n '30,50p'
echo ""
echo "=== Search for where DecodeParams and PrefillParams are used ==="
rg -n "DecodeParams\|PrefillParams" -B 2 -A 2 csrc/batch_pod_customize_config.jinja | head -50Repository: flashinfer-ai/flashinfer
Length of output: 2010
π Script executed:
#!/bin/bash
set -euo pipefail
echo "=== All Batch*Params type definitions (struct/class/using) ==="
rg -n "Batch.*Params\s*[{<|=]" --type-add 'cpp:*.{cpp,h,hpp,jinja}' | grep -E "(struct|class|using)" | head -30
echo ""
echo "=== Search for Decode in csrc files to understand decode param patterns ==="
rg -n "Decode.*Params|Decode.*<" csrc/ --type-add 'cpp:*.{cpp,h,hpp,jinja}' | head -30
echo ""
echo "=== Check if any kernel files instantiate these templates ==="
rg -n "BatchPrefillPagedParams|PrefillParams|DecodeParams" csrc/ -l | head -20Repository: flashinfer-ai/flashinfer
Length of output: 4370
π Script executed:
#!/bin/bash
set -euo pipefail
echo "=== BatchDecodeParams struct definition ==="
sed -n '100,150p' include/flashinfer/attention/default_decode_params.cuh
echo ""
echo "=== BatchPrefillPagedParams struct definition (excerpt) ==="
sed -n '280,320p' include/flashinfer/attention/default_prefill_params.cuhRepository: flashinfer-ai/flashinfer
Length of output: 2493
DecodeParams must use BatchDecodeParams instead of BatchPrefillPagedParams.
These are distinct types with incompatible struct layouts. BatchDecodeParams is optimized for single-token decode queries (with fields like q_rope_offset, request_indices, kv_tile_indices), while BatchPrefillPagedParams is optimized for variable-length prefill sequences (with fields like q_indptr, maybe_custom_mask, qo_tile_indices, merge_indptr). Using the wrong type will cause incorrect memory access patterns and missing/misaligned fields.
Update line 35 to:
-using DecodeParams = BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO, IdType>;
+using DecodeParams = BatchDecodeParams<DTypeQ, DTypeKV, DTypeO, IdType>;π€ Prompt for AI Agents
In csrc/batch_pod_customize_config.jinja around lines 34 to 36, the typedef for
DecodeParams incorrectly aliases BatchPrefillPagedParams which has a different
layout; change the alias so DecodeParams uses BatchDecodeParams<DTypeQ, DTypeKV,
DTypeO, IdType> instead of BatchPrefillPagedParams<DTypeQ, DTypeKV, DTypeO,
IdType> to match the single-token decode fields (q_rope_offset, request_indices,
kv_tile_indices) and avoid misaligned memory accesses.
| constexpr static int round_down_to_power_two(int x) { | ||
| x = x | (x >> 1); | ||
| x = x | (x >> 2); | ||
| x = x | (x >> 4); | ||
| x = x | (x >> 8); | ||
| x = x | (x >> 16); | ||
| return x - (x >> 1); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add input validation to prevent division by zero and undefined behavior.
The round_down_to_power_two function doesn't handle edge cases:
- For
x = 0, it returns 0, which causes division by zero when computingalibi_neg4_div_h = -4.0f / h_pow_2in line 31. - For
x < 0, the right shift operations exhibit implementation-defined behavior for signed integers in C++, potentially causing incorrect results.
Apply this diff to add validation:
constexpr static int round_down_to_power_two(int x) {
+ if (x <= 0) {
+ return 1; // or handle error appropriately
+ }
x = x | (x >> 1);
x = x | (x >> 2);
x = x | (x >> 4);
x = x | (x >> 8);
x = x | (x >> 16);
return x - (x >> 1);
}Alternatively, document the precondition that x must be positive and add assertions in debug builds.
π Committable suggestion
βΌοΈ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| constexpr static int round_down_to_power_two(int x) { | |
| x = x | (x >> 1); | |
| x = x | (x >> 2); | |
| x = x | (x >> 4); | |
| x = x | (x >> 8); | |
| x = x | (x >> 16); | |
| return x - (x >> 1); | |
| } | |
| constexpr static int round_down_to_power_two(int x) { | |
| if (x <= 0) { | |
| return 1; // or handle error appropriately | |
| } | |
| x = x | (x >> 1); | |
| x = x | (x >> 2); | |
| x = x | (x >> 4); | |
| x = x | (x >> 8); | |
| x = x | (x >> 16); | |
| return x - (x >> 1); | |
| } |
π€ Prompt for AI Agents
In csrc/fmha_v2/fmha/alibi_params.h around lines 18 to 25, the
round_down_to_power_two function lacks input validation and uses signed right
shifts which are implementation-defined for negative values; add a precondition
check and safe unsigned behavior: require x > 0 (document it), and in code
either assert(x > 0) in debug builds or return a safe default/error for x == 0
to avoid downstream division by zero at the alibi_neg4_div_h computation, and
change the algorithm to operate on an unsigned integer (or cast x to unsigned)
before performing the bitwise right shifts to avoid undefined behavior for
negative inputs. Ensure the chosen behavior is documented in comments and
callers are updated if you convert the parameter to unsigned.
| AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) { | ||
| h_pow_2 = round_down_to_power_two(h); | ||
| alibi_neg4_div_h = -4.0f / h_pow_2; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Validate h parameter to prevent division by zero.
The constructor computes alibi_neg4_div_h = -4.0f / h_pow_2 without validating that h > 0. If an invalid h value is passed (e.g., 0 or negative), h_pow_2 will be 0, resulting in division by zero or infinity.
Consider adding validation:
AlibiParams(int h, float scale_after_alibi = 1.f) : scale_after_alibi(scale_after_alibi) {
+ // Validate h is positive
+ if (h <= 0) {
+ // Handle error: throw exception, assert, or set safe defaults
+ }
h_pow_2 = round_down_to_power_two(h);
alibi_neg4_div_h = -4.0f / h_pow_2;
}Committable suggestion skipped: line range outside the PR's diff.
π€ Prompt for AI Agents
In csrc/fmha_v2/fmha/alibi_params.h around lines 29 to 32, the constructor uses
h to compute h_pow_2 and then performs alibi_neg4_div_h = -4.0f / h_pow_2
without validating h, which can lead to division by zero if h <= 0; add
validation to ensure h is positive (e.g., if h <= 0 throw/abort or set a safe
default and log), compute h_pow_2 only after validation, and handle the error
path deterministically so alibi_neg4_div_h is never computed with h_pow_2 == 0.
| // Make sure we use a single register to store predicates. | ||
| static_assert(PRED_REGS == 1, ""); | ||
|
|
||
| // We do not use LDGSTS (for the moment). | ||
| enum { USE_LDGSTS = USE_LDGSTS_ }; | ||
|
|
||
| // Ctor. | ||
| template <typename Params, typename Block_info> | ||
| inline __device__ Gmem_tile_qkv(Params const& params, int qkv_offset, Block_info const& binfo, | ||
| int tidx, int cta_row_offset = 0) | ||
|
|
||
| // in PACKED_QKV, q_stride = k_stride = v_stride | ||
| : params_qkv_stride_in_bytes_(params.q_stride_in_bytes), | ||
| qkv_ptr_(reinterpret_cast<char const*>(params.qkv_ptr)) { | ||
| // Compute the position in the sequence (within the CTA for the moment). | ||
| int row = tidx / THREADS_PER_ROW; | ||
| // Compute the position of the thread in the row. | ||
| int col = tidx % THREADS_PER_ROW; | ||
|
|
||
| // Prepare predicates. | ||
| uint32_t preds[LDGS]; | ||
| #pragma unroll | ||
| for (int ii = 0; ii < LDGS; ++ii) { | ||
| preds[ii] = row + ii * ROWS_PER_LDG < ROWS; | ||
| } | ||
|
|
||
| // Pack the predicates. | ||
| preds_[0] = fmha::pack_predicates(preds); | ||
|
|
||
| // The row offset in the batched GEMM. For each seq element, we store QKV in that order. | ||
| int64_t row_offset = (int64_t)(row + cta_row_offset) * params_qkv_stride_in_bytes_; | ||
| // Add the block index. | ||
| int idx; | ||
| if (HEADS_INTERLEAVED) { | ||
| idx = binfo.bidx * NUM_MATS + qkv_offset; | ||
| } else { | ||
| idx = (params.b * params.s * NUM_MATS + qkv_offset) * params.h + binfo.bidh; | ||
| } | ||
| // Assemble the final pointer. | ||
| qkv_ptr_ += row_offset + idx * VALID_BYTES_PER_ROW + col * BYTES_PER_LDG; | ||
|
|
||
| // active threads | ||
| is_active_ = col < VALID_THREADS_PER_ROW; | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressing + predication bugs: non-interleaved idx and missing column bounds in load(Smem_tile&).
idx = (params.b * params.s * NUM_MATS + qkv_offset) * params.h + binfo.bidh;looks wrong (uses sizes, not indices) and will likely OOB.load(Smem_tile&)runs for inactive columns too; ctor predicates only check row. This can dereference pastVALID_BYTES_PER_ROW.
Suggested fix (layout-dependent parts marked by the most conservative assumption: binfo.sum_s is the per-batch sequence index):
#pragma once
+
+#include <fmha/utils.h> // Div_up, Compute_number_of_pred_regs, pack_predicates, ldg/stg, etc.
+#include <fmha/traits.h> // if needed for Traits/Cta_tile plumbing
+#include <cstdint>
namespace fmha {
namespace v1 {
@@
- // Make sure we use a single register to store predicates.
- static_assert(PRED_REGS == 1, "");
+ // Only require single-reg predicates when using LDGSTS (or when the implementation only supports it).
+ static_assert(!USE_LDGSTS_ || PRED_REGS == 1, "");
@@
// Prepare predicates.
uint32_t preds[LDGS];
#pragma unroll
for (int ii = 0; ii < LDGS; ++ii) {
- preds[ii] = row + ii * ROWS_PER_LDG < ROWS;
+ preds[ii] = (row + ii * ROWS_PER_LDG < ROWS) && (col < VALID_THREADS_PER_ROW);
}
// Pack the predicates.
preds_[0] = fmha::pack_predicates(preds);
@@
int idx;
if (HEADS_INTERLEAVED) {
idx = binfo.bidx * NUM_MATS + qkv_offset;
} else {
- idx = (params.b * params.s * NUM_MATS + qkv_offset) * params.h + binfo.bidh;
+ idx = (binfo.sum_s * NUM_MATS + qkv_offset) * params.h + binfo.bidh;
}
@@
inline __device__ void load(Smem_tile& smem_tile) {
void const* ptrs[LDGS];
#pragma unroll
for (int ii = 0; ii < LDGS; ++ii) {
ptrs[ii] = qkv_ptr_ + (int64_t)ii * ROWS_PER_LDG * params_qkv_stride_in_bytes_;
}
if (USE_LDGSTS) {
smem_tile.store(ptrs, preds_);
} else {
fmha::ldg(fetch_, ptrs, preds_);
}
}If the intended indexing is different, please adjust the idx formula accordinglyβbut the current one is almost certainly incorrect.
Also applies to: 122-147
π€ Prompt for AI Agents
In csrc/fmha_v2/fmha/gmem_tile_qkv.h around lines 69-112, the non-interleaved
idx computation uses size fields and can OOB and the constructor only predicates
rows (so inactive columns can still be used); replace the idx formula to use the
per-batch sequence index from Block_info (e.g. use binfo.sum_s or the
appropriate per-batch/per-seq index from binfo) so idx is computed from concrete
indices not size constants, add a column check when building preds (AND each
preds[ii] with (col + ii*COL_STRIDE < VALID_THREADS_PER_ROW) or simply include
col < VALID_THREADS_PER_ROW) and set/pack predicates accordingly, and ensure
load(Smem_tile&) checks is_active_ (or the packed predicate) before doing any
memory stores/loads; apply the same fixes to the corresponding code at lines
~122-147.
| #if (CUDA_CP_ASYNC_MBARRIER_ARRIVE_ACTIVATED) && (CUDACC_VERSION >= 111) | ||
| #define CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED 1 | ||
| #endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
π Script executed:
#!/bin/bash
# Search for CUDACC_VERSION in the codebase
echo "=== Searching for CUDACC_VERSION ==="
rg -n "CUDACC_VERSION" --type cpp --type h
echo ""
echo "=== Searching for __CUDACC_VER_ macros ==="
rg -n "__CUDACC_VER_" --type cpp --type h | head -20
echo ""
echo "=== Checking the file in question ==="
cat -n csrc/fmha_v2/fmha/hopper/arrive_wait.h | sed -n '50,65p'Repository: flashinfer-ai/flashinfer
Length of output: 3863
π Web query:
CUDA version macros __CUDACC_VERSION__ CUDACC_VERSION CUDA compiler version detection
π‘ Result:
Short answer β which macros to use and what they mean:
- CUDACC / NVCC β defined when nvcc is compiling CUDA sources. Use these to detect nvcc compilation. [1][2]
- CUDACC_VER_MAJOR, CUDACC_VER_MINOR, CUDACC_VER_BUILD β predefined by nvcc with the nvcc compiler major/minor/build numbers; use them in preprocessor checks to detect the nvcc compiler version. Example: #if defined(CUDACC) && (CUDACC_VER_MAJOR > 11 || (CUDACC_VER_MAJOR==11 && CUDACC_VER_MINOR>=8)) ... #endif. [1][4]
- CUDA_ARCH β defined in device compilation to indicate target GPU architecture (SM) value (e.g. 800 for compute_80). Use it for device-only feature selection. [2]
- CUDA_VERSION / CUDART_VERSION β macros in CUDA headers (cuda.h / cuda_runtime_api.h) that encode the toolkit/driver API versions (not the same as nvcc macros). Use these to test toolkit/runtime API versions. Note these are header defines (runtime vs driver/toolkit semantics differ). [5]
Notes / alternatives:
- Newer NVIDIA libraries provide helper macros (e.g. _CCCL_CUDA_COMPILER / _CCCL_CUDACC_AT_LEAST) for readable version checks; these are library-specific utilities, not universal compiler defines. [1]
- For non-nvcc CUDA compilers (nvc++, clangβs CUDA front end) there are different predefined macros (e.g. __NVCOMPILER, NVCOMPILER_MAJOR/NVCOMPILER_MINOR for nvc++). Check that compilerβs docs/headers. [3]
Sources:
[1] NVIDIA nvcc / CUDA compiler docs (nvcc predefined macros including _CUDACC_VER*).
[2] NVIDIA CUDA Programming Guide β CUDA_ARCH.
[3] StackOverflow / examples for nvc++ and __NVCOMPILER macros.
[4] nvcc docs / nvcc driver reference listing CUDACC_VER_MAJOR/MINOR/BUILD.
[5] StackOverflow discussion on CUDA_VERSION vs CUDART_VERSION.
If you want, I can paste minimal example code for common version checks (nvcc, toolkit, and device-arch).
Fix undefined macro CUDACC_VERSION on line 56.
The macro CUDACC_VERSION is not a standard CUDA compiler macro and is never defined in the codebase. The condition will always be false, preventing CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED from ever being defined. The file already uses the standard macros __CUDACC_VER_MAJOR__ and __CUDACC_VER_MINOR__ (lines 17-18, 31, 44). Line 56 should use one of these instead, likely (__CUDACC_VER_MAJOR__ >= 11) to match the pattern used elsewhere in the file.
π€ Prompt for AI Agents
In csrc/fmha_v2/fmha/hopper/arrive_wait.h around lines 56 to 58, the
preprocessor check uses the undefined macro CUDACC_VERSION which always
evaluates false; change the condition to use the existing standard macro form
(e.g. (__CUDACC_VER_MAJOR__ >= 11)) to match the rest of the file so
CUDA_CP_ASYNC_MBARRIER_WAIT_ACTIVATED is defined correctly when compiling with
CUDA 11+.
| template <typename Traits, typename Cta_tile, int L, int UNROLL_N = 1> | ||
| struct Transposer {}; | ||
|
|
||
| template <typename Traits, typename Cta_tile, int UNROLL_N> | ||
| struct Transposer<Traits, Cta_tile, 128, UNROLL_N> { | ||
| static_assert(Cta_tile::K % 128 == 0); | ||
|
|
||
| enum { | ||
| WARPS_M = Cta_tile::WARPS_M, | ||
| WARPS_N = Cta_tile::WARPS_N, | ||
| WARPS_K = Cta_tile::WARPS_K, | ||
| }; | ||
|
|
||
| enum { | ||
| WARPS_4x1x1 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 1), | ||
| WARPS_4x1x2 = (WARPS_M == 4 && WARPS_N == 1 && WARPS_K == 2), | ||
| }; | ||
|
|
||
| enum { BYTES_PER_LDS = 16 }; | ||
|
|
||
| enum { BYTES_PER_ROW = 128 }; | ||
|
|
||
| // D=64 and 4 warps. | ||
| // Per warp we load 32 rows x 16 columns with LDSM.Tx4, 128 rows per CTA. | ||
| enum { S = Cta_tile::K >= 128 ? 128 : Cta_tile::K }; // The sequence length. | ||
|
|
||
| enum { D = Cta_tile::N >= 128 ? 128 : Cta_tile::N }; // The head dimension. | ||
|
|
||
| // static_assert(S % 128 == 0); | ||
| static_assert(WARPS_4x1x1 || WARPS_4x1x2); | ||
| static_assert(D % (BYTES_PER_LDS * WARPS_K) == 0); | ||
|
|
||
| enum { ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING = 128 }; // LDSMx4 | ||
|
|
||
| enum { ROW_PACKING = BYTES_PER_ROW / (D * sizeof(typename Traits::B_type)) }; | ||
|
|
||
| enum { ROWS_PER_LDSM_PER_CTA = ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING / ROW_PACKING }; | ||
|
|
||
| enum { ROWS_PER_XOR_PATTERN = fmha::Rows_per_xor_pattern_ampere_b<Traits, S>::VALUE }; | ||
|
|
||
| static_assert(ROWS_PER_XOR_PATTERN == 8); | ||
|
|
||
| // The number of loads in K dimension. | ||
| enum { K = S / ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING }; | ||
|
|
||
| // static_assert(K * ROWS_PER_LDSM_PER_CTA_WITHOUT_PACKING == S); | ||
| // static_assert(K == 3); | ||
| // The number of loads in the D dimension. | ||
| enum { N = D / (BYTES_PER_LDS * WARPS_K) }; // 16 bytes per load | ||
|
|
||
| static_assert(N * BYTES_PER_LDS * WARPS_K == D); | ||
|
|
||
| uint4 regs_[UNROLL_N][K]; | ||
|
|
||
| uint32_t read_offset_; | ||
| uint32_t write_offset_; | ||
| uint32_t smem_read_loc_; | ||
| uint32_t smem_write_loc_; | ||
|
|
||
| inline __device__ Transposer(int tidx) { | ||
| int read_row, read_col; | ||
|
|
||
| if (WARPS_4x1x1 && N == 8) { // D=128, 1 warp in N | ||
| read_row = (tidx & 0x7f); | ||
| read_col = (tidx & 0x07); | ||
| } else if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N | ||
| read_row = (tidx & 0xe0) / 2 + (tidx & 0x1e) / 2; | ||
| read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; | ||
| } else if (WARPS_4x1x1 && N == 2) { // D=32, 1 warp in N | ||
| read_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 4; | ||
| read_col = (tidx & 0x03) * 2; | ||
| read_col ^= (read_row & 0x01); | ||
| } else if (WARPS_4x1x2 && N == 4) { // D=128, 2 warps in N | ||
| read_row = (tidx & 0x7f); | ||
| read_col = (tidx & 0x07); | ||
| // For two warpgroups we do two steps in N at once. | ||
| read_col ^= (tidx & 0x80) / 128; | ||
| } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N | ||
| read_row = (tidx & 0x60) / 2 + (tidx & 0x1e) / 2; | ||
| read_col = (tidx & 0x01) * 4 + (tidx & 0x06) / 2; | ||
| // For two warpgroups we do two steps in N at once. | ||
| read_col ^= (tidx & 0x80) / 128; | ||
| } else if (WARPS_4x1x2 && N == 1) { // D=32, 2 warps in N | ||
| read_row = (tidx & 0x60) / 4 + (tidx & 0x1c) / 4; | ||
| read_col = (tidx & 0x03) * 2; | ||
| read_col ^= (read_row & 0x01); | ||
| // For two warpgroups we do two steps in N at once. | ||
| read_col ^= (tidx & 0x80) / 128; | ||
| } else { | ||
| assert(false); | ||
| } | ||
|
|
||
| read_offset_ = read_row * BYTES_PER_ROW + read_col * BYTES_PER_LDS; | ||
|
|
||
| int write_row, write_col; | ||
| if (WARPS_4x1x1) { // swizzle_128byte | ||
| write_row = (tidx & 0x10) / 2 + (tidx & 0x07); | ||
| write_col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; | ||
| } else if (WARPS_4x1x2) { | ||
| // Same as above, with second warp group writing next 16 rows. | ||
| write_row = (tidx & 0x80) / 8 + (tidx & 0x10) / 2 + (tidx & 0x07); | ||
| write_col = (tidx & 0x60) / 16 + (tidx & 0x08) / 8; | ||
| } else { | ||
| assert(false); | ||
| } | ||
|
|
||
| write_col ^= (write_row & 0x07); | ||
|
|
||
| write_offset_ = write_row * BYTES_PER_ROW + write_col * BYTES_PER_LDS; | ||
| } | ||
|
|
||
| inline __device__ void transpose(int tidx, uint32_t smem) { transpose_<true>(tidx, smem, smem); } | ||
|
|
||
| template <bool SYNC> | ||
| inline __device__ void transpose_(uint32_t smem_src, uint32_t smem_dst) { | ||
| #pragma unroll | ||
| for (int n_begin = 0; n_begin < N; n_begin += UNROLL_N) { | ||
| transpose_ldmatrix(n_begin, smem_src); | ||
| transpose_stmatrix<SYNC>(n_begin, smem_dst); | ||
| } | ||
| } | ||
|
|
||
| inline __device__ void transpose_ldmatrix(int n_begin, uint32_t smem_src) { | ||
| static_assert(N % UNROLL_N == 0, ""); | ||
|
|
||
| uint4 tmp[UNROLL_N][K]; | ||
| if (n_begin == 0) { | ||
| smem_read_loc_ = smem_src + read_offset_; | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { | ||
| int const nii = ni - n_begin; | ||
| #pragma unroll | ||
| for (int ki = 0; ki < K; ki++) { // 2 | ||
| fmha::ldsmt(tmp[nii][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW); | ||
| } | ||
|
|
||
| if (WARPS_4x1x1 && N == 4) { // D=64, 1 warp in N | ||
| smem_read_loc_ ^= (ni % 2 == 0 ? 1 : 3) * 16; | ||
| } else if (WARPS_4x1x1 && N == 2) { // D=32, 1 warp in N | ||
| smem_read_loc_ ^= 16; | ||
| } else if (WARPS_4x1x2 && N == 2) { // D=64, 2 warps in N | ||
| smem_read_loc_ ^= 32; | ||
| } else if (WARPS_4x1x2 && N == 4) { // D=128, 2 warps in N | ||
| smem_read_loc_ ^= (ni % 2 == 0 ? 1 : 3) * 32; | ||
| } else if (WARPS_4x1x1 && N == 8) { // D=128, 1 warp in N | ||
| smem_read_loc_ ^= ((ni % 4 == 3) ? 7 : (ni % 2 == 1 ? 3 : 1)) * 16; | ||
| } else if (N != 1) { | ||
| assert(false); | ||
| } | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { | ||
| int const nii = ni - n_begin; | ||
| #pragma unroll | ||
| for (int ki = 0; ki < K; ki++) { | ||
| fmha::swizzle_rows(regs_[nii][ki].x, regs_[nii][ki].z, tmp[nii][ki].x, | ||
| tmp[nii][ki].y); // PRMT 0+1 | ||
| fmha::swizzle_rows(regs_[nii][ki].y, regs_[nii][ki].w, tmp[nii][ki].z, | ||
| tmp[nii][ki].w); // PRMT 2+3 | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template <bool SYNC> | ||
| inline __device__ void transpose_stmatrix(int n_begin, uint32_t smem_dst) { | ||
| // After LDSM.Tx4 registers hold 2x2 elts: | ||
| // [00, 01] | ||
| // [10, 11] | ||
| // With row offsets | ||
| // x: + 0 | ||
| // y: + 8 | ||
| // z: +16 (g) | ||
| // w: +24 (o) | ||
| // | ||
| // After PRMT 0, the : | ||
| // [00, 01] [80, 81] => x: [00, 10, 80, 90], i.e. col 0 | ||
| // [10, 11] [90, 91] => z: [01, 11, 81, 91], i.e. col 1 | ||
| // | ||
| // [g0, g1] [o0, o1] => y: [g0, h0, o0, p0], i.e. col 0 | ||
| // [h0, h1] [p0, p1] => w: [g1, h1, o1, p1], i.e. col 1 | ||
| // | ||
| // Therefore, when looking at the transpose, quad q holds cols 2 * q + [0, 1], i.e. | ||
| // - quad 0 holds cols 0, 1 | ||
| // - quad 1 holds cols 2, 3 | ||
| // - etc. | ||
| // | ||
| // This fits with the accumulator layout, since N strides in steps of 8 per thread. | ||
|
|
||
| if (SYNC) { // needed if src and dst are the same. | ||
| __syncthreads(); // LDSM.T done. We should now have a D x S tile in registers. SMEM can be | ||
| // written. | ||
| } | ||
|
|
||
| if (n_begin == 0) { | ||
| smem_write_loc_ = smem_dst + write_offset_; | ||
| } | ||
|
|
||
| #pragma unroll | ||
| for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) { | ||
| int const nii = ni - n_begin; | ||
| #pragma unroll | ||
| for (int ki = 0; ki < K; ki++) { | ||
| fmha::stsm(smem_write_loc_ + ki * BYTES_PER_ROW * D, regs_[nii][ki]); | ||
| } | ||
| if (WARPS_4x1x1) { // D=64, 1 warp in N. | ||
| smem_write_loc_ += 16 * BYTES_PER_ROW; | ||
| } else if (WARPS_4x1x2) { // D=64, 2 warps in N. | ||
| smem_write_loc_ += 32 * BYTES_PER_ROW; | ||
| } else { | ||
| assert(false); | ||
| } | ||
| } | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix Transposer compile errors + OOB indexing in 64B specialization (will fail build / miscompile).
Transposer::transpose()callstranspose_with the wrong number of arguments (both 128B and 64B variants).Transposer<...,64,...>::transpose_ldmatrix()indexestmp[ni][ki]even whenn_begin > 0(out-of-bounds).Smem_tile_v_gmma::Base::transpose_tile()callstranspose_with a non-existent second template parameter.
--- a/csrc/fmha_v2/fmha/hopper/smem_tile.h
+++ b/csrc/fmha_v2/fmha/hopper/smem_tile.h
@@ -1751,7 +1751,9 @@ struct Transposer<Traits, Cta_tile, 128, UNROLL_N> {
- inline __device__ void transpose(int tidx, uint32_t smem) { transpose_<true>(tidx, smem, smem); }
+ inline __device__ void transpose(int /*tidx*/, uint32_t smem) {
+ transpose_<true>(smem, smem);
+ }
@@ -1950,7 +1952,9 @@ struct Transposer<Traits, Cta_tile, 64, UNROLL_N> {
- inline __device__ void transpose(int tidx, uint32_t smem) { transpose_<true>(tidx, smem, smem); }
+ inline __device__ void transpose(int /*tidx*/, uint32_t smem) {
+ transpose_<true>(smem, smem);
+ }
@@ -1969,13 +1973,13 @@ struct Transposer<Traits, Cta_tile, 64, UNROLL_N> {
#pragma unroll
for (int ni = n_begin; ni < n_begin + UNROLL_N; ni++) {
+#pragma unroll
+ int const nii = ni - n_begin;
#pragma unroll
for (int ki = 0; ki < K; ki++) {
- int const nii = ni - n_begin;
- fmha::ldsmt(tmp[ni][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW);
+ fmha::ldsmt(tmp[nii][ki], smem_read_loc_ + ki * ROWS_PER_LDSM_PER_CTA * BYTES_PER_ROW);
}
@@ -2226,7 +2230,7 @@ struct Smem_tile_v_gmma {
template <int UNROLL_N = Transposer::N>
inline __device__ void transpose_tile(uint32_t smem_src, uint32_t smem_dst) {
- transposer.template transpose_<false, UNROLL_N>(smem_src, smem_dst);
+ transposer.template transpose_<false>(smem_src, smem_dst);
}Also applies to: 1858-2046, 2223-2229
π€ Prompt for AI Agents
csrc/fmha_v2/fmha/hopper/smem_tile.h lines 1640-1856: three fixes required β 1)
Fix transpose() wrapper to call transpose_<true>(smem_src, smem_dst) with the
correct parameter order/signature used by transpose_ (make the call match the
template signature that takes (smem_src, smem_dst)); 2) In transpose_ldmatrix()
avoid OOB by indexing the local tmp array with the per-iteration offset nii
(tmp[nii][ki]) everywhere (not tmp[ni][ki]) so accesses are relative to the
UNROLL_N block; 3) In Smem_tile_v_gmma::Base::transpose_tile() (and other
affected places) remove the erroneous extra template parameter when calling
transpose_ so you call transpose_<SYNC>(smem_src, smem_dst) matching the
Transposer API. Ensure all overloads/specializations (128B and 64B) use the same
corrected call/signature.
| template <bool TB> | ||
| struct Hgmma_rfa_fp16<8, TB> { | ||
| static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) { | ||
| #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL) | ||
| int const trans_b = TB ? 1 : 0; | ||
| asm volatile( | ||
| "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " | ||
| "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n" | ||
|
|
||
| : "+r"(acc[0]), "+r"(acc[1]) | ||
| : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b)); | ||
| #endif | ||
| } | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Undefined variable desc_a will cause compilation failure.
The mma function references desc_a in the inline assembly inputs (line 450), but desc_a is not a parameter of this function. For RF-A (A from Register File) variants, operand A comes from registers, not a descriptor.
Apply this diff to fix the assembly constraint:
template <bool TB>
struct Hgmma_rfa_fp16<8, TB> {
static inline __device__ void mma(uint32_t const (&a)[4], uint64_t desc_b, uint32_t (&acc)[2]) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)
int const trans_b = TB ? 1 : 0;
asm volatile(
"wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 "
- "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"
+ "{%0, %1}, {%2, %3, %4, %5}, %6, 1, 1, 1, %7;\n"
: "+r"(acc[0]), "+r"(acc[1])
- : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_a), "l"(desc_b), "n"(trans_b));
+ : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "l"(desc_b), "n"(trans_b));
#endif
}
};| **Supported CUDA Versions:** 12.6, 12.8, 13.0, 13.1 | ||
|
|
||
| > **Note:** FlashInfer strives to follow PyTorch's supported CUDA versions plus the latest CUDA release. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
π§© Analysis chain
π Script executed:
#!/bin/bash
# Check workflow files for CUDA version references
find .github/workflows -name "*.yml" -o -name "*.yaml" 2>/dev/null | while read f; do
if [ -f "$f" ]; then
echo "=== File: $f ==="
grep -i "cuda" "$f" | head -20
fi
done
echo ""
echo "=== Searching for CUDA versions with pattern matching ==="
# Search for CUDA version patterns (12.x, 13.x)
rg "cuda.*(?:12\.|13\.)" --type yaml 2>/dev/null | head -30
echo ""
echo "=== Checking pyproject.toml and setup files ==="
fd -e toml -e cfg -e txt 2>/dev/null | xargs grep -i "cuda" 2>/dev/null | head -30Repository: flashinfer-ai/flashinfer
Length of output: 20500
Update CUDA version support list to match CI configurations.
The README lists CUDA versions as 12.6, 12.8, 13.0, 13.1, but CI/build configurations (release-ci-docker.yml, release.yml, nightly-release.yml) actually support 12.6, 12.8, 12.9, and 13.0. Update the README to include 12.9 and remove 13.1.
π€ Prompt for AI Agents
In README.md around lines 194 to 196, the supported CUDA versions list is out of
sync with CI: it currently shows 12.6, 12.8, 13.0, 13.1 but CI config files
support 12.6, 12.8, 12.9, 13.0. Update the README list to include 12.9 and
remove 13.1 so it reads 12.6, 12.8, 12.9, 13.0, and ensure the accompanying note
remains accurate.
π Description
π Related Issues
π Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
β Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.π§ͺ Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Release Notes
New Features
Improvements
Documentation
βοΈ Tip: You can customize this high-level summary in your review settings.