Skip to content

Conversation

yaox12
Copy link
Member

@yaox12 yaox12 commented Aug 21, 2025

Description

The functionality is ready but we're not seeing perf gain due to the performance regression of fused activation and quantization kernels, take the input in shape (8*4000, 4096) for example

  • SwiGLU + MXFP8 Quantization

    • BF16 SwiGLU + 8x MXFP8 Quantization: ~256us
    • SwiGLU + MXFP8 fusion: ~343us + 2 padding scaling factor kernels (84 us)
  • SReLU + MXFP8 Quantization

    • BF16 SReLU + 8x MXFP8 Quantization: ~187us
    • SReLU + MXFP8 fusion: ~142us + 2 padding scaling factor kernels (75 us)

For the SReLU case, the CPU overhead of the fused version is slower so overall we can get a slight speedup.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 force-pushed the xiny/quantized_input branch 2 times, most recently from 0163f2f to 35a25f1 Compare August 26, 2025 02:28
@yaox12 yaox12 marked this pull request as ready for review August 26, 2025 02:28
@yaox12
Copy link
Member Author

yaox12 commented Aug 26, 2025

/te-ci pytorch

1 similar comment
@yaox12
Copy link
Member Author

yaox12 commented Aug 26, 2025

/te-ci pytorch

@yaox12 yaox12 requested review from timmoon10 and zhongbozhu August 26, 2025 05:33
@ptrendx
Copy link
Member

ptrendx commented Aug 26, 2025

@yaox12 What do you mean by the "padding scaling factor kernels"? Is that swizzle? If so, then we definitely need to optimize them, they should take max a few percent of the time to do the quantize.

@yaox12
Copy link
Member Author

yaox12 commented Aug 27, 2025

@yaox12 What do you mean by the "padding scaling factor kernels"? Is that swizzle? If so, then we definitely need to optimize them, they should take max a few percent of the time to do the quantize.

No, it's not swizzle. What I'm doing here is trying to quantize (or act and quantize) multiple input tensors as a whole versus splitting them into chunks and then quantize them one by one. Both methods produce exact the same quantized data, but the scaling factors may be different due to padding. So for the "quantize as a whole" way, we need to pad the sf manually.

For example, if we have two input tensors (for two experts) both in shape [64, 128]:

  • Previous way:

    • Activation function in BF16 on the concat tensor in shape [128, 128]
    • Split it into two tensors in shape [64, 128], and quantize them one by one, then we get two quantized tensors with data in shape [64, 128], rowwise scaling factor in shape [128, 4] (padding row to 128), colwise scaling factor in shape [4, 128].
  • With this PR:

    • We do act and quantize fusion on the concat tensor in shape [128, 128], we get one single quantized tensor with data in shape [128, 128], rowwise scaling factor in shape [128, 4], colwise scaling factor in shape [4, 128]
    • Then we need to split the whole quantized tensor into two. The quantized data can be split directly, while for the rowwise sf , we first split it into two [64, 4] and pad them both to [128, 4]. This is the padding I'm talking about.

@yaox12 yaox12 force-pushed the xiny/quantized_input branch 2 times, most recently from ef68f89 to fcd52fc Compare August 27, 2025 10:25
@@ -28,7 +30,7 @@ __device__ inline OType dgelu(const IType val, const Empty&) {
template <typename OType, typename IType>
__device__ inline OType sigmoid(const IType val, const Empty&) {
const float cval = val;
return 1.f / (1.f + expf(-cval));
return sigmoidf(cval);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify the implementation with that in cast_gated_kernels.cuh.

ComputeType after_dgate = grad_val * Activation(gelu_in, p);
ComputeType act_in, dact_in;
if constexpr ((Activation == &silu<fp32, fp32>) && (Dactivation == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(gelu_in);
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unify the implementation with that in cast_gated_kernels.cuh.

@yaox12 yaox12 changed the title [PyTorch] Let GroupedLinear accept MXFP8 input [PyTorch] Let GroupedLinear accept MXFP8 input and gradient Aug 27, 2025
@yaox12
Copy link
Member Author

yaox12 commented Aug 28, 2025

/te-ci

@yaox12 yaox12 force-pushed the xiny/quantized_input branch from c1b8230 to 9461308 Compare August 29, 2025 09:48
@yaox12
Copy link
Member Author

yaox12 commented Aug 29, 2025

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants