Skip to content

Commit ef6fc50

Browse files
integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script
stack-info: PR: #2785, branch: danielvegamyhre/stack/44
1 parent fbe08c3 commit ef6fc50

File tree

6 files changed

+340
-103
lines changed

6 files changed

+340
-103
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
blockwise_fp8_gemm_1x128_128x128,
1918
fp8_blockwise_act_quant_lhs,
2019
fp8_blockwise_weight_quant_transposed_rhs,
20+
triton_fp8_gemm_1x128_128x128,
2121
)
2222

2323
device = torch.device("cuda")
@@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
5858
(16640, 5120, 8192),
5959
(16640, 8192, 5120),
6060
]
61-
out_dtypes = [torch.float32, torch.bfloat16]
61+
out_dtypes = [torch.bfloat16]
6262
configs = []
6363
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
6464
m, n, k = mnk
@@ -94,19 +94,21 @@ def warmup(func, *args, **kwargs):
9494

9595
# Warm up then run triton bench
9696
warmup(
97-
blockwise_fp8_gemm_1x128_128x128,
97+
triton_fp8_gemm_1x128_128x128,
9898
A_q,
99-
1.0 / A_s,
10099
B_t_q,
100+
1.0 / A_s,
101101
1.0 / B_t_s,
102+
out_dtype=config.out_dtype,
102103
)
103104

104105
fp8_triton_us = benchmark_cuda_function_in_microseconds(
105-
blockwise_fp8_gemm_1x128_128x128,
106+
triton_fp8_gemm_1x128_128x128,
106107
A_q,
107-
1.0 / A_s,
108108
B_t_q,
109+
1.0 / A_s,
109110
1.0 / B_t_s,
111+
out_dtype=config.out_dtype,
110112
)
111113

112114
# Warm up then run torch bench

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
blockwise_fp8_gemm_1x128_128x1,
1918
fp8_blockwise_act_quant_rhs,
2019
fp8_blockwise_act_quant_transposed_lhs,
20+
triton_fp8_gemm_1x128_128x1,
2121
)
2222

2323
device = torch.device("cuda")
@@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
5858
(16640, 5120, 8192),
5959
(16640, 8192, 5120),
6060
]
61-
out_dtypes = [torch.float32, torch.bfloat16]
61+
out_dtypes = [torch.bfloat16]
6262
configs = []
6363
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
6464
m, n, k = mnk
@@ -92,24 +92,23 @@ def warmup(func, *args, **kwargs):
9292

9393
# Warm up then run triton bench
9494
warmup(
95-
blockwise_fp8_gemm_1x128_128x1,
95+
triton_fp8_gemm_1x128_128x1,
9696
A_t_q,
97-
1.0 / A_t_s,
9897
B_q,
98+
1.0 / A_t_s,
9999
1.0 / B_s,
100+
out_dtype=config.out_dtype,
100101
)
101102

102103
fp8_triton_us = benchmark_cuda_function_in_microseconds(
103-
blockwise_fp8_gemm_1x128_128x1,
104+
triton_fp8_gemm_1x128_128x1,
104105
A_t_q,
105-
1.0 / A_t_s,
106106
B_q,
107+
1.0 / A_t_s,
107108
1.0 / B_s,
109+
out_dtype=config.out_dtype,
108110
)
109111

110-
# torch._scaled_mm requires A_s and B_t_s be in column-major format
111-
A_t_s = A_t_s.t().contiguous().t()
112-
113112
# Warm up then run torch bench
114113
warmup(
115114
torch._scaled_mm,
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from torch.nn import functional as F
15+
from tqdm import tqdm
16+
from triton.testing import do_bench
17+
18+
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear
19+
20+
device = torch.device("cuda")
21+
22+
# This benchmark requires CUDA 12.9+
23+
assert torch.version.cuda is not None, "CUDA is not available"
24+
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
25+
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"
26+
27+
# Needed since changing args to function causes recompiles
28+
torch._dynamo.config.cache_size_limit = 1000
29+
30+
31+
@dataclass(frozen=True)
32+
class ExperimentConfig:
33+
out_dtype: torch.dtype
34+
m: int
35+
n: int
36+
k: int
37+
38+
39+
@dataclass(frozen=True)
40+
class ExperimentResult:
41+
bf16_linear_us: float
42+
fp8_triton_linear_us: float
43+
fp8_scaled_mm_linear_us: float
44+
45+
46+
@dataclass(frozen=True)
47+
class Experiment:
48+
config: ExperimentConfig
49+
result: ExperimentResult
50+
51+
52+
def get_configs() -> List[ExperimentConfig]:
53+
mnk_list = [
54+
# Llama4 shapes
55+
(16640, 5120, 8192),
56+
(16640, 8192, 5120),
57+
]
58+
out_dtypes = [torch.bfloat16]
59+
configs = []
60+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
61+
m, n, k = mnk
62+
configs.append(
63+
ExperimentConfig(
64+
out_dtype=out_dtype,
65+
m=m,
66+
n=n,
67+
k=k,
68+
)
69+
)
70+
return configs
71+
72+
73+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
74+
M, N, K = config.m, config.n, config.k
75+
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
76+
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
77+
fp8_triton_linear = Float8BlockwiseLinear(
78+
K, N, dtype=config.out_dtype, device="cuda", use_triton=True
79+
)
80+
fp8_scaled_mm_linear = Float8BlockwiseLinear(
81+
K, N, dtype=config.out_dtype, device="cuda", use_triton=False
82+
)
83+
84+
def warmup(func, *args, **kwargs):
85+
for _ in range(10):
86+
func(*args, **kwargs)
87+
88+
def fwd_bwd(func, inputs, labels, *args, **kwargs):
89+
out = func(inputs, *args, **kwargs)
90+
loss = F.mse_loss(out, labels)
91+
loss.backward()
92+
torch.cuda.synchronize()
93+
94+
# Warmup then run bf16 torch.mm
95+
labels = inputs.new_empty(M, N).fill_(1.0)
96+
warmup(fwd_bwd, bf16_linear, inputs, labels)
97+
98+
bf16_linear_us = benchmark_cuda_function_in_microseconds(
99+
fwd_bwd, bf16_linear, inputs, labels
100+
)
101+
102+
# Warm up then run triton bench
103+
warmup(
104+
fwd_bwd,
105+
fp8_triton_linear,
106+
inputs,
107+
labels,
108+
)
109+
110+
fp8_triton_linear_us = benchmark_cuda_function_in_microseconds(
111+
fwd_bwd,
112+
fp8_triton_linear,
113+
inputs,
114+
labels,
115+
)
116+
117+
warmup(
118+
fwd_bwd,
119+
fp8_scaled_mm_linear,
120+
inputs,
121+
labels,
122+
)
123+
124+
fp8_scaled_mm_linear_us = benchmark_cuda_function_in_microseconds(
125+
fwd_bwd,
126+
fp8_scaled_mm_linear,
127+
inputs,
128+
labels,
129+
)
130+
131+
return ExperimentResult(
132+
bf16_linear_us=bf16_linear_us,
133+
fp8_triton_linear_us=fp8_triton_linear_us,
134+
fp8_scaled_mm_linear_us=fp8_scaled_mm_linear_us,
135+
)
136+
137+
138+
def print_results(experiments: List[Experiment]):
139+
headers = [
140+
"M",
141+
"N",
142+
"K",
143+
"out_dtype",
144+
"bf16_mm_linear_us",
145+
"fp8_triton_linear_us",
146+
"fp8_scaled_mm_linear_us",
147+
]
148+
rows = []
149+
for experiment in experiments:
150+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
151+
rows.append(
152+
[
153+
m,
154+
n,
155+
k,
156+
experiment.config.out_dtype,
157+
experiment.result.bf16_linear_us,
158+
experiment.result.fp8_triton_linear_us,
159+
experiment.result.fp8_scaled_mm_linear_us,
160+
]
161+
)
162+
print(tabulate(rows, headers=headers))
163+
164+
165+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
166+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
167+
168+
169+
def main():
170+
torch.random.manual_seed(123)
171+
configs = get_configs()
172+
results = []
173+
for config in tqdm(configs):
174+
result = run_experiment(config)
175+
results.append(Experiment(config=config, result=result))
176+
177+
# Use Tabulate to print results
178+
print_results(results)
179+
180+
181+
if __name__ == "__main__":
182+
main()

0 commit comments

Comments
 (0)