Skip to content

Commit da736d3

Browse files
improve fp8 blockwise gemm perf
stack-info: PR: #2784, branch: danielvegamyhre/stack/43
1 parent 1526dfe commit da736d3

File tree

3 files changed

+246
-25
lines changed

3 files changed

+246
-25
lines changed
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.blockwise_fp8_training.kernels import (
18+
blockwise_fp8_gemm_1x128_128x128,
19+
fp8_blockwise_act_quant_lhs,
20+
fp8_blockwise_weight_quant_transposed_rhs,
21+
)
22+
23+
device = torch.device("cuda")
24+
25+
# This benchmark requires CUDA 12.9+
26+
assert torch.version.cuda is not None, "CUDA is not available"
27+
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
28+
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"
29+
30+
# Needed since changing args to function causes recompiles
31+
torch._dynamo.config.cache_size_limit = 1000
32+
33+
34+
@dataclass(frozen=True)
35+
class ExperimentConfig:
36+
out_dtype: torch.dtype
37+
m: int
38+
n: int
39+
k: int
40+
41+
42+
@dataclass(frozen=True)
43+
class ExperimentResult:
44+
bf16_mm_us: float
45+
fp8_triton_us: float
46+
fp8_scaled_mm_us: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
mnk_list = [
57+
(16640, 5120, 8192),
58+
]
59+
out_dtypes = [torch.float32, torch.bfloat16]
60+
configs = []
61+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
62+
m, n, k = mnk
63+
configs.append(
64+
ExperimentConfig(
65+
out_dtype=out_dtype,
66+
m=m,
67+
n=n,
68+
k=k,
69+
)
70+
)
71+
return configs
72+
73+
74+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
75+
# define test inputs
76+
# Simulate output = input @ weight.T
77+
M, N, K = config.m, config.n, config.k
78+
A = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
79+
B = torch.randn(N, K, dtype=config.out_dtype, device="cuda")
80+
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
81+
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(
82+
B, dtype=torch.float8_e4m3fn
83+
)
84+
85+
def warmup(func, *args, **kwargs):
86+
for _ in range(10):
87+
func(*args, **kwargs)
88+
89+
# Warmup then run bf16 torch.mm
90+
warmup(torch.mm, A, B.t())
91+
92+
bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A, B.t())
93+
94+
# Warm up then run triton bench
95+
warmup(
96+
blockwise_fp8_gemm_1x128_128x128,
97+
A_q,
98+
1.0 / A_s,
99+
B_t_q,
100+
1.0 / B_t_s,
101+
)
102+
103+
fp8_triton_us = benchmark_cuda_function_in_microseconds(
104+
blockwise_fp8_gemm_1x128_128x128,
105+
A_q,
106+
1.0 / A_s,
107+
B_t_q,
108+
1.0 / B_t_s,
109+
)
110+
111+
# Warm up then run torch bench
112+
# scaled_mm requires A_s and B_t_s be in column-major format
113+
A_s = A_s.t().contiguous().t()
114+
115+
warmup(
116+
torch._scaled_mm,
117+
A_q,
118+
B_t_q,
119+
1.0 / A_s,
120+
1.0 / B_t_s,
121+
out_dtype=config.out_dtype,
122+
)
123+
124+
fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds(
125+
torch._scaled_mm,
126+
A_q,
127+
B_t_q,
128+
1.0 / A_s,
129+
1.0 / B_t_s,
130+
out_dtype=config.out_dtype,
131+
)
132+
133+
return ExperimentResult(
134+
bf16_mm_us=bf16_mm_us,
135+
fp8_triton_us=fp8_triton_us,
136+
fp8_scaled_mm_us=fp8_scaled_mm_us,
137+
)
138+
139+
140+
def print_results(experiments: List[Experiment]):
141+
headers = [
142+
"M",
143+
"N",
144+
"K",
145+
"out_dtype",
146+
"bf16_mm_us",
147+
"fp8_triton_us",
148+
"fp8_scaled_mm_us",
149+
"bf16 tflops/sec",
150+
"triton tflops/sec",
151+
"scaled_mm tflops/sec",
152+
]
153+
rows = []
154+
for experiment in experiments:
155+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
156+
flops = 2 * m * n * k
157+
bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6)
158+
triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6)
159+
scaled_mm_tflops_per_sec = (flops / 1e12) / (
160+
experiment.result.fp8_scaled_mm_us / 1e6
161+
)
162+
rows.append(
163+
[
164+
m,
165+
n,
166+
k,
167+
experiment.config.out_dtype,
168+
experiment.result.bf16_mm_us,
169+
experiment.result.fp8_triton_us,
170+
experiment.result.fp8_scaled_mm_us,
171+
bf16_mm_tflops_per_sec,
172+
triton_tflops_per_sec,
173+
scaled_mm_tflops_per_sec,
174+
]
175+
)
176+
print(tabulate(rows, headers=headers))
177+
178+
179+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
180+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
181+
182+
183+
def main():
184+
torch.random.manual_seed(123)
185+
configs = get_configs()
186+
results = []
187+
for config in tqdm(configs):
188+
result = run_experiment(config)
189+
results.append(Experiment(config=config, result=result))
190+
191+
# Use Tabulate to print results
192+
print_results(results)
193+
194+
195+
if __name__ == "__main__":
196+
main()

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
import triton
1111
import triton.language as tl
1212

13+
from torchao.prototype.moe_training.utils import (
14+
_is_column_major,
15+
_is_row_major,
16+
)
17+
1318
fp8_gemm_configs_max_autotune = [
14-
# Small
15-
triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, num_warps=2),
16-
# Medium
17-
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128}, num_warps=4),
18-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_warps=4),
19-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=4),
20-
triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256}, num_warps=8),
21-
# Large
22-
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_warps=8),
23-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=8),
24-
triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256}, num_warps=4),
25-
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_warps=4),
26-
triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_warps=8),
19+
triton.Config(
20+
{"BLOCK_SIZE_M": block_size, "BLOCK_SIZE_N": block_size},
21+
num_warps=num_warps,
22+
num_stages=num_stages,
23+
)
24+
for block_size in [64, 128, 256]
25+
for num_warps in [4, 8]
26+
for num_stages in [2, 4]
2727
]
2828

2929
# For fast compile times during development.
@@ -57,6 +57,7 @@ def blockwise_fp8_gemm_1x128_128x128_kernel(
5757
M,
5858
N: tl.constexpr,
5959
K: tl.constexpr,
60+
out_dtype: tl.constexpr,
6061
BLOCK_SIZE_M: tl.constexpr,
6162
BLOCK_SIZE_N: tl.constexpr,
6263
BLOCK_SIZE_K: tl.constexpr,
@@ -81,18 +82,16 @@ def blockwise_fp8_gemm_1x128_128x128_kernel(
8182
a_s_base_ptr = a_s_ptr + offs_m * a_s_stride_dim_0
8283
b_s_base_ptr = b_s_ptr + (offs_n // BLOCK_SIZE_K) * b_s_stride_dim_1
8384
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
85+
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
86+
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
8487
for k in range(0, k_num_blocks):
85-
a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
8688
a = tl.load(a_ptrs, mask=a_mask, other=0.0)
87-
88-
b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N)
8989
b = tl.load(b_ptrs, mask=b_mask, other=0.0)
9090

9191
# Reciprocal scales to scale back to dynamic range of output dtype
9292
a_s = tl.load(a_s_base_ptr + k * a_s_stride_dim_1)
9393
b_s = tl.load(b_s_base_ptr + k * b_s_stride_dim_0)
94-
95-
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
94+
accumulator += tl.dot(a, b) * a_s[:, None] * b_s
9695

9796
a_ptrs += BLOCK_SIZE_K * a_stride_dim_1
9897
b_ptrs += BLOCK_SIZE_K * b_stride_dim_0
@@ -109,14 +108,22 @@ def blockwise_fp8_gemm_1x128_128x128(
109108
b: torch.Tensor, # (K, N)
110109
b_s: torch.Tensor, # (K // block_size, N // block_size)
111110
block_size: int = 128,
111+
out_dtype: torch.dtype = torch.float32,
112112
):
113113
# 'a' must be in row-major layout, 'b' must be in column-major layout
114-
assert a.is_contiguous() and not b.is_contiguous()
115-
assert a_s.is_contiguous() and b_s.is_contiguous()
114+
assert _is_row_major(a) and _is_column_major(b), (
115+
"a must be row-major, b must be column-major"
116+
)
117+
118+
# a_scales must be row-major, b_scales must be column-major
119+
assert _is_row_major(a_s) and _is_column_major(b_s), (
120+
"a_s must be row-major, b_s must be column-major"
121+
)
122+
116123
M = a.size(0)
117124
K = a.size(1)
118125
N = b.size(1)
119-
c = a.new_empty(M, N, dtype=torch.bfloat16)
126+
c = a.new_empty(M, N, dtype=out_dtype)
120127
grid = lambda META: (
121128
triton.cdiv(M, META["BLOCK_SIZE_M"]),
122129
triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -140,6 +147,7 @@ def blockwise_fp8_gemm_1x128_128x128(
140147
M,
141148
N,
142149
K,
150+
out_dtype=out_dtype,
143151
BLOCK_SIZE_K=block_size,
144152
)
145153
return c
@@ -217,14 +225,15 @@ def blockwise_fp8_gemm_1x128_128x1(
217225
b: torch.Tensor, # (K, N)
218226
b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales
219227
block_size: int = 128,
228+
out_dtype: torch.dtype = torch.float32,
220229
):
221230
# 'a' must be in row-major layout, 'b' must be in column-major layout
222231
assert a.is_contiguous() and not b.is_contiguous()
223232
assert a_s.is_contiguous() and b_s.is_contiguous()
224233
M = a.size(0)
225234
K = a.size(1)
226235
N = b.size(1)
227-
c = a.new_empty(M, N, dtype=torch.bfloat16)
236+
c = a.new_empty(M, N, dtype=out_dtype)
228237
grid = lambda META: (
229238
triton.cdiv(M, META["BLOCK_SIZE_M"]),
230239
triton.cdiv(N, META["BLOCK_SIZE_N"]),
@@ -674,8 +683,10 @@ def fp8_blockwise_weight_quant_transposed_rhs(
674683
M, N = x.size()
675684
y = torch.empty(N, M, dtype=dtype, device=x.device)
676685
y = y.as_strided(y.size(), (1, y.size(0))) # Column major
677-
s = x.new_empty(
678-
triton.cdiv(N, block_size), triton.cdiv(M, block_size), dtype=torch.float32
686+
n_blocks, m_blocks = triton.cdiv(N, block_size), triton.cdiv(M, block_size)
687+
s = x.new_empty(n_blocks, m_blocks, dtype=torch.float32).as_strided(
688+
(n_blocks, m_blocks), # shape
689+
(1, n_blocks), # stride
679690
)
680691
grid = lambda meta: (
681692
triton.cdiv(M, meta["BLOCK_SIZE"]),

torchao/prototype/moe_training/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,21 @@ def _is_column_major(x: torch.Tensor) -> bool:
290290
A boolean indicating whether the input tensor is column-major.
291291
"""
292292
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
293-
return x.stride(-2) == 1 and x.stride(-1) > 1
293+
return x.stride(-2) == 1
294+
295+
296+
def _is_row_major(x: torch.Tensor) -> bool:
297+
"""
298+
This function checks if the input tensor is row-major.
299+
300+
Args:
301+
x (torch.Tensor): The input tensor to be checked.
302+
303+
Returns:
304+
A boolean indicating whether the input tensor is row-major.
305+
"""
306+
assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D"
307+
return x.stride(-1) == 1
294308

295309

296310
def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):

0 commit comments

Comments
 (0)