Skip to content

Commit

Permalink
Add tracing option to quantize bench (#3651)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#727

Adds support for the --trace option which will produce gpu traces for each benchmarked operator. This only works internally so if tried in OSS we fall back to nullcontext.

Reviewed By: jiawenliu64

Differential Revision: D68980020
  • Loading branch information
jwfromm authored and facebook-github-bot committed Feb 3, 2025
1 parent 146442e commit ff0150b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 31 deletions.
71 changes: 47 additions & 24 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
import seaborn as sns
import torch

try:
from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext
except ImportError:
from contextlib import nullcontext

class profiler_or_nullcontext(nullcontext):
def __init__(self, *args, **kwargs):
super().__init__()


from .quantize_ops import get_quantize_ops, QuantizeOpBase


Expand Down Expand Up @@ -96,6 +106,7 @@ def benchmark_grouped(
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
trace: bool = False,
) -> Dict[str, Any]:
num_groups = len(m)
# Create input tensors.
Expand Down Expand Up @@ -143,19 +154,21 @@ def benchmark_grouped(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
tflops = 0
Expand Down Expand Up @@ -197,6 +210,7 @@ def benchmark(
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
trace: bool = False,
) -> Dict[str, Any]:
# Create input tensors.
if b > 1:
Expand Down Expand Up @@ -230,19 +244,21 @@ def benchmark(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12
Expand Down Expand Up @@ -370,6 +386,7 @@ def main(args: Any):
args.bench_quantize,
args.use_rotating_buffer_bench,
not args.no_cuda_graph,
args.trace,
)
benchmark_results.append(quantize_measurements)
if args.export_csv:
Expand Down Expand Up @@ -460,6 +477,12 @@ def invoke_main() -> None:
action="store_true",
help="If set, benchmark using fixed shapes relevant to ldm workloads.",
)
parser.add_argument(
"--trace",
default=False,
action="store_true",
help="If set, produce a performance trace of the benchmark.",
)

args = parser.parse_args()
main(args)
20 changes: 13 additions & 7 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,20 +411,26 @@ class FP8RowwiseGemm(QuantizeOpBase):
FP8 matmul with rowwise scaling.
"""

def quantize(self, x, w):
def preprocess(self, x, w):
# Prequantize weights.
if isinstance(w, (list, tuple)):
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
else:
wq, w_scale = quantize_fp8_row(w)
if wq.dim() == 3:
w_scale = w_scale.view(wq.size(0), -1)
return x, wq, w_scale

def quantize(self, x, wq, w_scale):
# Quantize both input tensors.
# Handle both grouped and standard gemm.
if isinstance(x, (list, tuple)):
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
else:
xq, x_scale = quantize_fp8_row(x)
wq, w_scale = quantize_fp8_row(w)
# Set proper batch dimension shapes.
if xq.dim() == 3:
x_scale = x_scale.view(xq.size(0), -1)
if wq.dim() == 3:
w_scale = w_scale.view(wq.size(0), -1)
return xq, wq, x_scale, w_scale

def compute(self, xq, wq, x_scale, w_scale):
Expand All @@ -451,8 +457,8 @@ def compute(self, xq, wq, x_scale, w_scale):
# Otherwise return normal gemm result.
return torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale)

def quantize_and_compute(self, x, w):
xq, wq, x_scale, w_scale = self.quantize(x, w)
def quantize_and_compute(self, x, wq, w_scale):
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
return self.compute(xq, wq, x_scale, w_scale)

@property
Expand Down

0 comments on commit ff0150b

Please sign in to comment.