Skip to content

Commit c288547

Browse files
improve fp8 blockwise gemm perf
stack-info: PR: #2784, branch: danielvegamyhre/stack/43
1 parent 1526dfe commit c288547

File tree

4 files changed

+442
-25
lines changed

4 files changed

+442
-25
lines changed
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.blockwise_fp8_training.kernels import (
18+
blockwise_fp8_gemm_1x128_128x128,
19+
fp8_blockwise_act_quant_lhs,
20+
fp8_blockwise_weight_quant_transposed_rhs,
21+
)
22+
23+
device = torch.device("cuda")
24+
25+
# This benchmark requires CUDA 12.9+
26+
assert torch.version.cuda is not None, "CUDA is not available"
27+
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
28+
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"
29+
30+
# Needed since changing args to function causes recompiles
31+
torch._dynamo.config.cache_size_limit = 1000
32+
33+
34+
@dataclass(frozen=True)
35+
class ExperimentConfig:
36+
out_dtype: torch.dtype
37+
m: int
38+
n: int
39+
k: int
40+
41+
42+
@dataclass(frozen=True)
43+
class ExperimentResult:
44+
bf16_mm_us: float
45+
fp8_triton_us: float
46+
fp8_scaled_mm_us: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
mnk_list = [
57+
# Llama4 shapes
58+
(16640, 5120, 8192),
59+
(16640, 8192, 5120),
60+
]
61+
out_dtypes = [torch.float32, torch.bfloat16]
62+
configs = []
63+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
64+
m, n, k = mnk
65+
configs.append(
66+
ExperimentConfig(
67+
out_dtype=out_dtype,
68+
m=m,
69+
n=n,
70+
k=k,
71+
)
72+
)
73+
return configs
74+
75+
76+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
77+
# Simulate `grad_input = grad_output @ weight`
78+
M, N, K = config.m, config.n, config.k
79+
A = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
80+
B = torch.randn(N, K, dtype=config.out_dtype, device="cuda")
81+
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
82+
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(
83+
B, dtype=torch.float8_e4m3fn
84+
)
85+
86+
def warmup(func, *args, **kwargs):
87+
for _ in range(10):
88+
func(*args, **kwargs)
89+
90+
# Warmup then run bf16 torch.mm
91+
warmup(torch.mm, A, B.t())
92+
93+
bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A, B.t())
94+
95+
# Warm up then run triton bench
96+
warmup(
97+
blockwise_fp8_gemm_1x128_128x128,
98+
A_q,
99+
1.0 / A_s,
100+
B_t_q,
101+
1.0 / B_t_s,
102+
)
103+
104+
fp8_triton_us = benchmark_cuda_function_in_microseconds(
105+
blockwise_fp8_gemm_1x128_128x128,
106+
A_q,
107+
1.0 / A_s,
108+
B_t_q,
109+
1.0 / B_t_s,
110+
)
111+
112+
# Warm up then run torch bench
113+
# scaled_mm requires A_s and B_t_s be in column-major format
114+
A_s = A_s.t().contiguous().t()
115+
116+
warmup(
117+
torch._scaled_mm,
118+
A_q,
119+
B_t_q,
120+
1.0 / A_s,
121+
1.0 / B_t_s,
122+
out_dtype=config.out_dtype,
123+
)
124+
125+
fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds(
126+
torch._scaled_mm,
127+
A_q,
128+
B_t_q,
129+
1.0 / A_s,
130+
1.0 / B_t_s,
131+
out_dtype=config.out_dtype,
132+
)
133+
134+
return ExperimentResult(
135+
bf16_mm_us=bf16_mm_us,
136+
fp8_triton_us=fp8_triton_us,
137+
fp8_scaled_mm_us=fp8_scaled_mm_us,
138+
)
139+
140+
141+
def print_results(experiments: List[Experiment]):
142+
headers = [
143+
"M",
144+
"N",
145+
"K",
146+
"out_dtype",
147+
"bf16_mm_us",
148+
"fp8_triton_us",
149+
"fp8_scaled_mm_us",
150+
"bf16 tflops/sec",
151+
"triton tflops/sec",
152+
"scaled_mm tflops/sec",
153+
]
154+
rows = []
155+
for experiment in experiments:
156+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
157+
flops = 2 * m * n * k
158+
bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6)
159+
triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6)
160+
scaled_mm_tflops_per_sec = (flops / 1e12) / (
161+
experiment.result.fp8_scaled_mm_us / 1e6
162+
)
163+
rows.append(
164+
[
165+
m,
166+
n,
167+
k,
168+
experiment.config.out_dtype,
169+
experiment.result.bf16_mm_us,
170+
experiment.result.fp8_triton_us,
171+
experiment.result.fp8_scaled_mm_us,
172+
bf16_mm_tflops_per_sec,
173+
triton_tflops_per_sec,
174+
scaled_mm_tflops_per_sec,
175+
]
176+
)
177+
print(tabulate(rows, headers=headers))
178+
179+
180+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
181+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
182+
183+
184+
def main():
185+
torch.random.manual_seed(123)
186+
configs = get_configs()
187+
results = []
188+
for config in tqdm(configs):
189+
result = run_experiment(config)
190+
results.append(Experiment(config=config, result=result))
191+
192+
# Use Tabulate to print results
193+
print_results(results)
194+
195+
196+
if __name__ == "__main__":
197+
main()
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from tqdm import tqdm
15+
from triton.testing import do_bench
16+
17+
from torchao.prototype.blockwise_fp8_training.kernels import (
18+
blockwise_fp8_gemm_1x128_128x1,
19+
fp8_blockwise_act_quant_rhs,
20+
fp8_blockwise_act_quant_transposed_lhs,
21+
)
22+
23+
device = torch.device("cuda")
24+
25+
# This benchmark requires CUDA 12.9+
26+
assert torch.version.cuda is not None, "CUDA is not available"
27+
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
28+
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"
29+
30+
# Needed since changing args to function causes recompiles
31+
torch._dynamo.config.cache_size_limit = 1000
32+
33+
34+
@dataclass(frozen=True)
35+
class ExperimentConfig:
36+
out_dtype: torch.dtype
37+
m: int
38+
n: int
39+
k: int
40+
41+
42+
@dataclass(frozen=True)
43+
class ExperimentResult:
44+
bf16_mm_us: float
45+
fp8_triton_us: float
46+
fp8_scaled_mm_us: float
47+
48+
49+
@dataclass(frozen=True)
50+
class Experiment:
51+
config: ExperimentConfig
52+
result: ExperimentResult
53+
54+
55+
def get_configs() -> List[ExperimentConfig]:
56+
mnk_list = [
57+
# Llama4 shapes
58+
(16640, 5120, 8192),
59+
(16640, 8192, 5120),
60+
]
61+
out_dtypes = [torch.float32, torch.bfloat16]
62+
configs = []
63+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
64+
m, n, k = mnk
65+
configs.append(
66+
ExperimentConfig(
67+
out_dtype=out_dtype,
68+
m=m,
69+
n=n,
70+
k=k,
71+
)
72+
)
73+
return configs
74+
75+
76+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
77+
# Simulate `grad_weight = grad_output_t @ input`
78+
M, N, K = config.m, config.n, config.k
79+
A = torch.randn(M, N, dtype=config.out_dtype, device="cuda")
80+
B = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
81+
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=torch.float8_e4m3fn)
82+
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn)
83+
84+
def warmup(func, *args, **kwargs):
85+
for _ in range(10):
86+
func(*args, **kwargs)
87+
88+
# Warmup then run bf16 torch.mm
89+
warmup(torch.mm, A.t(), B)
90+
91+
bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A.t(), B)
92+
93+
# Warm up then run triton bench
94+
warmup(
95+
blockwise_fp8_gemm_1x128_128x1,
96+
A_t_q,
97+
1.0 / A_t_s,
98+
B_q,
99+
1.0 / B_s,
100+
)
101+
102+
fp8_triton_us = benchmark_cuda_function_in_microseconds(
103+
blockwise_fp8_gemm_1x128_128x1,
104+
A_t_q,
105+
1.0 / A_t_s,
106+
B_q,
107+
1.0 / B_s,
108+
)
109+
110+
# torch._scaled_mm requires A_s and B_t_s be in column-major format
111+
A_t_s = A_t_s.t().contiguous().t()
112+
113+
# Warm up then run torch bench
114+
warmup(
115+
torch._scaled_mm,
116+
A_t_q,
117+
B_q,
118+
1.0 / A_t_s,
119+
1.0 / B_s,
120+
out_dtype=config.out_dtype,
121+
)
122+
123+
fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds(
124+
torch._scaled_mm,
125+
A_t_q,
126+
B_q,
127+
1.0 / A_t_s,
128+
1.0 / B_s,
129+
out_dtype=config.out_dtype,
130+
)
131+
132+
return ExperimentResult(
133+
bf16_mm_us=bf16_mm_us,
134+
fp8_triton_us=fp8_triton_us,
135+
fp8_scaled_mm_us=fp8_scaled_mm_us,
136+
)
137+
138+
139+
def print_results(experiments: List[Experiment]):
140+
headers = [
141+
"M",
142+
"N",
143+
"K",
144+
"out_dtype",
145+
"bf16_mm_us",
146+
"fp8_triton_us",
147+
"fp8_scaled_mm_us",
148+
"bf16 tflops/sec",
149+
"triton tflops/sec",
150+
"scaled_mm tflops/sec",
151+
]
152+
rows = []
153+
for experiment in experiments:
154+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
155+
flops = 2 * m * n * k
156+
bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6)
157+
triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6)
158+
scaled_mm_tflops_per_sec = (flops / 1e12) / (
159+
experiment.result.fp8_scaled_mm_us / 1e6
160+
)
161+
rows.append(
162+
[
163+
m,
164+
n,
165+
k,
166+
experiment.config.out_dtype,
167+
experiment.result.bf16_mm_us,
168+
experiment.result.fp8_triton_us,
169+
experiment.result.fp8_scaled_mm_us,
170+
bf16_mm_tflops_per_sec,
171+
triton_tflops_per_sec,
172+
scaled_mm_tflops_per_sec,
173+
]
174+
)
175+
print(tabulate(rows, headers=headers))
176+
177+
178+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
179+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
180+
181+
182+
def main():
183+
torch.random.manual_seed(123)
184+
configs = get_configs()
185+
results = []
186+
for config in tqdm(configs):
187+
result = run_experiment(config)
188+
results.append(Experiment(config=config, result=result))
189+
190+
# Use Tabulate to print results
191+
print_results(results)
192+
193+
194+
if __name__ == "__main__":
195+
main()

0 commit comments

Comments
 (0)