From e9119cb05fa3fc9e23a5ac705828a0fb74e4ba68 Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 2 May 2025 14:02:37 -0700 Subject: [PATCH] VLLM Workaround stack-info: PR: https://github.com/pytorch/ao/pull/2165, branch: drisspg/stack/52 --- torchao/prototype/mx_formats/utils.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 67c914c124..05a681591e 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -35,13 +35,14 @@ def to_blocked(input_matrix) -> Tensor: padded_cols = n_col_blocks * 4 padded = input_matrix - # if (rows, cols) != (padded_rows, padded_cols): - padded = torch.zeros( - (padded_rows, padded_cols), - device=input_matrix.device, - dtype=input_matrix.dtype, - ) - padded[:rows, :cols] = input_matrix + # TODO This is to work around VLLM's usage of compile w/ dynamic shapes + if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix # Rearrange the blocks blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)