|
16 | 16 | import seaborn as sns |
17 | 17 | import torch |
18 | 18 |
|
| 19 | +try: |
| 20 | + from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext |
| 21 | +except ImportError: |
| 22 | + from contextlib import nullcontext |
| 23 | + |
| 24 | + class profiler_or_nullcontext(nullcontext): |
| 25 | + def __init__(self, *args, **kwargs): |
| 26 | + super().__init__() |
| 27 | + |
| 28 | + |
19 | 29 | from .quantize_ops import get_quantize_ops, QuantizeOpBase |
20 | 30 |
|
21 | 31 |
|
@@ -96,6 +106,7 @@ def benchmark_grouped( |
96 | 106 | bench_quantize: bool = False, |
97 | 107 | use_rotating_buffer_bench: bool = False, |
98 | 108 | use_cuda_graph: bool = True, |
| 109 | + trace: bool = False, |
99 | 110 | ) -> Dict[str, Any]: |
100 | 111 | num_groups = len(m) |
101 | 112 | # Create input tensors. |
@@ -143,19 +154,21 @@ def benchmark_grouped( |
143 | 154 | # Now perform benchmark. |
144 | 155 | if bench_quantize: |
145 | 156 | # Benchmark both quantize and compute. |
146 | | - ms_runtime = quantize_op.benchmark( |
147 | | - *preprocessed_args, |
148 | | - bench_quantize=True, |
149 | | - use_rotating_buffer_bench=use_rotating_buffer_bench, |
150 | | - use_cuda_graph=use_cuda_graph, |
151 | | - ) |
| 157 | + with profiler_or_nullcontext(enabled=trace, with_stack=True): |
| 158 | + ms_runtime = quantize_op.benchmark( |
| 159 | + *preprocessed_args, |
| 160 | + bench_quantize=True, |
| 161 | + use_rotating_buffer_bench=use_rotating_buffer_bench, |
| 162 | + use_cuda_graph=use_cuda_graph, |
| 163 | + ) |
152 | 164 | else: |
153 | | - ms_runtime = quantize_op.benchmark( |
154 | | - *quantized_vals, |
155 | | - bench_quantize=False, |
156 | | - use_rotating_buffer_bench=use_rotating_buffer_bench, |
157 | | - use_cuda_graph=use_cuda_graph, |
158 | | - ) |
| 165 | + with profiler_or_nullcontext(enabled=trace, with_stack=True): |
| 166 | + ms_runtime = quantize_op.benchmark( |
| 167 | + *quantized_vals, |
| 168 | + bench_quantize=False, |
| 169 | + use_rotating_buffer_bench=use_rotating_buffer_bench, |
| 170 | + use_cuda_graph=use_cuda_graph, |
| 171 | + ) |
159 | 172 |
|
160 | 173 | # Print out results for this op. |
161 | 174 | tflops = 0 |
@@ -197,6 +210,7 @@ def benchmark( |
197 | 210 | bench_quantize: bool = False, |
198 | 211 | use_rotating_buffer_bench: bool = False, |
199 | 212 | use_cuda_graph: bool = True, |
| 213 | + trace: bool = False, |
200 | 214 | ) -> Dict[str, Any]: |
201 | 215 | # Create input tensors. |
202 | 216 | if b > 1: |
@@ -230,19 +244,21 @@ def benchmark( |
230 | 244 | # Now perform benchmark. |
231 | 245 | if bench_quantize: |
232 | 246 | # Benchmark both quantize and compute. |
233 | | - ms_runtime = quantize_op.benchmark( |
234 | | - *preprocessed_args, |
235 | | - bench_quantize=True, |
236 | | - use_rotating_buffer_bench=use_rotating_buffer_bench, |
237 | | - use_cuda_graph=use_cuda_graph, |
238 | | - ) |
| 247 | + with profiler_or_nullcontext(enabled=trace, with_stack=True): |
| 248 | + ms_runtime = quantize_op.benchmark( |
| 249 | + *preprocessed_args, |
| 250 | + bench_quantize=True, |
| 251 | + use_rotating_buffer_bench=use_rotating_buffer_bench, |
| 252 | + use_cuda_graph=use_cuda_graph, |
| 253 | + ) |
239 | 254 | else: |
240 | | - ms_runtime = quantize_op.benchmark( |
241 | | - *quantized_vals, |
242 | | - bench_quantize=False, |
243 | | - use_rotating_buffer_bench=use_rotating_buffer_bench, |
244 | | - use_cuda_graph=use_cuda_graph, |
245 | | - ) |
| 255 | + with profiler_or_nullcontext(enabled=trace, with_stack=True): |
| 256 | + ms_runtime = quantize_op.benchmark( |
| 257 | + *quantized_vals, |
| 258 | + bench_quantize=False, |
| 259 | + use_rotating_buffer_bench=use_rotating_buffer_bench, |
| 260 | + use_cuda_graph=use_cuda_graph, |
| 261 | + ) |
246 | 262 |
|
247 | 263 | # Print out results for this op. |
248 | 264 | tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12 |
@@ -370,6 +386,7 @@ def main(args: Any): |
370 | 386 | args.bench_quantize, |
371 | 387 | args.use_rotating_buffer_bench, |
372 | 388 | not args.no_cuda_graph, |
| 389 | + args.trace, |
373 | 390 | ) |
374 | 391 | benchmark_results.append(quantize_measurements) |
375 | 392 | if args.export_csv: |
@@ -460,6 +477,12 @@ def invoke_main() -> None: |
460 | 477 | action="store_true", |
461 | 478 | help="If set, benchmark using fixed shapes relevant to ldm workloads.", |
462 | 479 | ) |
| 480 | + parser.add_argument( |
| 481 | + "--trace", |
| 482 | + default=False, |
| 483 | + action="store_true", |
| 484 | + help="If set, produce a performance trace of the benchmark.", |
| 485 | + ) |
463 | 486 |
|
464 | 487 | args = parser.parse_args() |
465 | 488 | main(args) |
0 commit comments