Skip to content

Commit 3a14328

Browse files
authored
Refine funcs (#446)
1 parent 9aa8803 commit 3a14328

File tree

4 files changed

+24
-19
lines changed

4 files changed

+24
-19
lines changed

auto_round/calib_dataset.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def get_pile_val_dataset(tokenizer, seqlen, dataset_name="swift/pile-val-backup"
150150

151151

152152
@register_dataset("BAAI/CCI3-HQ")
153-
def get_CCI3_HQ_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False):
153+
def get_cci3_hq_dataset(tokenizer, seqlen, dataset_name="BAAI/CCI3-HQ", split=None, seed=42, apply_chat_template=False):
154154
"""Returns a dataloader for the specified dataset and split.
155155
156156
Args:
@@ -235,18 +235,19 @@ def get_new_chinese_title_dataset(
235235
seed=42,
236236
apply_chat_template=False
237237
):
238-
"""Returns a dataloader for the specified dataset and split.
238+
"""
239+
Returns a tokenized dataset for the specified parameters.
239240
240241
Args:
241-
tokenizer: The tokenizer to be used for tokenization.
242-
seqlen: The maximum sequence length.
243-
data_name: The name of the dataset.
244-
split: The data split to be used (e.g., "train", "test").
245-
seed: The random seed for shuffling the dataset.
246-
apply_chat_template: Whether to apply chat template in tokenization.
242+
tokenizer: The tokenizer to use.
243+
seqlen: Maximum sequence length.
244+
dataset_name: Name of the dataset to load.
245+
split: Which split of the dataset to use.
246+
seed: Random seed for shuffling.
247+
apply_template: Whether to apply a template to the data.
247248
248249
Returns:
249-
A dataloader for the specified dataset and split, using the provided tokenizer and sequence length.
250+
A tokenized and shuffled dataset.
250251
"""
251252

252253
def get_tokenizer_function(tokenizer, seqlen, apply_chat_template=apply_chat_template):
@@ -639,3 +640,4 @@ def collate_batch(batch):
639640

640641
calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch)
641642
return calib_dataloader
643+

auto_round/export/export_to_autoround/export.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def check_neq_config(config, data_type, bits, group_size, sym):
5050
return [key for key, expected_value in expected_config.items() if config.get(key) != expected_value]
5151

5252

53-
def dynamic_import_quantLinear_for_packing(backend, bits, group_size, sym):
53+
def dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym):
5454
"""
5555
Dynamically imports and returns the appropriate QuantLinear class based on the specified backend and parameters.
5656
@@ -97,7 +97,7 @@ def pack_layer(name, model, layer_config, backend, pbar):
9797
layer = get_module(model, name)
9898
device = layer.weight.device
9999

100-
QuantLinear = dynamic_import_quantLinear_for_packing(backend, bits, group_size, sym)
100+
QuantLinear = dynamic_import_quant_linear_for_packing(backend, bits, group_size, sym)
101101

102102
if isinstance(layer, nn.Linear):
103103
in_features = layer.in_features
@@ -286,3 +286,4 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
286286
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
287287
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
288288
json.dump(model.config.quantization_config, f, indent=2)
289+

auto_round/low_cpu_mem/utils.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_named_children(model, pre=[]):
7878
return module_list
7979

8080

81-
def dowload_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
81+
def download_hf_model(repo_id, cache_dir=None, repo_type=None, revision=None):
8282
"""Download hugging face model from hf hub."""
8383
from huggingface_hub.constants import DEFAULT_REVISION, HUGGINGFACE_HUB_CACHE
8484
from huggingface_hub.file_download import REGEX_COMMIT_HASH, repo_folder_name
@@ -116,7 +116,7 @@ def load_empty_model(pretrained_model_name_or_path, cls=AutoModelForCausalLM, sa
116116
if is_local: # pragma: no cover
117117
path = pretrained_model_name_or_path
118118
else:
119-
path = dowload_hf_model(pretrained_model_name_or_path)
119+
path = download_hf_model(pretrained_model_name_or_path)
120120
torch_dtype = kwargs.pop("torch_dtype", None)
121121
if cls.__base__ == _BaseAutoModelClass:
122122
config = AutoConfig.from_pretrained(path, **kwargs)
@@ -258,7 +258,7 @@ def _get_path(pretrained_model_name_or_path):
258258
if is_local: # pragma: no cover
259259
path = pretrained_model_name_or_path
260260
else:
261-
path = dowload_hf_model(pretrained_model_name_or_path)
261+
path = download_hf_model(pretrained_model_name_or_path)
262262
return path
263263

264264

@@ -471,3 +471,4 @@ def layer_wise_load(path):
471471
d = pickle.loads(d)
472472
state_dict.update(d)
473473
return state_dict
474+

auto_round/mllm/processor.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
PROCESSORS = {}
2121

2222

23-
def regist_processor(name):
23+
def register_processor(name):
2424
def register(processor):
2525
PROCESSORS[name] = processor
2626
return processor
2727

2828
return register
2929

3030

31-
@regist_processor("basic")
31+
@register_processor("basic")
3232
class BasicProcessor:
3333
def __init__(self):
3434
pass
@@ -111,7 +111,7 @@ def squeeze_result(ret):
111111
return ret
112112

113113

114-
@regist_processor("qwen2_vl")
114+
@register_processor("qwen2_vl")
115115
class Qwen2VLProcessor(BasicProcessor):
116116
@staticmethod
117117
def squeeze_result(ret):
@@ -122,7 +122,7 @@ def squeeze_result(ret):
122122
return ret
123123

124124

125-
@regist_processor("cogvlm2")
125+
@register_processor("cogvlm2")
126126
class CogVLM2Processor(BasicProcessor):
127127
def get_input(
128128
self, text, images, truncation=False,
@@ -205,7 +205,7 @@ def default_image_processor(image_path_or_url):
205205
llava_train = LazyImport("llava.train.train")
206206

207207

208-
@regist_processor("llava")
208+
@register_processor("llava")
209209
class LlavaProcessor(BasicProcessor):
210210
def post_init(self, model, tokenizer, image_processor=None, **kwargs):
211211
self.model = model
@@ -245,3 +245,4 @@ class DataArgs:
245245

246246
def data_collator(self, batch):
247247
return self.collator_func(batch)
248+

0 commit comments

Comments
 (0)