diff --git a/torchao/prototype/mx_formats/utils.py b/torchao/prototype/mx_formats/utils.py index 67c914c124..05a681591e 100644 --- a/torchao/prototype/mx_formats/utils.py +++ b/torchao/prototype/mx_formats/utils.py @@ -35,13 +35,14 @@ def to_blocked(input_matrix) -> Tensor: padded_cols = n_col_blocks * 4 padded = input_matrix - # if (rows, cols) != (padded_rows, padded_cols): - padded = torch.zeros( - (padded_rows, padded_cols), - device=input_matrix.device, - dtype=input_matrix.dtype, - ) - padded[:rows, :cols] = input_matrix + # TODO This is to work around VLLM's usage of compile w/ dynamic shapes + if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), + device=input_matrix.device, + dtype=input_matrix.dtype, + ) + padded[:rows, :cols] = input_matrix # Rearrange the blocks blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)