@@ -75,7 +75,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
75
75
f"need input_tensor shape: { input_tensor .shape } final"
76
76
f"dim to match weight_tensor shape: { weight_tensor .shape } second dim "
77
77
)
78
-
79
78
# TODO: check groupsize quantization
80
79
# avoid circular dep, TODO: move this to a common util.py
81
80
act_mat = input_tensor
@@ -97,7 +96,6 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
97
96
y = torch .ops .aten ._weight_int4pack_mm (
98
97
act_mat .contiguous (), packed_weight , groupsize , scale_and_zero
99
98
)
100
-
101
99
# remove out_feature padding
102
100
orig_out_features = weight_tensor .shape [- 2 ]
103
101
y = y [:, :orig_out_features ]
@@ -119,7 +117,7 @@ class TensorCoreTiledLayout(Layout):
119
117
inner_k_tiles : int = 8
120
118
121
119
def pre_process (self , input : torch .Tensor ) -> torch .Tensor :
122
- orig_out_features , orig_in_features = input .shape
120
+ orig_out_features , orig_in_features = input .shape [ - 2 :]
123
121
in_features = find_multiple (orig_in_features , 1024 )
124
122
out_features = find_multiple (orig_out_features , 8 )
125
123
input = torch .nn .functional .pad (
@@ -160,7 +158,7 @@ def post_process(
160
158
zero_point : torch .Tensor ,
161
159
block_size : Tuple [int , ...],
162
160
) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
163
- orig_out_features , orig_in_features = input .shape
161
+ orig_out_features , orig_in_features = input .shape [ - 2 :]
164
162
in_features = find_multiple (orig_in_features , 1024 )
165
163
out_features = find_multiple (orig_out_features , 8 )
166
164
input = torch .nn .functional .pad (
@@ -272,14 +270,28 @@ def from_plain(
272
270
assert (
273
271
int_data .dtype == torch .int32
274
272
), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype"
275
- packed_weight = torch .ops .aten ._convert_weight_to_int4pack (
276
- int_data , _layout .inner_k_tiles
277
- )
278
- scale = scale .reshape (int_data .shape [0 ], - 1 )
279
- zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
273
+ def quant_2d (int_data ):
274
+ return torch .ops .aten ._convert_weight_to_int4pack (
275
+ int_data , _layout .inner_k_tiles
276
+ )
277
+ if int_data .shape [1 ] == 14336 :
278
+ import fbvscode ; fbvscode .set_trace ()
279
+ if int_data .dim () == 3 : # for moe quant
280
+ num_experts = int_data .shape [0 ]
281
+ packed_weight_list = []
282
+ for expert in range (num_experts ):
283
+ packed_weight_list .append (quant_2d (int_data [expert ]).unsqueeze (0 ))
284
+ packed_weight = torch .cat (packed_weight_list , dim = 0 )
285
+ scale = scale .reshape (int_data .shape [0 ], int_data .shape [- 2 ], - 1 )
286
+ zero_point = zero_point .reshape (int_data .shape [0 ], int_data .shape [- 2 ], - 1 )
287
+ else :
288
+ packed_weight = quant_2d (int_data )
289
+ scale = scale .reshape (int_data .shape [0 ], - 1 )
290
+ zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
280
291
from torchao .quantization .utils import pack_tinygemm_scales_and_zeros
281
292
282
293
scale_and_zero = pack_tinygemm_scales_and_zeros (scale , zero_point , scale .dtype )
294
+ import fbvscode ; fbvscode .set_trace ()
283
295
return cls (packed_weight , scale_and_zero , False , _layout )
284
296
285
297
def to (self , * args , ** kwargs ):
@@ -336,6 +348,18 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
336
348
f"Not supported args for copy_ due to metadata mistach: { args [0 ], args [1 ]} "
337
349
)
338
350
351
+ if func in [aten .select .int , aten .index .Tensor ]:
352
+ assert not (func is aten .select .int and args [1 ]!= 0 ), "aten.select.int currently only has support for dim=0"
353
+ return return_and_correct_aliasing (
354
+ func ,
355
+ args ,
356
+ kwargs ,
357
+ args [0 ]._apply_fn_to_data (
358
+ lambda x : func (x , * args [1 :], ** kwargs )
359
+ ),
360
+ )
361
+
362
+
339
363
if func is aten .t .default :
340
364
"""we don't need to repack the weight and just rely on external
341
365
shape being changed and record the status of transpose/no-transpose
@@ -399,29 +423,45 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
399
423
)
400
424
from torchao .quantization .utils import unpack_tinygemm_scales_and_zeros
401
425
402
- scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
403
-
426
+ def dequant_4d (self ):
427
+ cur_shape = self .shape
428
+ assert len (cur_shape ) == 4
429
+ inner_k_tiles = cur_shape [- 1 ] * 2
430
+ original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
431
+ eye_shape = original_shape [1 ]
432
+ groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
433
+ block_size = (1 , groupsize )
434
+ device = self .device
435
+ original_dtype = torch .bfloat16
436
+ target_dtype = torch .int32
437
+ quant_min = 0
438
+ quant_max = 15
439
+ zero_point_domain = ZeroPointDomain .FLOAT
440
+ assert len (block_size ) == 2 and block_size [0 ] == 1
441
+ dequantized = torch .ops .aten ._weight_int4pack_mm (
442
+ torch .eye (eye_shape , device = device , dtype = original_dtype ),
443
+ self .packed_weight ,
444
+ groupsize ,
445
+ self .scale_and_zero ,
446
+ )
447
+ dequantized = dequantized .t ().contiguous ()
448
+ return dequantized
449
+
404
450
cur_shape = self .shape
405
- assert len (cur_shape ) == 4
406
- inner_k_tiles = cur_shape [- 1 ] * 2
407
- original_shape = (cur_shape [0 ] * 8 , cur_shape [1 ] * (inner_k_tiles * 16 ))
408
- eye_shape = original_shape [1 ]
409
- groupsize = int (original_shape [1 ] / scale .shape [- 2 ])
410
- block_size = (1 , groupsize )
411
- device = self .device
412
- original_dtype = torch .bfloat16
413
- target_dtype = torch .int32
414
- quant_min = 0
415
- quant_max = 15
416
- zero_point_domain = ZeroPointDomain .FLOAT
417
- assert len (block_size ) == 2 and block_size [0 ] == 1
418
- dequantized = torch .ops .aten ._weight_int4pack_mm (
419
- torch .eye (eye_shape , device = device , dtype = original_dtype ),
420
- self .packed_weight ,
421
- groupsize ,
422
- self .scale_and_zero ,
423
- )
424
- dequantized = dequantized .t ().contiguous ()
451
+
452
+ if len (cur_shape )== 4 :
453
+ dequantized = dequant_4d (self )
454
+ else :
455
+
456
+ assert len (cur_shape ) == 5
457
+ num_experts = cur_shape [0 ]
458
+ dequantized_list = []
459
+ import fbvscode ; fbvscode .set_trace ()
460
+ for expert in range (num_experts ):
461
+ dequantized_list .append (dequant_4d (self [expert ]).unsqueeze (0 ))
462
+ de
463
+
464
+ scale , zero = unpack_tinygemm_scales_and_zeros (self .scale_and_zero )
425
465
# TODO: move this to `unpack_tinygemm_scales_and_zeros`?
426
466
scale = scale .reshape (scale .shape [:- 1 ]).contiguous ()
427
467
zero = zero .reshape (zero .shape [:- 1 ]).contiguous ()
0 commit comments