Skip to content

Add torch compliant grouped gemm API for CK FP8 rowwise #4486

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

cthi
Copy link
Contributor

@cthi cthi commented Jul 14, 2025

Summary:
For PyTorch integration we will need to support several additional cases, as well as leverage slightly different API. This is best observed through the torch test cases, e.g. test_scaled_grouped_gemm_2d_3d, test_scaled_grouped_gemm_3d_2d

A summary is we need these cases:

|Input Type | Notes |
| 2D-3D | same as fbgemm stacked for MoE |
| 3D-2D | not sure use-case for this yet |
| 2D-2D | I think this is for backward? |
| 3D-3D (BMM) | Could alternatively leverage FBGEMM BMM kernel |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.

  • For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
  • We don't add support for 2D-2D yet, that will come after.
  • I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78119166

Copy link

netlify bot commented Jul 14, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 4758f24
🔍 Latest deploy log https://app.netlify.com/projects/pytorch-fbgemm-docs/deploys/687fe832d12f13000886a81e
😎 Deploy Preview https://deploy-preview-4486--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify project configuration.

@cthi cthi force-pushed the export-D78119166 branch from 53208aa to 5ce73a7 Compare July 15, 2025 21:09
cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 15, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, as well as leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- [Natalias grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | not sure use-case for this yet |
| 2D-2D   | I think this is for backward?  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78119166

cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 15, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, as well as leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- [Natalias grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | not sure use-case for this yet |
| 2D-2D   | I think this is for backward?  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@cthi cthi force-pushed the export-D78119166 branch from 5ce73a7 to f632ac9 Compare July 15, 2025 21:46
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78119166

@cthi cthi force-pushed the export-D78119166 branch from f632ac9 to 91030cd Compare July 16, 2025 15:44
cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 16, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, and leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- ngimel's [grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | not sure use-case for this yet |
| 2D-2D   | I think this is for backward?  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78119166

cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 16, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, and leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- ngimel's [grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | not sure use-case for this yet |
| 2D-2D   | I think this is for backward?  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@cthi cthi force-pushed the export-D78119166 branch from 91030cd to e690e64 Compare July 16, 2025 17:08
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78119166

cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 16, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543

Pull Request resolved: pytorch#4486

For PyTorch integration we will need to support several additional cases, and leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- ngimel's [grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | not sure use-case for this yet |
| 2D-2D   | I think this is for backward?  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@cthi cthi force-pushed the export-D78119166 branch 2 times, most recently from 89cb88e to a920a74 Compare July 22, 2025 13:04
cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 22, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, and leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- ngimel's [grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | needed for backwards |
| 2D-2D   | needed for backwards  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D78119166

Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, and leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- ngimel's [grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | needed for backwards |
| 2D-2D   | needed for backwards  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
cthi added a commit to cthi/FBGEMM-1 that referenced this pull request Jul 22, 2025
Summary:
X-link: facebookresearch/FBGEMM#1543


For PyTorch integration we will need to support several additional cases, and leverage slightly different API. This is best observed through the torch test cases, e.g. [test_scaled_grouped_gemm_2d_3d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1793), [test_scaled_grouped_gemm_3d_2d](https://www.internalfb.com/code/fbsource/[fbdb0063f1c1ecca30f5eab8b5341643f680ed51]/fbcode/caffe2/test/test_matmul_cuda.py?lines=1854)
- ngimel's [grouped gemm API doc](https://docs.google.com/document/d/1985La6wUUVH1AGBkNhaGKUXzx-9ybtbUp567-vYVOM4/edit?tab=t.0#heading=h.g8lzbjnyzzx9)

**A summary is we need these cases:**

|**Input Type**   | Notes  |
| 2D-3D   | same as fbgemm stacked for MoE |
| 3D-2D   | needed for backwards |
| 2D-2D   | needed for backwards  |
| 3D-3D (BMM)  | [Could alternatively leverage FBGEMM BMM kernel](https://www.internalfb.com/code/fbsource/fbcode/deeplearning/fbgemm/fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/)  |

Pytorch API uses offsets instead of sizes, so we update the kernel setting the grouped gemm parameters to take in offsets as well, and support the above cases.
- For BMM we could alternatively leverage the AMD FP8 BMM kernel in FBGEMM. But we can get some "free" support by doing this in the grouped kernel.
- I've not yet updated the heuristics to account for the new cases properly. This will come after with a re-tune for generic shapes, as opposed to llama specific.

Differential Revision: D78119166
@cthi cthi force-pushed the export-D78119166 branch 2 times, most recently from 2b51947 to 4758f24 Compare July 22, 2025 19:36
@facebook-github-bot
Copy link
Contributor

This pull request has been merged in 53cde4a.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants