Skip to content

Commit b650162

Browse files
authored
set torch compile to false by default (#447)
* align auto_quantizer with main branch in Transformers Signed-off-by: Zhang, Weiwei1 <[email protected]> * rename torch_compile argment, set torch_compile to False by default Signed-off-by: Zhang, Weiwei1 <[email protected]> * Update auto_quantizer.py * fixtypos Signed-off-by: Zhang, Weiwei1 <[email protected]> * refine code Signed-off-by: Zhang, Weiwei1 <[email protected]> * fixtypo and refine compile func Signed-off-by: Zhang, Weiwei1 <[email protected]> * fix scan issue Signed-off-by: Zhang, Weiwei1 <[email protected]> * fixtypo Signed-off-by: Zhang, Weiwei1 <[email protected]> --------- Signed-off-by: Zhang, Weiwei1 <[email protected]>
1 parent 3a14328 commit b650162

File tree

5 files changed

+38
-40
lines changed

5 files changed

+38
-40
lines changed

auto_round/autoround.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@
5252
mv_module_from_gpu,
5353
unsupport_meta_device, clear_memory,
5454
compile_func,
55-
find_matching_blocks, is_debug_mode
55+
find_matching_blocks, is_debug_mode,
56+
TORCH_VERSION_AT_LEAST_2_6
5657
)
5758
from .low_cpu_mem.utils import get_layers_before_block
5859

@@ -159,7 +160,7 @@ def __init__(
159160
act_dynamic: bool = True,
160161
to_quant_block_names: Union[str, list] = None,
161162
enable_norm_bias_tuning: bool = False,
162-
enable_torch_compile: bool = None,
163+
enable_torch_compile: bool = False,
163164
device_map: Union[str, dict] = None,
164165
**kwargs,
165166
):
@@ -232,19 +233,24 @@ def __init__(
232233
logger.info(f"using {self.model.dtype} for quantization tuning")
233234

234235
self.enable_torch_compile = enable_torch_compile
235-
if self.act_bits <= 8 and self.enable_torch_compile != False:
236+
if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \
237+
and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type:
238+
logger.info("'enable_torch_compile' is set to `False` by default. " \
239+
"Enabling it can reduce tuning cost by 20%, but it might throw an exception.")
240+
241+
if self.act_bits <= 8 and self.enable_torch_compile:
236242
self.enable_torch_compile = False
237243
logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled")
238244

239-
if self.low_cpu_mem_usage == True and self.enable_torch_compile != False:
245+
if self.low_cpu_mem_usage == True and self.enable_torch_compile:
240246
self.enable_torch_compile = False
241247
logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")
242248

243-
if is_debug_mode() and self.enable_torch_compile != False:
249+
if is_debug_mode() and self.enable_torch_compile:
244250
self.enable_torch_compile = False
245251
logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")
246252

247-
if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile != False:
253+
if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile:
248254
self.enable_torch_compile = False
249255
logger.warning("reset enable_torch_compile to `False` as fp8 is enabled")
250256

@@ -493,13 +499,8 @@ def quant_layers(self, layer_names, layer_inputs):
493499
self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
494500
clear_memory()
495501
device = next(self.model.parameters()).device
496-
if self.enable_torch_compile != False:
497-
try:
498-
quant_layer = compile_func(self.quant_layer, self.device, self.enable_torch_compile)
499-
except:
500-
logger.warning("torch compile failed, reset it to `False`")
501-
self.enable_torch_compile = False
502-
quant_layer = self.quant_layer
502+
if self.enable_torch_compile:
503+
quant_layer = compile_func(self.quant_layer, self.device)
503504
else:
504505
quant_layer = self.quant_layer
505506
for layer_name in layer_names:
@@ -1311,13 +1312,8 @@ def quant_blocks(
13111312
elif isinstance(input_others[key], list):
13121313
for i in range(len(input_others[key])):
13131314
to_dtype(input_others[key][i], tmp_dtype)
1314-
if self.enable_torch_compile != False:
1315-
try:
1316-
quant_block = compile_func(self.quant_block, device, self.enable_torch_compile)
1317-
except:
1318-
logger.warning("torch compile failed, reset it to `False`")
1319-
self.enable_torch_compile = False
1320-
quant_block = self.quant_block
1315+
if self.enable_torch_compile:
1316+
quant_block = compile_func(self.quant_block, device)
13211317
else:
13221318
quant_block = self.quant_block
13231319

@@ -1648,7 +1644,7 @@ def __init__(
16481644
act_dynamic: bool = True,
16491645
to_quant_block_names: Union[str, list] = None,
16501646
enable_norm_bias_tuning: bool = False,
1651-
enable_torch_compile: bool = None,
1647+
enable_torch_compile: bool = False,
16521648
device_map: Union[str, dict] = None,
16531649
optimizer="AdamW",
16541650
**kwargs,
@@ -1822,7 +1818,7 @@ def __init__(
18221818
act_dynamic: bool = True,
18231819
to_quant_block_names: Union[str, list] = None,
18241820
enable_norm_bias_tuning: bool = False,
1825-
enable_torch_compile: bool = None,
1821+
enable_torch_compile: bool = False,
18261822
device_map: Union[str, dict] = None,
18271823
optimizer="AdamW",
18281824
**kwargs,
@@ -1868,3 +1864,4 @@ def __init__(
18681864
optimizer=optimizer,
18691865
**kwargs,
18701866
)
1867+

auto_round/mllm/autoround_mllm.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ class AutoRoundMLLM(AutoRound):
112112
act_dynamic (bool): Whether to use dynamic activation quantization. Default is True.
113113
to_quant_block_names (str|list): A string or list whose elements are list of
114114
block's layer names to be quantized.
115-
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer, torch>=2.6 True
115+
enable_torch_compile (bool): Whether to enable torch compile to optimize quant_block/layer
116116
**kwargs: Additional keyword arguments.
117117
118118
@@ -160,7 +160,7 @@ def __init__(
160160
to_quant_block_names: Union[str, list] = None,
161161
enable_norm_bias_tuning: bool = False,
162162
truncation: bool = None,
163-
enable_torch_compile: bool = None,
163+
enable_torch_compile: bool = False,
164164
**kwargs,
165165
):
166166
all_blocks = get_multimodal_block_names(model, quant_nontext_module)
@@ -410,3 +410,4 @@ def save_quantized(self, output_dir=None, format="auto_round", inplace=True, **k
410410
compressed_model = super().save_quantized(
411411
output_dir=output_dir, format=format, inplace=inplace, processor=self.processor, **kwargs)
412412
return compressed_model
413+

auto_round/script/llm.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def __init__(self, *args, **kwargs):
169169
type=str,
170170
help="Names of quantitative blocks, please use commas to separate them.")
171171

172-
self.add_argument("--disable_torch_compile", action='store_true',
173-
help="whether to disable torch compile")
172+
self.add_argument("--enable_torch_compile", action='store_true',
173+
help="whether to enable torch compile")
174174

175175
self.add_argument("--act_data_type", default=None, type=str, help="activation data type")
176176

@@ -353,9 +353,9 @@ def tune(args):
353353
# logger.info("`torch.use_deterministic_algorithms` is enabled by default for reproducibility "
354354
# "and can be disabled using the `--disable_deterministic_algorithms` argument.")
355355

356-
if not args.disable_torch_compile:
357-
logger.info("`torch.compile` is enabled by default to reduce tuning costs. "
358-
"If it causes issues, you can disable it using the `--disable_torch_compile` argument.")
356+
if args.enable_torch_compile:
357+
logger.info("`torch.compile` is enabled to reduce tuning costs. "
358+
"If it causes issues, you can disable it by remove `--enable_torch_compile` argument.")
359359

360360
model_name = args.model
361361
if model_name[-1] == "/":
@@ -482,7 +482,7 @@ def tune(args):
482482
if not awq_supported:
483483
logger.warning(f"The AutoAWQ format may not be supported due to {info}")
484484

485-
enable_torch_compile = False if "--disable_torch_compile" in sys.argv else None
485+
enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False
486486

487487
autoround = round(
488488
model,
@@ -621,3 +621,4 @@ def eval_sequence(args):
621621
for key in res_keys:
622622
res_all[key].update(res[key])
623623
print(make_table(res_all))
624+

auto_round/script/mllm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ def __init__(self, *args, **kwargs):
152152
action='store_true',
153153
help="whether to use the iter of best mes loss in the tuning phase")
154154

155-
self.add_argument("--disable_torch_compile", action='store_true',
156-
help="whether to disable torch compile")
155+
self.add_argument("--enable_torch_compile", action='store_true',
156+
help="whether to enable torch compile")
157157

158158
self.add_argument("--disable_deterministic_algorithms", action='store_true',
159159
help="disable torch deterministic algorithms.")
@@ -446,7 +446,7 @@ def tune(args):
446446
if not awq_supported:
447447
logger.warning(f"The AutoAWQ format may not be supported due to {info}")
448448

449-
enable_torch_compile = False if "--disable_torch_compile" in sys.argv else None
449+
enable_torch_compile = True if "--enable_torch_compile" in sys.argv else False
450450

451451
autoround = round(
452452
model,
@@ -598,3 +598,4 @@ def lmms_eval(args):
598598
)
599599
return results
600600

601+

auto_round/utils.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -1009,18 +1009,15 @@ def compile_func_on_hpu(func):
10091009
return func
10101010

10111011

1012-
def compile_func_on_cuda_or_cpu(func, enable_torch_compile):
1013-
if enable_torch_compile or (TORCH_VERSION_AT_LEAST_2_6_PRE_RELEASE and enable_torch_compile != False):
1014-
return torch.compile(func)
1015-
else:
1016-
return func
1012+
def compile_func_on_cuda_or_cpu(func):
1013+
return torch.compile(func)
10171014

10181015

1019-
def compile_func(fun, device, enable_torch_compile):
1016+
def compile_func(fun, device):
10201017
if "hpu" in str(device):
10211018
return compile_func_on_hpu(fun) ## use auto by default
10221019
else:
1023-
return compile_func_on_cuda_or_cpu(fun, enable_torch_compile)
1020+
return compile_func_on_cuda_or_cpu(fun)
10241021

10251022

10261023
def is_numba_available(): # pragma: no cover
@@ -1201,3 +1198,4 @@ def is_debug_mode():
12011198
bool: True if debugging is enabled, False otherwise.
12021199
"""
12031200
return sys.gettrace() is not None or sys.flags.debug == 1
1201+

0 commit comments

Comments
 (0)