Skip to content

Commit 8812365

Browse files
authored
Fix autoquant tests failed due to changes to benchmark_gpu (#2818)
Skip test failing only in CI Summary: att Test Plan: CI Reviewers: Subscribers: Tasks: Tags:
1 parent 481a8ab commit 8812365

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

torchao/quantization/autoquant.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
TorchAOBaseTensor,
3535
is_sm_at_least_89,
3636
is_sm_at_least_90,
37+
torch_version_at_least,
3738
)
3839

3940
from .granularity import (
@@ -343,9 +344,18 @@ def do_autoquant_bench(op, *args, **kwargs):
343344
graph = torch.cuda.CUDAGraph()
344345
with torch.cuda.graph(graph, stream=stream):
345346
op(*args, **kwargs)
346-
res = benchmarker.benchmark_gpu(
347-
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median"
348-
)
347+
# TODO: update to 2.8.0 after https://github.com/pytorch/ao/pull/2786 is landed
348+
if torch_version_at_least("2.9.0"):
349+
from statistics import median
350+
351+
res = benchmarker.benchmark_gpu(
352+
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="all"
353+
)
354+
res = median(res)
355+
else:
356+
res = benchmarker.benchmark_gpu(
357+
lambda: graph.replay(), warmup=warmup, rep=rep, return_mode="median"
358+
)
349359
return res
350360

351361

0 commit comments

Comments
 (0)