|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py |
| 7 | + |
| 8 | +import itertools |
| 9 | +import time |
| 10 | +from dataclasses import dataclass |
| 11 | +from typing import List |
| 12 | + |
| 13 | +import torch |
| 14 | +from tabulate import tabulate |
| 15 | +from tqdm import tqdm |
| 16 | + |
| 17 | +from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( |
| 18 | + triton_fp8_col_major_jagged_colwise_scales, |
| 19 | + triton_fp8_row_major_jagged_rowwise_scales, |
| 20 | +) |
| 21 | +from torchao.prototype.scaled_grouped_mm.scaled_grouped_mm import ( |
| 22 | + _to_2d_jagged_float8_tensor_colwise, |
| 23 | + _to_2d_jagged_float8_tensor_rowwise, |
| 24 | +) |
| 25 | + |
| 26 | +device = torch.device("cuda") |
| 27 | + |
| 28 | +# Needed since changing args to function causes recompiles |
| 29 | +torch._dynamo.config.cache_size_limit = 1000 |
| 30 | + |
| 31 | + |
| 32 | +@dataclass(frozen=True) |
| 33 | +class ExperimentConfig: |
| 34 | + high_precision_dtype: torch.dtype |
| 35 | + input_shape: tuple[int] |
| 36 | + n_groups: int |
| 37 | + |
| 38 | + |
| 39 | +@dataclass(frozen=True) |
| 40 | +class ExperimentResult: |
| 41 | + torch_time_us: float |
| 42 | + triton_time_us: float |
| 43 | + |
| 44 | + |
| 45 | +@dataclass(frozen=True) |
| 46 | +class Experiment: |
| 47 | + config: ExperimentConfig |
| 48 | + result: ExperimentResult |
| 49 | + |
| 50 | + |
| 51 | +def get_configs() -> List[ExperimentConfig]: |
| 52 | + input_shapes = [(2**8, 4096), (2**12, 4096), (2**16, 4096)] |
| 53 | + n_groups_list = [4, 8, 16] |
| 54 | + high_precision_dtypes = [torch.bfloat16] |
| 55 | + configs = [] |
| 56 | + for input_shape, n_groups, high_precision_dtype in itertools.product( |
| 57 | + input_shapes, n_groups_list, high_precision_dtypes |
| 58 | + ): |
| 59 | + configs.append( |
| 60 | + ExperimentConfig( |
| 61 | + input_shape=input_shape, |
| 62 | + n_groups=n_groups, |
| 63 | + high_precision_dtype=high_precision_dtype, |
| 64 | + ) |
| 65 | + ) |
| 66 | + return configs |
| 67 | + |
| 68 | + |
| 69 | +def run_experiment(config: ExperimentConfig) -> ExperimentResult: |
| 70 | + # define test inputs |
| 71 | + input_tensor = torch.randn( |
| 72 | + *config.input_shape, |
| 73 | + dtype=config.high_precision_dtype, |
| 74 | + device=device, |
| 75 | + ) |
| 76 | + input_row_major = input_tensor.clone().detach() |
| 77 | + input_col_major = input_tensor.clone().detach().t() |
| 78 | + |
| 79 | + # - configure input to be row-major with groups divided along the column dimension, |
| 80 | + # representing the left operand of grad_weight = grad_output_t @ input |
| 81 | + # that occurs in the backward pass of the differentiable scaled grouped mm. |
| 82 | + # - the transposed tensor in col-major format with groups along the row dimension, |
| 83 | + # which represents the right operand. |
| 84 | + group_size = input_row_major.shape[1] // config.n_groups |
| 85 | + n_groups = config.n_groups |
| 86 | + offs = torch.arange( |
| 87 | + group_size, |
| 88 | + group_size * n_groups + 1, |
| 89 | + group_size, |
| 90 | + device=device, |
| 91 | + dtype=torch.int32, |
| 92 | + ) |
| 93 | + |
| 94 | + def warmup(func, *args, **kwargs): |
| 95 | + for _ in range(10): |
| 96 | + func(*args, **kwargs) |
| 97 | + |
| 98 | + def run_torch( |
| 99 | + input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor |
| 100 | + ): |
| 101 | + _ = _to_2d_jagged_float8_tensor_rowwise( |
| 102 | + input_row_major, |
| 103 | + offs, |
| 104 | + target_dtype=torch.float8_e4m3fn, |
| 105 | + round_scales_to_power_of_2=True, |
| 106 | + ) |
| 107 | + _ = _to_2d_jagged_float8_tensor_colwise( |
| 108 | + input_col_major, |
| 109 | + offs, |
| 110 | + target_dtype=torch.float8_e4m3fn, |
| 111 | + round_scales_to_power_of_2=True, |
| 112 | + ) |
| 113 | + |
| 114 | + def run_triton( |
| 115 | + input_row_major: torch.Tensor, input_col_major: torch.Tensor, offs: torch.Tensor |
| 116 | + ): |
| 117 | + _ = triton_fp8_row_major_jagged_rowwise_scales( |
| 118 | + input_row_major, |
| 119 | + offs, |
| 120 | + output_dtype=torch.float8_e4m3fn, |
| 121 | + round_scales_to_power_of_2=True, |
| 122 | + ) |
| 123 | + _ = triton_fp8_col_major_jagged_colwise_scales( |
| 124 | + input_col_major, |
| 125 | + offs, |
| 126 | + output_dtype=torch.float8_e4m3fn, |
| 127 | + round_scales_to_power_of_2=True, |
| 128 | + ) |
| 129 | + |
| 130 | + # bench torch |
| 131 | + compiled_run_torch = torch.compile(run_torch) |
| 132 | + warmup(compiled_run_torch, input_row_major, input_col_major, offs) |
| 133 | + start_time_ns = time.perf_counter_ns() |
| 134 | + compiled_run_torch(input_row_major, input_col_major, offs) |
| 135 | + torch_time_ns = time.perf_counter_ns() - start_time_ns |
| 136 | + torch_time_us = torch_time_ns / 1e3 |
| 137 | + |
| 138 | + # bench triton |
| 139 | + warmup(run_triton, input_row_major, input_col_major, offs) |
| 140 | + start_time_ns = time.perf_counter_ns() |
| 141 | + run_triton(input_row_major, input_col_major, offs) |
| 142 | + triton_time_ns = time.perf_counter_ns() - start_time_ns |
| 143 | + triton_time_us = triton_time_ns / 1e3 |
| 144 | + |
| 145 | + return ExperimentResult( |
| 146 | + torch_time_us=torch_time_us, |
| 147 | + triton_time_us=triton_time_us, |
| 148 | + ) |
| 149 | + |
| 150 | + |
| 151 | +def print_results(experiments: List[Experiment]): |
| 152 | + headers = [ |
| 153 | + "input_shape", |
| 154 | + "n_groups", |
| 155 | + "high_precision_dtype", |
| 156 | + "torch_time_us", |
| 157 | + "triton_time_us", |
| 158 | + ] |
| 159 | + rows = [] |
| 160 | + for experiment in experiments: |
| 161 | + input_shape = ( |
| 162 | + f"({experiment.config.input_shape[0]}, {experiment.config.input_shape[1]})" |
| 163 | + ) |
| 164 | + rows.append( |
| 165 | + [ |
| 166 | + input_shape, |
| 167 | + experiment.config.n_groups, |
| 168 | + experiment.config.high_precision_dtype, |
| 169 | + experiment.result.torch_time_us, |
| 170 | + experiment.result.triton_time_us, |
| 171 | + ] |
| 172 | + ) |
| 173 | + print(tabulate(rows, headers=headers)) |
| 174 | + |
| 175 | + |
| 176 | +def main(): |
| 177 | + torch.random.manual_seed(123) |
| 178 | + configs = get_configs() |
| 179 | + results = [] |
| 180 | + for config in tqdm(configs): |
| 181 | + result = run_experiment(config) |
| 182 | + results.append(Experiment(config=config, result=result)) |
| 183 | + |
| 184 | + # Use Tabulate to print results |
| 185 | + print_results(results) |
| 186 | + |
| 187 | + |
| 188 | +if __name__ == "__main__": |
| 189 | + main() |
0 commit comments