|
52 | 52 | mv_module_from_gpu,
|
53 | 53 | unsupport_meta_device, clear_memory,
|
54 | 54 | compile_func,
|
55 |
| - find_matching_blocks, is_debug_mode |
| 55 | + find_matching_blocks, is_debug_mode, |
| 56 | + TORCH_VERSION_AT_LEAST_2_6 |
56 | 57 | )
|
57 | 58 | from .low_cpu_mem.utils import get_layers_before_block
|
58 | 59 |
|
@@ -159,7 +160,7 @@ def __init__(
|
159 | 160 | act_dynamic: bool = True,
|
160 | 161 | to_quant_block_names: Union[str, list] = None,
|
161 | 162 | enable_norm_bias_tuning: bool = False,
|
162 |
| - enable_torch_compile: bool = None, |
| 163 | + enable_torch_compile: bool = False, |
163 | 164 | device_map: Union[str, dict] = None,
|
164 | 165 | **kwargs,
|
165 | 166 | ):
|
@@ -232,19 +233,24 @@ def __init__(
|
232 | 233 | logger.info(f"using {self.model.dtype} for quantization tuning")
|
233 | 234 |
|
234 | 235 | self.enable_torch_compile = enable_torch_compile
|
235 |
| - if self.act_bits <= 8 and self.enable_torch_compile != False: |
| 236 | + if not self.enable_torch_compile and TORCH_VERSION_AT_LEAST_2_6 and self.act_bits > 8 and not is_debug_mode() \ |
| 237 | + and self.low_cpu_mem_usage != True and "fp8" not in self.data_type and "fp8" not in self.act_data_type: |
| 238 | + logger.info("'enable_torch_compile' is set to `False` by default. " \ |
| 239 | + "Enabling it can reduce tuning cost by 20%, but it might throw an exception.") |
| 240 | + |
| 241 | + if self.act_bits <= 8 and self.enable_torch_compile: |
236 | 242 | self.enable_torch_compile = False
|
237 | 243 | logger.warning("reset enable_torch_compile to `False` as activation quantization is enabled")
|
238 | 244 |
|
239 |
| - if self.low_cpu_mem_usage == True and self.enable_torch_compile != False: |
| 245 | + if self.low_cpu_mem_usage == True and self.enable_torch_compile: |
240 | 246 | self.enable_torch_compile = False
|
241 | 247 | logger.warning("reset enable_torch_compile to `False` as low_cpu_mem_usage is enabled")
|
242 | 248 |
|
243 |
| - if is_debug_mode() and self.enable_torch_compile != False: |
| 249 | + if is_debug_mode() and self.enable_torch_compile: |
244 | 250 | self.enable_torch_compile = False
|
245 | 251 | logger.warning("reset enable_torch_compile to `False` as debug mode is enabled")
|
246 | 252 |
|
247 |
| - if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile != False: |
| 253 | + if ("fp8" in self.data_type or "fp8" in self.act_data_type) and self.enable_torch_compile: |
248 | 254 | self.enable_torch_compile = False
|
249 | 255 | logger.warning("reset enable_torch_compile to `False` as fp8 is enabled")
|
250 | 256 |
|
@@ -493,13 +499,8 @@ def quant_layers(self, layer_names, layer_inputs):
|
493 | 499 | self.model = mv_module_from_gpu(self.model, self.low_cpu_mem_usage)
|
494 | 500 | clear_memory()
|
495 | 501 | device = next(self.model.parameters()).device
|
496 |
| - if self.enable_torch_compile != False: |
497 |
| - try: |
498 |
| - quant_layer = compile_func(self.quant_layer, self.device, self.enable_torch_compile) |
499 |
| - except: |
500 |
| - logger.warning("torch compile failed, reset it to `False`") |
501 |
| - self.enable_torch_compile = False |
502 |
| - quant_layer = self.quant_layer |
| 502 | + if self.enable_torch_compile: |
| 503 | + quant_layer = compile_func(self.quant_layer, self.device) |
503 | 504 | else:
|
504 | 505 | quant_layer = self.quant_layer
|
505 | 506 | for layer_name in layer_names:
|
@@ -1311,13 +1312,8 @@ def quant_blocks(
|
1311 | 1312 | elif isinstance(input_others[key], list):
|
1312 | 1313 | for i in range(len(input_others[key])):
|
1313 | 1314 | to_dtype(input_others[key][i], tmp_dtype)
|
1314 |
| - if self.enable_torch_compile != False: |
1315 |
| - try: |
1316 |
| - quant_block = compile_func(self.quant_block, device, self.enable_torch_compile) |
1317 |
| - except: |
1318 |
| - logger.warning("torch compile failed, reset it to `False`") |
1319 |
| - self.enable_torch_compile = False |
1320 |
| - quant_block = self.quant_block |
| 1315 | + if self.enable_torch_compile: |
| 1316 | + quant_block = compile_func(self.quant_block, device) |
1321 | 1317 | else:
|
1322 | 1318 | quant_block = self.quant_block
|
1323 | 1319 |
|
@@ -1648,7 +1644,7 @@ def __init__(
|
1648 | 1644 | act_dynamic: bool = True,
|
1649 | 1645 | to_quant_block_names: Union[str, list] = None,
|
1650 | 1646 | enable_norm_bias_tuning: bool = False,
|
1651 |
| - enable_torch_compile: bool = None, |
| 1647 | + enable_torch_compile: bool = False, |
1652 | 1648 | device_map: Union[str, dict] = None,
|
1653 | 1649 | optimizer="AdamW",
|
1654 | 1650 | **kwargs,
|
@@ -1822,7 +1818,7 @@ def __init__(
|
1822 | 1818 | act_dynamic: bool = True,
|
1823 | 1819 | to_quant_block_names: Union[str, list] = None,
|
1824 | 1820 | enable_norm_bias_tuning: bool = False,
|
1825 |
| - enable_torch_compile: bool = None, |
| 1821 | + enable_torch_compile: bool = False, |
1826 | 1822 | device_map: Union[str, dict] = None,
|
1827 | 1823 | optimizer="AdamW",
|
1828 | 1824 | **kwargs,
|
@@ -1868,3 +1864,4 @@ def __init__(
|
1868 | 1864 | optimizer=optimizer,
|
1869 | 1865 | **kwargs,
|
1870 | 1866 | )
|
| 1867 | + |
0 commit comments