Skip to content

Commit 7b05105

Browse files
[scaled grouped mm] add triton kernels for float8 rowwise quantization with per-group/jagged scales (#2064)
1 parent 8c5eeac commit 7b05105

File tree

7 files changed

+655
-16
lines changed

7 files changed

+655
-16
lines changed

torchao/prototype/scaled_grouped_mm/kernels/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py
7+
8+
import itertools
9+
import time
10+
from dataclasses import dataclass
11+
from typing import List
12+
13+
import torch
14+
from tabulate import tabulate
15+
from tqdm import tqdm
16+
17+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
18+
triton_fp8_col_major_jagged_colwise_scales,
19+
triton_fp8_row_major_jagged_rowwise_scales,
20+
)
21+
from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import (
22+
_to_2d_jagged_float8_tensor_colwise,
23+
_to_2d_jagged_float8_tensor_rowwise,
24+
)
25+
26+
device = torch.device("cuda")
27+
28+
# Needed since changing args to function causes recompiles
29+
torch._dynamo.config.cache_size_limit = 1000
30+
31+
32+
@dataclass(frozen=True)
33+
class ExperimentConfig:
34+
high_precision_dtype: torch.dtype
35+
input_shape: tuple[int]
36+
n_groups: int
37+
38+
39+
@dataclass(frozen=True)
40+
class ExperimentResult:
41+
torch_time_us: float
42+
triton_time_us: float
43+
44+
45+
@dataclass(frozen=True)
46+
class Experiment:
47+
config: ExperimentConfig
48+
result: ExperimentResult
49+
50+
51+
def get_configs() -> List[ExperimentConfig]:
52+
input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)]
53+
n_groups_list = [4, 8, 16]
54+
high_precision_dtypes = [torch.bfloat16]
55+
configs = []
56+
for input_shape, n_groups, high_precision_dtype in itertools.product(
57+
input_shapes, n_groups_list, high_precision_dtypes
58+
):
59+
configs.append(
60+
ExperimentConfig(
61+
input_shape=input_shape,
62+
n_groups=n_groups,
63+
high_precision_dtype=high_precision_dtype,
64+
)
65+
)
66+
return configs
67+
68+
69+
def run_experiment(config: ExperimentConfig) -> ExperimentResult:
70+
# define test inputs
71+
input_tensor = torch.randn(
72+
*config.input_shape,
73+
dtype=config.high_precision_dtype,
74+
device=device,
75+
)
76+
input_row_major = input_tensor.clone().detach()
77+
input_col_major = input_tensor.clone().detach().t()
78+
79+
# - configure input to be row-major with groups divided along the column dimension,
80+
# representing the left operand of grad_weight = grad_output_t @ input
81+
# that occurs in the backward pass of the differentiable scaled grouped mm.
82+
# - the transposed tensor in col-major format with groups along the row dimension,
83+
# which represents the right operand.
84+
group_size = input_row_major.shape[1] // config.n_groups
85+
n_groups = config.n_groups
86+
offs = torch.arange(
87+
group_size,
88+
group_size * n_groups + 1,
89+
group_size,
90+
device=device,
91+
dtype=torch.int32,
92+
)
93+
94+
def warmup(func, *args, **kwargs):
95+
for _ in range(10):
96+
func(*args, **kwargs)
97+
98+
def run_torch(
99+
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
100+
):
101+
_ = _to_2d_jagged_float8_tensor_rowwise(
102+
input_row_major,
103+
offs,
104+
target_dtype=torch.float8_e4m3fn,
105+
round_scales_to_power_of_2=True,
106+
)
107+
_ = _to_2d_jagged_float8_tensor_colwise(
108+
input_col_major,
109+
offs,
110+
target_dtype=torch.float8_e4m3fn,
111+
round_scales_to_power_of_2=True,
112+
)
113+
114+
def run_triton(
115+
input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor
116+
):
117+
_ = triton_fp8_row_major_jagged_rowwise_scales(
118+
input_row_major,
119+
offs,
120+
output_dtype=torch.float8_e4m3fn,
121+
round_scales_to_power_of_2=True,
122+
)
123+
_ = triton_fp8_col_major_jagged_colwise_scales(
124+
input_col_major,
125+
offs,
126+
output_dtype=torch.float8_e4m3fn,
127+
round_scales_to_power_of_2=True,
128+
)
129+
130+
# bench torch
131+
compiled_run_torch = torch.compile(run_torch)
132+
warmup(compiled_run_torch, input_row_major, input_col_major, offs)
133+
start_time_ns = time.perf_counter_ns()
134+
compiled_run_torch(input_row_major, input_col_major, offs)
135+
torch_time_ns = time.perf_counter_ns() - start_time_ns
136+
torch_time_us = torch_time_ns / 1e3
137+
138+
# bench triton
139+
warmup(run_triton, input_row_major, input_col_major, offs)
140+
start_time_ns = time.perf_counter_ns()
141+
run_triton(input_row_major, input_col_major, offs)
142+
triton_time_ns = time.perf_counter_ns() - start_time_ns
143+
triton_time_us = triton_time_ns / 1e3
144+
145+
return ExperimentResult(
146+
torch_time_us=torch_time_us,
147+
triton_time_us=triton_time_us,
148+
)
149+
150+
151+
def print_results(experiments: List[Experiment]):
152+
headers = [
153+
"input_shape",
154+
"n_groups",
155+
"high_precision_dtype",
156+
"torch_time_us",
157+
"triton_time_us",
158+
]
159+
rows = []
160+
for experiment in experiments:
161+
input_shape = (
162+
f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})"
163+
)
164+
rows.append(
165+
[
166+
input_shape,
167+
experiment.config.n_groups,
168+
experiment.config.high_precision_dtype,
169+
experiment.result.torch_time_us,
170+
experiment.result.triton_time_us,
171+
]
172+
)
173+
print(tabulate(rows, headers=headers))
174+
175+
176+
def main():
177+
torch.random.manual_seed(123)
178+
configs = get_configs()
179+
results = []
180+
for config in tqdm(configs):
181+
result = run_experiment(config)
182+
results.append(Experiment(config=config, result=result))
183+
184+
# Use Tabulate to print results
185+
print_results(results)
186+
187+
188+
if __name__ == "__main__":
189+
main()

0 commit comments

Comments
 (0)