Skip to content

Commit 3c78e2a

Browse files
integrate torch._scaled_mm into Float8BlockwiseLinear and add bench script
stack-info: PR: #2785, branch: danielvegamyhre/stack/44
1 parent 253d65a commit 3c78e2a

File tree

9 files changed

+331
-98
lines changed

9 files changed

+331
-98
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
blockwise_fp8_gemm_1x128_128x128,
1918
fp8_blockwise_act_quant_lhs,
2019
fp8_blockwise_weight_quant_transposed_rhs,
20+
triton_fp8_gemm_1x128_128x128,
2121
)
2222

2323
device = torch.device("cuda")
@@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
5858
(16640, 5120, 8192),
5959
(16640, 8192, 5120),
6060
]
61-
out_dtypes = [torch.float32, torch.bfloat16]
61+
out_dtypes = [torch.bfloat16]
6262
configs = []
6363
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
6464
m, n, k = mnk
@@ -94,19 +94,21 @@ def warmup(func, *args, **kwargs):
9494

9595
# Warm up then run triton bench
9696
warmup(
97-
blockwise_fp8_gemm_1x128_128x128,
97+
triton_fp8_gemm_1x128_128x128,
9898
A_q,
99-
1.0 / A_s,
10099
B_t_q,
100+
1.0 / A_s,
101101
1.0 / B_t_s,
102+
out_dtype=config.out_dtype,
102103
)
103104

104105
fp8_triton_us = benchmark_cuda_function_in_microseconds(
105-
blockwise_fp8_gemm_1x128_128x128,
106+
triton_fp8_gemm_1x128_128x128,
106107
A_q,
107-
1.0 / A_s,
108108
B_t_q,
109+
1.0 / A_s,
109110
1.0 / B_t_s,
111+
out_dtype=config.out_dtype,
110112
)
111113

112114
# Warm up then run torch bench

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
blockwise_fp8_gemm_1x128_128x1,
1918
fp8_blockwise_act_quant_rhs,
2019
fp8_blockwise_act_quant_transposed_lhs,
20+
triton_fp8_gemm_1x128_128x1,
2121
)
2222

2323
device = torch.device("cuda")
@@ -58,7 +58,7 @@ def get_configs() -> List[ExperimentConfig]:
5858
(16640, 5120, 8192),
5959
(16640, 8192, 5120),
6060
]
61-
out_dtypes = [torch.float32, torch.bfloat16]
61+
out_dtypes = [torch.bfloat16]
6262
configs = []
6363
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
6464
m, n, k = mnk
@@ -92,24 +92,23 @@ def warmup(func, *args, **kwargs):
9292

9393
# Warm up then run triton bench
9494
warmup(
95-
blockwise_fp8_gemm_1x128_128x1,
95+
triton_fp8_gemm_1x128_128x1,
9696
A_t_q,
97-
1.0 / A_t_s,
9897
B_q,
98+
1.0 / A_t_s,
9999
1.0 / B_s,
100+
out_dtype=config.out_dtype,
100101
)
101102

102103
fp8_triton_us = benchmark_cuda_function_in_microseconds(
103-
blockwise_fp8_gemm_1x128_128x1,
104+
triton_fp8_gemm_1x128_128x1,
104105
A_t_q,
105-
1.0 / A_t_s,
106106
B_q,
107+
1.0 / A_t_s,
107108
1.0 / B_s,
109+
out_dtype=config.out_dtype,
108110
)
109111

110-
# torch._scaled_mm requires A_s and B_t_s be in column-major format
111-
A_t_s = A_t_s.t().contiguous().t()
112-
113112
# Warm up then run torch bench
114113
warmup(
115114
torch._scaled_mm,
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
from dataclasses import dataclass
10+
from typing import List
11+
12+
import torch
13+
from tabulate import tabulate
14+
from torch.nn import functional as F
15+
from tqdm import tqdm
16+
from triton.testing import do_bench
17+
18+
from benchmarks.utils import bench_fwd_bwd_microseconds
19+
from torchao.prototype.blockwise_fp8_training.linear import Float8BlockwiseLinear
20+
21+
device = torch.device("cuda")
22+
23+
# This benchmark requires CUDA 12.9+
24+
assert torch.version.cuda is not None, "CUDA is not available"
25+
cuda_major, cuda_minor = map(int, torch.version.cuda.split("."))
26+
assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required"
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentConfig:
34+
out_dtype: torch.dtype
35+
m: int
36+
n: int
37+
k: int
38+
39+
40+
@dataclass(frozen=True)
41+
class ExperimentResult:
42+
bf16_linear_us: float
43+
fp8_triton_linear_us: float
44+
fp8_scaled_mm_linear_us: float
45+
46+
47+
@dataclass(frozen=True)
48+
class Experiment:
49+
config: ExperimentConfig
50+
result: ExperimentResult
51+
52+
53+
def get_configs() -> List[ExperimentConfig]:
54+
mnk_list = [
55+
# Llama4 shapes
56+
(16640, 5120, 8192),
57+
(16640, 8192, 5120),
58+
]
59+
out_dtypes = [torch.bfloat16]
60+
configs = []
61+
for mnk, out_dtype in itertools.product(mnk_list, out_dtypes):
62+
m, n, k = mnk
63+
configs.append(
64+
ExperimentConfig(
65+
out_dtype=out_dtype,
66+
m=m,
67+
n=n,
68+
k=k,
69+
)
70+
)
71+
return configs
72+
73+
74+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
75+
M, N, K = config.m, config.n, config.k
76+
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
77+
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
78+
fp8_triton_linear = Float8BlockwiseLinear(
79+
K, N, dtype=config.out_dtype, device="cuda", use_triton=True
80+
)
81+
fp8_scaled_mm_linear = Float8BlockwiseLinear(
82+
K, N, dtype=config.out_dtype, device="cuda", use_triton=False
83+
)
84+
85+
def warmup(func, *args, **kwargs):
86+
for _ in range(10):
87+
func(*args, **kwargs)
88+
89+
def fwd_bwd(func, inputs, labels, *args, **kwargs):
90+
out = func(inputs, *args, **kwargs)
91+
loss = F.mse_loss(out, labels)
92+
loss.backward()
93+
torch.cuda.synchronize()
94+
95+
# Warmup then run bf16 torch.mm
96+
labels = inputs.new_empty(M, N).fill_(1.0)
97+
warmup(fwd_bwd, bf16_linear, inputs, labels)
98+
99+
bf16_linear_us = benchmark_cuda_function_in_microseconds(
100+
fwd_bwd, bf16_linear, inputs, labels
101+
)
102+
103+
# Warm up then run triton bench
104+
warmup(
105+
fwd_bwd,
106+
fp8_triton_linear,
107+
inputs,
108+
labels,
109+
)
110+
111+
fp8_triton_linear_us = bench_fwd_bwd_microseconds(
112+
fp8_triton_linear,
113+
inputs,
114+
labels=labels,
115+
)
116+
117+
warmup(
118+
fwd_bwd,
119+
fp8_scaled_mm_linear,
120+
inputs,
121+
labels,
122+
)
123+
124+
fp8_scaled_mm_linear_us = bench_fwd_bwd_microseconds(
125+
fp8_scaled_mm_linear,
126+
inputs,
127+
labels=labels,
128+
)
129+
130+
return ExperimentResult(
131+
bf16_linear_us=bf16_linear_us,
132+
fp8_triton_linear_us=fp8_triton_linear_us,
133+
fp8_scaled_mm_linear_us=fp8_scaled_mm_linear_us,
134+
)
135+
136+
137+
def print_results(experiments: List[Experiment]):
138+
headers = [
139+
"M",
140+
"N",
141+
"K",
142+
"out_dtype",
143+
"bf16_mm_linear_us",
144+
"fp8_triton_linear_us",
145+
"fp8_scaled_mm_linear_us",
146+
]
147+
rows = []
148+
for experiment in experiments:
149+
m, n, k = experiment.config.m, experiment.config.n, experiment.config.k
150+
rows.append(
151+
[
152+
m,
153+
n,
154+
k,
155+
experiment.config.out_dtype,
156+
experiment.result.bf16_linear_us,
157+
experiment.result.fp8_triton_linear_us,
158+
experiment.result.fp8_scaled_mm_linear_us,
159+
]
160+
)
161+
print(tabulate(rows, headers=headers))
162+
163+
164+
def benchmark_cuda_function_in_microseconds(f, *args, **kwargs):
165+
return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3
166+
167+
168+
def main():
169+
torch.random.manual_seed(123)
170+
configs = get_configs()
171+
results = []
172+
for config in tqdm(configs):
173+
result = run_experiment(config)
174+
results.append(Experiment(config=config, result=result))
175+
176+
# Use Tabulate to print results
177+
print_results(results)
178+
179+
180+
if __name__ == "__main__":
181+
main()

benchmarks/prototype/moe_training/benchmark_moe_fsdp.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,7 @@
2222
from torch.distributed._composable.fsdp import fully_shard
2323
from torch.nn import functional as F
2424

25-
from benchmarks.prototype.moe_training.utils import (
26-
bench_fwd_bwd_microseconds,
27-
profile_fwd_bwd,
28-
)
25+
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
2926

3027
# this feature requires CUDA and SM89+
3128
if not torch.cuda.is_available() or torch.cuda.get_device_capability() < (8, 9):

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@
1212
import torch
1313
from tabulate import tabulate
1414
from tqdm import tqdm
15-
from utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
1615

16+
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
1717
from torchao.prototype.moe_training import _scaled_grouped_mm
1818
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1919
from torchao.prototype.moe_training.utils import generate_jagged_offs
File renamed without changes.

0 commit comments

Comments
 (0)