Skip to content

Commit 00c43b4

Browse files
Josh Frommfacebook-github-bot
authored andcommitted
Add tracing option to quantize bench
Summary: Adds support for the --trace option which will produce gpu traces for each benchmarked operator. This only works internally so if tried in OSS we fall back to nullcontext. Differential Revision: D68980020
1 parent 23f9bf0 commit 00c43b4

File tree

2 files changed

+60
-31
lines changed

2 files changed

+60
-31
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,16 @@
1616
import seaborn as sns
1717
import torch
1818

19+
try:
20+
from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext
21+
except ImportError:
22+
from contextlib import nullcontext
23+
24+
class profiler_or_nullcontext(nullcontext):
25+
def __init__(self, *args, **kwargs):
26+
super().__init__()
27+
28+
1929
from .quantize_ops import get_quantize_ops, QuantizeOpBase
2030

2131

@@ -96,6 +106,7 @@ def benchmark_grouped(
96106
bench_quantize: bool = False,
97107
use_rotating_buffer_bench: bool = False,
98108
use_cuda_graph: bool = True,
109+
trace: bool = False,
99110
) -> Dict[str, Any]:
100111
num_groups = len(m)
101112
# Create input tensors.
@@ -143,19 +154,21 @@ def benchmark_grouped(
143154
# Now perform benchmark.
144155
if bench_quantize:
145156
# Benchmark both quantize and compute.
146-
ms_runtime = quantize_op.benchmark(
147-
*preprocessed_args,
148-
bench_quantize=True,
149-
use_rotating_buffer_bench=use_rotating_buffer_bench,
150-
use_cuda_graph=use_cuda_graph,
151-
)
157+
with profiler_or_nullcontext(enabled=trace, with_stack=True):
158+
ms_runtime = quantize_op.benchmark(
159+
*preprocessed_args,
160+
bench_quantize=True,
161+
use_rotating_buffer_bench=use_rotating_buffer_bench,
162+
use_cuda_graph=use_cuda_graph,
163+
)
152164
else:
153-
ms_runtime = quantize_op.benchmark(
154-
*quantized_vals,
155-
bench_quantize=False,
156-
use_rotating_buffer_bench=use_rotating_buffer_bench,
157-
use_cuda_graph=use_cuda_graph,
158-
)
165+
with profiler_or_nullcontext(enabled=trace, with_stack=True):
166+
ms_runtime = quantize_op.benchmark(
167+
*quantized_vals,
168+
bench_quantize=False,
169+
use_rotating_buffer_bench=use_rotating_buffer_bench,
170+
use_cuda_graph=use_cuda_graph,
171+
)
159172

160173
# Print out results for this op.
161174
tflops = 0
@@ -197,6 +210,7 @@ def benchmark(
197210
bench_quantize: bool = False,
198211
use_rotating_buffer_bench: bool = False,
199212
use_cuda_graph: bool = True,
213+
trace: bool = False,
200214
) -> Dict[str, Any]:
201215
# Create input tensors.
202216
if b > 1:
@@ -230,19 +244,21 @@ def benchmark(
230244
# Now perform benchmark.
231245
if bench_quantize:
232246
# Benchmark both quantize and compute.
233-
ms_runtime = quantize_op.benchmark(
234-
*preprocessed_args,
235-
bench_quantize=True,
236-
use_rotating_buffer_bench=use_rotating_buffer_bench,
237-
use_cuda_graph=use_cuda_graph,
238-
)
247+
with profiler_or_nullcontext(enabled=trace, with_stack=True):
248+
ms_runtime = quantize_op.benchmark(
249+
*preprocessed_args,
250+
bench_quantize=True,
251+
use_rotating_buffer_bench=use_rotating_buffer_bench,
252+
use_cuda_graph=use_cuda_graph,
253+
)
239254
else:
240-
ms_runtime = quantize_op.benchmark(
241-
*quantized_vals,
242-
bench_quantize=False,
243-
use_rotating_buffer_bench=use_rotating_buffer_bench,
244-
use_cuda_graph=use_cuda_graph,
245-
)
255+
with profiler_or_nullcontext(enabled=trace, with_stack=True):
256+
ms_runtime = quantize_op.benchmark(
257+
*quantized_vals,
258+
bench_quantize=False,
259+
use_rotating_buffer_bench=use_rotating_buffer_bench,
260+
use_cuda_graph=use_cuda_graph,
261+
)
246262

247263
# Print out results for this op.
248264
tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12
@@ -370,6 +386,7 @@ def main(args: Any):
370386
args.bench_quantize,
371387
args.use_rotating_buffer_bench,
372388
not args.no_cuda_graph,
389+
args.trace,
373390
)
374391
benchmark_results.append(quantize_measurements)
375392
if args.export_csv:
@@ -460,6 +477,12 @@ def invoke_main() -> None:
460477
action="store_true",
461478
help="If set, benchmark using fixed shapes relevant to ldm workloads.",
462479
)
480+
parser.add_argument(
481+
"--trace",
482+
default=False,
483+
action="store_true",
484+
help="If set, produce a performance trace of the benchmark.",
485+
)
463486

464487
args = parser.parse_args()
465488
main(args)

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -411,20 +411,26 @@ class FP8RowwiseGemm(QuantizeOpBase):
411411
FP8 matmul with rowwise scaling.
412412
"""
413413

414-
def quantize(self, x, w):
414+
def preprocess(self, x, w):
415+
# Prequantize weights.
416+
if isinstance(w, (list, tuple)):
417+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
418+
else:
419+
wq, w_scale = quantize_fp8_row(w)
420+
if wq.dim() == 3:
421+
w_scale = w_scale.view(wq.size(0), -1)
422+
return x, wq, w_scale
423+
424+
def quantize(self, x, wq, w_scale):
415425
# Quantize both input tensors.
416426
# Handle both grouped and standard gemm.
417427
if isinstance(x, (list, tuple)):
418428
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
419-
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
420429
else:
421430
xq, x_scale = quantize_fp8_row(x)
422-
wq, w_scale = quantize_fp8_row(w)
423431
# Set proper batch dimension shapes.
424432
if xq.dim() == 3:
425433
x_scale = x_scale.view(xq.size(0), -1)
426-
if wq.dim() == 3:
427-
w_scale = w_scale.view(wq.size(0), -1)
428434
return xq, wq, x_scale, w_scale
429435

430436
def compute(self, xq, wq, x_scale, w_scale):
@@ -451,8 +457,8 @@ def compute(self, xq, wq, x_scale, w_scale):
451457
# Otherwise return normal gemm result.
452458
return torch.ops.fbgemm.f8f8bf16_rowwise(xq, wq, x_scale, w_scale)
453459

454-
def quantize_and_compute(self, x, w):
455-
xq, wq, x_scale, w_scale = self.quantize(x, w)
460+
def quantize_and_compute(self, x, wq, w_scale):
461+
xq, wq, x_scale, w_scale = self.quantize(x, wq, w_scale)
456462
return self.compute(xq, wq, x_scale, w_scale)
457463

458464
@property

0 commit comments

Comments
 (0)