|
1 | 1 | # Copyright (c) Alibaba, Inc. and its affiliates.
|
2 | 2 | import os
|
3 |
| -from typing import Dict, List, Optional |
| 3 | +from typing import List, Optional |
4 | 4 |
|
5 | 5 | import json
|
6 | 6 | import torch
|
7 |
| -import transformers |
8 |
| -from packaging import version |
9 | 7 |
|
10 | 8 | from swift.llm import get_model_tokenizer, get_template
|
11 | 9 | from swift.utils import (check_json_format, get_logger, get_main, get_model_info, push_to_ms_hub, seed_everything,
|
@@ -66,67 +64,8 @@ def _get_dataset(*args, **kwargs):
|
66 | 64 |
|
67 | 65 | def awq_model_quantize(awq_model, tokenizer, batch_size) -> None:
|
68 | 66 |
|
69 |
| - def _llama_rotary_emb_forward(self, x, position_ids): |
70 |
| - with torch.no_grad(): |
71 |
| - if 'dynamic' in self.rope_type: |
72 |
| - self._dynamic_frequency_update(position_ids, device=x.device) |
73 |
| - |
74 |
| - # Core RoPE block |
75 |
| - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) |
76 |
| - position_ids_expanded = position_ids[:, None, :].float() |
77 |
| - # Force float32 (see https://github.com/huggingface/transformers/pull/29285) |
78 |
| - device_type = x.device.type |
79 |
| - device_type = device_type if isinstance(device_type, str) and device_type != 'mps' else 'cpu' |
80 |
| - with torch.autocast(device_type=device_type, enabled=False): |
81 |
| - inv_freq_expanded = inv_freq_expanded.to(position_ids_expanded.device) |
82 |
| - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
83 |
| - emb = torch.cat((freqs, freqs), dim=-1) |
84 |
| - cos = emb.cos() |
85 |
| - sin = emb.sin() |
86 |
| - |
87 |
| - # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention |
88 |
| - cos = cos * self.attention_scaling |
89 |
| - sin = sin * self.attention_scaling |
90 |
| - |
91 |
| - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
92 |
| - |
93 |
| - @torch.no_grad() |
94 |
| - def _module_forward(self, x: torch.Tensor, module: torch.nn.Module, module_kwargs: Dict) -> torch.Tensor: |
95 |
| - # The original code of awq.AwqQuantizer._module_forward has a bug with n_parallel_calib_samples |
96 |
| - if self.n_parallel_calib_samples is None: |
97 |
| - # runs through all samples at once |
98 |
| - module_output = module(x, **module_kwargs) |
99 |
| - if isinstance(module_output, tuple): |
100 |
| - module_output = module_output[0] |
101 |
| - else: |
102 |
| - # memory efficiently runs through all calibration samples |
103 |
| - # but only n_parallel_calib_samples at a time |
104 |
| - module_output = [] |
105 |
| - partitioned_inputs = torch.split(x, self.n_parallel_calib_samples) |
106 |
| - for idx, x_partial in enumerate(partitioned_inputs): |
107 |
| - tmp_module_kwargs = {**module_kwargs} |
108 |
| - if tmp_module_kwargs.get('attention_mask'): |
109 |
| - tmp_module_kwargs['attention_mask'] = tmp_module_kwargs['attention_mask'][idx:idx + self. |
110 |
| - n_parallel_calib_samples] |
111 |
| - partial_output = module(x_partial, **tmp_module_kwargs) |
112 |
| - |
113 |
| - if isinstance(partial_output, tuple): |
114 |
| - partial_output = partial_output[0] |
115 |
| - |
116 |
| - module_output.append(partial_output.cpu()) |
117 |
| - |
118 |
| - module_output = torch.cat(module_output, dim=0) |
119 |
| - |
120 |
| - return module_output |
121 |
| - |
122 |
| - import awq |
123 | 67 | from awq.quantize import quantizer
|
124 | 68 | from transformers import AwqConfig
|
125 |
| - if version.parse(awq.__version__) >= version.parse('0.2.6'): |
126 |
| - quantizer.AwqQuantizer._module_forward = _module_forward |
127 |
| - |
128 |
| - if version.parse(transformers.__version__) >= version.parse('4.43.0'): |
129 |
| - transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = _llama_rotary_emb_forward |
130 | 69 |
|
131 | 70 | assert _args is not None
|
132 | 71 | logger.info(f'Quantization dataset: {_args.dataset}')
|
@@ -257,7 +196,6 @@ def llm_export(args: ExportArguments) -> None:
|
257 | 196 | model.config.quantization_config.pop('dataset', None)
|
258 | 197 | gptq_quantizer.save(model, args.quant_output_dir)
|
259 | 198 | elif args.quant_method == 'bnb':
|
260 |
| - args.quant_device_map = 'auto' # cannot use cpu on bnb |
261 | 199 | args.quantization_bit = args.quant_bits
|
262 | 200 | args.bnb_4bit_compute_dtype, args.load_in_4bit, args.load_in_8bit = args.select_bnb()
|
263 | 201 | model, template = prepare_model_template(args, device_map=args.quant_device_map, verbose=False)
|
|
0 commit comments