Skip to content

Commit fce1e18

Browse files
committed
Enabling MOE Quantization using linear decomposition [WIP]
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. current tests are running locally but will be added once working. currently int8wo and int8dq are working for multi and single token moe inference while int4wo is being finished up. TODO move test set into ao, move quantizable moe module code to ao test on hf model definition. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a81322e commit fce1e18

File tree

6 files changed

+133
-39
lines changed

6 files changed

+133
-39
lines changed

torchao/dtypes/affine_quantized_tensor_ops.py

+40
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,46 @@ def _(func, types, args, kwargs):
477477
)
478478
return return_and_correct_aliasing(func, args, kwargs, new)
479479

480+
@implements(aten.index.Tensor)
481+
def _(func, types, args, kwargs):
482+
self, indices = args
483+
assert len(indices) == 1, f"op {func} currently only implemented for single dimensional indexing but got indices: {indices}"
484+
485+
new_tensor_impl = aten.index.Tensor(self.tensor_impl, indices)
486+
shape = tuple([indices[0].numel(), *self.shape[1:]])
487+
488+
block_size = self.block_size
489+
new = self.__class__(
490+
new_tensor_impl,
491+
block_size,
492+
shape,
493+
self.quant_min,
494+
self.quant_max,
495+
self.zero_point_domain,
496+
dtype=self.dtype,
497+
)
498+
return return_and_correct_aliasing(func, args, kwargs, new)
499+
500+
@implements(aten.select.int)
501+
def _(func, types, args, kwargs):
502+
self, dim, index = fill_defaults(args, 3, [0, 0])
503+
assert dim==0, f"op {func} currently only implemented for dim=0 but got dim={dim}"
504+
assert self.dim() == 3, f"op {func} currently only implemented for 3 dimensional tensors but got shape={self.shape}"
505+
506+
new_tensor_impl = aten.select.int(self.tensor_impl, dim, index)
507+
508+
shape = self.shape[1:]
509+
block_size = self.block_size[1:]
510+
new = self.__class__(
511+
new_tensor_impl,
512+
block_size,
513+
shape,
514+
self.quant_min,
515+
self.quant_max,
516+
self.zero_point_domain,
517+
dtype=self.dtype,
518+
)
519+
return return_and_correct_aliasing(func, args, kwargs, new)
480520

481521
# this is needed for DTensor.from_local() and for flattening tensor
482522
@implements(aten.view.default)

torchao/dtypes/uintx/plain_layout.py

+11
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
154154
)
155155
return return_and_correct_aliasing(func, args, kwargs, new)
156156

157+
158+
elif func in [aten.select.int, aten.index.Tensor]:
159+
return return_and_correct_aliasing(
160+
func,
161+
args,
162+
kwargs,
163+
args[0]._apply_fn_to_data(
164+
lambda x: func(x, *args[1:], **kwargs)
165+
),
166+
)
167+
157168
elif func is aten.slice.Tensor:
158169
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
159170
if dim == 0:

torchao/dtypes/uintx/tensor_core_tiled_layout.py

+71-31
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
7575
f"need input_tensor shape: {input_tensor.shape} final"
7676
f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
7777
)
78-
7978
# TODO: check groupsize quantization
8079
# avoid circular dep, TODO: move this to a common util.py
8180
act_mat = input_tensor
@@ -97,7 +96,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
9796
y = torch.ops.aten._weight_int4pack_mm(
9897
act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
9998
)
100-
10199
# remove out_feature padding
102100
orig_out_features = weight_tensor.shape[-2]
103101
y = y[:, :orig_out_features]
@@ -119,7 +117,7 @@ class TensorCoreTiledLayout(Layout):
119117
inner_k_tiles: int = 8
120118

121119
def pre_process(self, input: torch.Tensor) -> torch.Tensor:
122-
orig_out_features, orig_in_features = input.shape
120+
orig_out_features, orig_in_features = input.shape[-2:]
123121
in_features = find_multiple(orig_in_features, 1024)
124122
out_features = find_multiple(orig_out_features, 8)
125123
input = torch.nn.functional.pad(
@@ -160,7 +158,7 @@ def post_process(
160158
zero_point: torch.Tensor,
161159
block_size: Tuple[int, ...],
162160
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
163-
orig_out_features, orig_in_features = input.shape
161+
orig_out_features, orig_in_features = input.shape[-2:]
164162
in_features = find_multiple(orig_in_features, 1024)
165163
out_features = find_multiple(orig_out_features, 8)
166164
input = torch.nn.functional.pad(
@@ -272,14 +270,28 @@ def from_plain(
272270
assert (
273271
int_data.dtype == torch.int32
274272
), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
275-
packed_weight = torch.ops.aten._convert_weight_to_int4pack(
276-
int_data, _layout.inner_k_tiles
277-
)
278-
scale = scale.reshape(int_data.shape[0], -1)
279-
zero_point = zero_point.reshape(int_data.shape[0], -1)
273+
def quant_2d(int_data):
274+
return torch.ops.aten._convert_weight_to_int4pack(
275+
int_data, _layout.inner_k_tiles
276+
)
277+
if int_data.shape[1] == 14336:
278+
import fbvscode; fbvscode.set_trace()
279+
if int_data.dim() == 3: # for moe quant
280+
num_experts = int_data.shape[0]
281+
packed_weight_list = []
282+
for expert in range(num_experts):
283+
packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0))
284+
packed_weight = torch.cat(packed_weight_list, dim=0)
285+
scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1)
286+
zero_point = zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1)
287+
else:
288+
packed_weight = quant_2d(int_data)
289+
scale = scale.reshape(int_data.shape[0], -1)
290+
zero_point = zero_point.reshape(int_data.shape[0], -1)
280291
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
281292

282293
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype)
294+
import fbvscode; fbvscode.set_trace()
283295
return cls(packed_weight, scale_and_zero, False, _layout)
284296

285297
def to(self, *args, **kwargs):
@@ -336,6 +348,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
336348
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
337349
)
338350

351+
if func in [aten.select.int, aten.index.Tensor]:
352+
assert not (func is aten.select.int and args[1]!=0), "aten.select.int currently only has support for dim=0"
353+
return return_and_correct_aliasing(
354+
func,
355+
args,
356+
kwargs,
357+
args[0]._apply_fn_to_data(
358+
lambda x: func(x, *args[1:], **kwargs)
359+
),
360+
)
361+
362+
339363
if func is aten.t.default:
340364
"""we don't need to repack the weight and just rely on external
341365
shape being changed and record the status of transpose/no-transpose
@@ -399,29 +423,45 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
399423
)
400424
from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros
401425

402-
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
403-
426+
def dequant_4d(self):
427+
cur_shape = self.shape
428+
assert len(cur_shape) == 4
429+
inner_k_tiles = cur_shape[-1] * 2
430+
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
431+
eye_shape = original_shape[1]
432+
groupsize = int(original_shape[1] / scale.shape[-2])
433+
block_size = (1, groupsize)
434+
device = self.device
435+
original_dtype = torch.bfloat16
436+
target_dtype = torch.int32
437+
quant_min = 0
438+
quant_max = 15
439+
zero_point_domain = ZeroPointDomain.FLOAT
440+
assert len(block_size) == 2 and block_size[0] == 1
441+
dequantized = torch.ops.aten._weight_int4pack_mm(
442+
torch.eye(eye_shape, device=device, dtype=original_dtype),
443+
self.packed_weight,
444+
groupsize,
445+
self.scale_and_zero,
446+
)
447+
dequantized = dequantized.t().contiguous()
448+
return dequantized
449+
404450
cur_shape = self.shape
405-
assert len(cur_shape) == 4
406-
inner_k_tiles = cur_shape[-1] * 2
407-
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
408-
eye_shape = original_shape[1]
409-
groupsize = int(original_shape[1] / scale.shape[-2])
410-
block_size = (1, groupsize)
411-
device = self.device
412-
original_dtype = torch.bfloat16
413-
target_dtype = torch.int32
414-
quant_min = 0
415-
quant_max = 15
416-
zero_point_domain = ZeroPointDomain.FLOAT
417-
assert len(block_size) == 2 and block_size[0] == 1
418-
dequantized = torch.ops.aten._weight_int4pack_mm(
419-
torch.eye(eye_shape, device=device, dtype=original_dtype),
420-
self.packed_weight,
421-
groupsize,
422-
self.scale_and_zero,
423-
)
424-
dequantized = dequantized.t().contiguous()
451+
452+
if len(cur_shape)==4:
453+
dequantized = dequant_4d(self)
454+
else:
455+
456+
assert len(cur_shape) == 5
457+
num_experts = cur_shape[0]
458+
dequantized_list = []
459+
import fbvscode; fbvscode.set_trace()
460+
for expert in range(num_experts):
461+
dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0))
462+
de
463+
464+
scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero)
425465
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
426466
scale = scale.reshape(scale.shape[:-1]).contiguous()
427467
zero = zero.reshape(zero.shape[:-1]).contiguous()

torchao/quantization/quant_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def _replace_with_custom_fn_if_matches_filter(
300300
device,
301301
extra_args,
302302
)
303-
if new_child is not child:
303+
if new_child is not child and new_child is not None:
304304
setattr(model, name, new_child)
305305
if device is not None:
306306
model.to(device=device) # move parent module to device

torchao/quantization/utils.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -366,22 +366,23 @@ def get_groupwise_affine_qparams(
366366
def pack_tinygemm_scales_and_zeros(scales, zeros, dtype=torch.bfloat16):
367367
guard_dtype_size(scales, "scales", dtype=dtype, size=zeros.size())
368368
guard_dtype_size(zeros, "zeros", dtype=dtype)
369+
dim = scales.dim()
369370
return (
370371
torch.cat(
371372
[
372-
scales.reshape(scales.size(0), scales.size(1), 1),
373-
zeros.reshape(zeros.size(0), zeros.size(1), 1),
373+
scales.unsqueeze(-1),
374+
zeros.unsqueeze(-1),
374375
],
375-
2,
376+
dim,
376377
)
377-
.transpose(0, 1)
378+
.transpose(-3, -2)
378379
.contiguous()
379380
)
380381

381382

382383
def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
383-
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
384-
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)
384+
assert scales_and_zeros.shape[-1] == 2
385+
return torch.split(scales_and_zeros.transpose(-3, -2), 1, -1)
385386

386387

387388
def convert_weight_to_int4pack_xpu(weight, zero_point_domain_is_int=False):

torchao/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,8 @@ class MyTensor(torch.Tensor):
422422
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
423423

424424
with torch._C.DisableTorchFunctionSubclass():
425-
return func(*args, **kwargs)
425+
out = func(*args, **kwargs)
426+
return out
426427

427428

428429
def _dispatch__torch_dispatch__(cls, func, types, args, kwargs):
@@ -441,6 +442,7 @@ class MyTensor(torch.Tensor):
441442

442443
arg_types = tuple(type(arg) for arg in args)
443444
kwarg_types = {k: type(arg) for k, arg in kwargs.items()}
445+
# import fbvscode; fbvscode.set_trace()
444446
raise NotImplementedError(
445447
f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}"
446448
)

0 commit comments

Comments
 (0)