Skip to content

Enhance test_autoquant_compile to support ROCm #2100

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,9 @@ def _test_smooth_linear_impl(self, x_shape, lin_shape, device):
sqnr_fq = compute_error(y_smooth_fq_only, y_dynamic_q)
# print('sqnr_smooth', sqnr_smooth_fq, 'sqnr_dynamic', sqnr_dynamic_q, 'sqnr_fq', sqnr_fq)

assert torch.allclose(
y_ref, y_smooth_nocalib
), "y_ref not close to y_smooth_nocalib"
assert torch.allclose(y_ref, y_smooth_nocalib), (
"y_ref not close to y_smooth_nocalib"
)
# after https://github.com/pytorch-labs/ao_benchmarks/pull/32,
# numerics do not match exactly between production c++ code
# and this Python code
Expand Down Expand Up @@ -1338,9 +1338,9 @@ def forward(self, x):
model_qc = torch.compile(model, mode="max-autotune")
ref_q = model_qc(x).detach()

assert (
SQNR(ref_f, ref_q) > min_sqnr
), f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
assert SQNR(ref_f, ref_q) > min_sqnr, (
f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
)

# load model structure
with torch.device("meta"):
Expand All @@ -1360,9 +1360,9 @@ def forward(self, x):
model_qc = torch.compile(model, mode="max-autotune")
test = model_qc(x).detach()

assert (
SQNR(ref_f, test) > min_sqnr
), f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
assert SQNR(ref_f, test) > min_sqnr, (
f"got sqnr: {SQNR(ref_f, ref_q)}, expected: {min_sqnr}"
)
self.assertTrue(torch.equal(ref_q, test))

@parameterized.expand(COMMON_DEVICE_DTYPE)
Expand Down Expand Up @@ -1593,15 +1593,28 @@ def test_autoquant_one_input(self, device, dtype, m, k, n):
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant requires 2.5+.")
def test_autoquant_compile(self, device, dtype, m1, m2, k, n):
undo_recommended_configs()
if device != "cuda" or not torch.cuda.is_available():

# Check if we're running on a supported device
is_cuda_available = torch.cuda.is_available()
is_rocm_available = torch.version.hip is not None
is_supported_device = device == "cuda" and (
is_cuda_available or is_rocm_available
)

if not is_supported_device:
self.skipTest(f"autoquant currently does not support {device}")
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
if dtype == torch.bfloat16:
self.skipTest("bfloat16 requires sm80+")
if m1 == 1 or m2 == 1:
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")
# This test fails on v0.4.0 and torch 2.4, so skipping for now.
if m1 == 1 or m2 == 1 and not TORCH_VERSION_AT_LEAST_2_5:

# Check CUDA-specific requirements if running on CUDA
if is_cuda_available and not is_rocm_available:
device_capability = torch.cuda.get_device_capability()
if device_capability < (8, 0):
if dtype == torch.bfloat16:
self.skipTest("bfloat16 requires sm80+")
if m1 == 1 or m2 == 1:
self.skipTest(f"Shape {(m1, m2, k, n)} requires sm80+")

# Skip certain shapes on older PyTorch versions
if (m1 == 1 or m2 == 1) and not TORCH_VERSION_AT_LEAST_2_5:
self.skipTest(f"Shape {(m1, m2, k, n)} requires torch version > 2.4")
model = (
torch.nn.Sequential(
Expand Down
Loading