Skip to content

Commit 1599812

Browse files
committed
Modernize type hints - dataclass arguments and recipe metadata
Signed-off-by: ojeda-e <[email protected]>
1 parent 5061adf commit 1599812

File tree

6 files changed

+48
-47
lines changed

6 files changed

+48
-47
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Any, Callable, Dict, List, Optional, Union
11+
from typing import Any, Callable
1212

1313
from transformers import DefaultDataCollator
1414

@@ -19,7 +19,7 @@ class DVCDatasetArguments:
1919
Arguments for training using DVC
2020
"""
2121

22-
dvc_data_repository: Optional[str] = field(
22+
dvc_data_repository: str | None = field(
2323
default=None,
2424
metadata={"help": "Path to repository used for dvc_dataset_path"},
2525
)
@@ -31,7 +31,7 @@ class CustomDatasetArguments(DVCDatasetArguments):
3131
Arguments for training using custom datasets
3232
"""
3333

34-
dataset_path: Optional[str] = field(
34+
dataset_path: str | None = field(
3535
default=None,
3636
metadata={
3737
"help": (
@@ -52,12 +52,12 @@ class CustomDatasetArguments(DVCDatasetArguments):
5252
},
5353
)
5454

55-
remove_columns: Union[None, str, List] = field(
55+
remove_columns: str | list[str] | None = field(
5656
default=None,
5757
metadata={"help": "Column names to remove after preprocessing (deprecated)"},
5858
)
5959

60-
preprocessing_func: Union[None, str, Callable] = field(
60+
preprocessing_func: str | Callable | None = field(
6161
default=None,
6262
metadata={
6363
"help": (
@@ -85,7 +85,7 @@ class DatasetArguments(CustomDatasetArguments):
8585
arguments to be able to specify them on the command line
8686
"""
8787

88-
dataset: Optional[str] = field(
88+
dataset: str | None = field(
8989
default=None,
9090
metadata={
9191
"help": (
@@ -94,7 +94,7 @@ class DatasetArguments(CustomDatasetArguments):
9494
)
9595
},
9696
)
97-
dataset_config_name: Optional[str] = field(
97+
dataset_config_name: str | None = field(
9898
default=None,
9999
metadata={
100100
"help": ("The configuration name of the dataset to use"),
@@ -114,15 +114,15 @@ class DatasetArguments(CustomDatasetArguments):
114114
"help": "Whether or not to concatenate datapoints to fill max_seq_length"
115115
},
116116
)
117-
raw_kwargs: Dict = field(
117+
raw_kwargs: dict = field(
118118
default_factory=dict,
119119
metadata={"help": "Additional keyboard args to pass to datasets load_data"},
120120
)
121-
splits: Union[None, str, List, Dict] = field(
121+
splits: str | list[str] | dict[str, str] | None = field(
122122
default=None,
123123
metadata={"help": "Optional percentages of each split to download"},
124124
)
125-
num_calibration_samples: Optional[int] = field(
125+
num_calibration_samples: int | None = field(
126126
default=512,
127127
metadata={"help": "Number of samples to use for one-shot calibration"},
128128
)
@@ -136,21 +136,21 @@ class DatasetArguments(CustomDatasetArguments):
136136
"module definitions"
137137
},
138138
)
139-
shuffle_calibration_samples: Optional[bool] = field(
139+
shuffle_calibration_samples: bool | None = field(
140140
default=True,
141141
metadata={
142142
"help": "whether to shuffle the dataset before selecting calibration data"
143143
},
144144
)
145-
streaming: Optional[bool] = field(
145+
streaming: bool | None = field(
146146
default=False,
147147
metadata={"help": "True to stream data from a cloud dataset"},
148148
)
149149
overwrite_cache: bool = field(
150150
default=False,
151151
metadata={"help": "Overwrite the cached preprocessed datasets or not."},
152152
)
153-
preprocessing_num_workers: Optional[int] = field(
153+
preprocessing_num_workers: int | None = field(
154154
default=None,
155155
metadata={"help": "The number of processes to use for the preprocessing."},
156156
)
@@ -162,14 +162,14 @@ class DatasetArguments(CustomDatasetArguments):
162162
"in the batch (which can be faster on GPU but will be slower on TPU)."
163163
},
164164
)
165-
max_train_samples: Optional[int] = field(
165+
max_train_samples: int | None = field(
166166
default=None,
167167
metadata={
168168
"help": "For debugging purposes or quicker training, truncate the number "
169169
"of training examples to this value if set."
170170
},
171171
)
172-
min_tokens_per_module: Optional[float] = field(
172+
min_tokens_per_module: float | None = field(
173173
default=None,
174174
metadata={
175175
"help": (
@@ -182,15 +182,15 @@ class DatasetArguments(CustomDatasetArguments):
182182
},
183183
)
184184
# --- pipeline arguments --- #
185-
pipeline: Optional[str] = field(
185+
pipeline: str | None = field(
186186
default="independent",
187187
metadata={
188188
"help": "Calibration pipeline used to calibrate model"
189189
"Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
190190
"independent]"
191191
},
192192
)
193-
tracing_ignore: List[str] = field(
193+
tracing_ignore: list[str] = field(
194194
default_factory=lambda: [
195195
"_update_causal_mask",
196196
"create_causal_mask",
@@ -209,7 +209,7 @@ class DatasetArguments(CustomDatasetArguments):
209209
"{module}.{method_name} or {function_name}"
210210
},
211211
)
212-
sequential_targets: Optional[List[str]] = field(
212+
sequential_targets: list[str] | None = field(
213213
default=None,
214214
metadata={
215215
"help": "List of layer targets for the sequential pipeline. "

src/llmcompressor/args/model_arguments.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Optional
1211

1312

1413
@dataclass
@@ -27,31 +26,31 @@ class ModelArguments:
2726
)
2827
},
2928
)
30-
distill_teacher: Optional[str] = field(
29+
distill_teacher: str | None = field(
3130
default=None,
3231
metadata={
3332
"help": "Teacher model (a trained text generation model)",
3433
},
3534
)
36-
config_name: Optional[str] = field(
35+
config_name: str | None = field(
3736
default=None,
3837
metadata={
3938
"help": "Pretrained config name or path if not the same as model_name"
4039
},
4140
)
42-
tokenizer: Optional[str] = field(
41+
tokenizer: str | None = field(
4342
default=None,
4443
metadata={
4544
"help": "Pretrained tokenizer name or path if not the same as model_name"
4645
},
4746
)
48-
processor: Optional[str] = field(
47+
processor: str | None = field(
4948
default=None,
5049
metadata={
5150
"help": "Pretrained processor name or path if not the same as model_name"
5251
},
5352
)
54-
cache_dir: Optional[str] = field(
53+
cache_dir: str | None = field(
5554
default=None,
5655
metadata={"help": "Where to store the pretrained data from huggingface.co"},
5756
)
@@ -85,7 +84,7 @@ class ModelArguments:
8584
},
8685
)
8786
# TODO: potentialy separate out/expand to additional saving args
88-
save_compressed: Optional[bool] = field(
87+
save_compressed: bool | None = field(
8988
default=True,
9089
metadata={"help": "Whether to compress sparse models during save"},
9190
)

src/llmcompressor/args/recipe_arguments.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,19 @@
77
"""
88

99
from dataclasses import dataclass, field
10-
from typing import List, Optional
1110

1211

1312
@dataclass
1413
class RecipeArguments:
1514
"""Recipe and session variables"""
1615

17-
recipe: Optional[str] = field(
16+
recipe: str | None = field(
1817
default=None,
1918
metadata={
2019
"help": "Path to a LLM Compressor sparsification recipe",
2120
},
2221
)
23-
recipe_args: Optional[List[str]] = field(
22+
recipe_args: list[str] | None = field(
2423
default=None,
2524
metadata={
2625
"help": (
@@ -29,7 +28,7 @@ class RecipeArguments:
2928
)
3029
},
3130
)
32-
clear_sparse_session: Optional[bool] = field(
31+
clear_sparse_session: bool | None = field(
3332
default=False,
3433
metadata={
3534
"help": (
@@ -38,7 +37,7 @@ class RecipeArguments:
3837
)
3938
},
4039
)
41-
stage: Optional[str] = field(
40+
stage: str | None = field(
4241
default=None,
4342
metadata={"help": ("The stage of the recipe to use for oneshot / train.",)},
4443
)

src/llmcompressor/args/training_arguments.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
"""
99

1010
from dataclasses import dataclass, field
11-
from typing import Optional
1211

1312
from transformers import TrainingArguments as HFTrainingArgs
1413

@@ -25,11 +24,11 @@ class TrainingArguments(HFTrainingArgs):
2524
2625
"""
2726

28-
do_oneshot: Optional[bool] = field(
27+
do_oneshot: bool | None = field(
2928
default=False,
3029
metadata={"help": "Whether to run one-shot calibration in stages"},
3130
)
32-
run_stages: Optional[bool] = field(
31+
run_stages: bool | None = field(
3332
default=False, metadata={"help": "Whether to trigger recipe stage by stage"}
3433
)
3534
output_dir: str = field(

src/llmcompressor/args/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
warnings, and processor resolution.
88
"""
99

10-
from typing import Tuple
11-
1210
from loguru import logger
1311
from transformers import HfArgumentParser
1412

@@ -23,7 +21,13 @@
2321

2422
def parse_args(
2523
include_training_args: bool = False, **kwargs
26-
) -> Tuple[ModelArguments, DatasetArguments, RecipeArguments, TrainingArguments, str]:
24+
) -> tuple[
25+
ModelArguments,
26+
DatasetArguments,
27+
RecipeArguments,
28+
TrainingArguments | None,
29+
str | None,
30+
]:
2731
"""
2832
Keyword arguments passed in from `oneshot` or `train` will
2933
separate the arguments into the following:

src/llmcompressor/recipe/metadata.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
structured data containers for recipe configuration and execution tracking.
77
"""
88

9-
from typing import Any, Dict, List, Optional
9+
from typing import Any
1010

1111
from pydantic import BaseModel, Field
1212

@@ -22,7 +22,7 @@ class DatasetMetaData(BaseModel):
2222
name: str = None
2323
version: str = None
2424
hash: str = None
25-
shape: List[int] = Field(default_factory=list)
25+
shape: list[int] = Field(default_factory=list)
2626
num_classes: int = None
2727
num_train_samples: int = None
2828
num_val_samples: int = None
@@ -31,24 +31,24 @@ class DatasetMetaData(BaseModel):
3131

3232
class ParamMetaData(BaseModel):
3333
name: str = None
34-
shape: List[int] = None
34+
shape: list[int] = None
3535
weight_hash: str = None
3636

3737

3838
class LayerMetaData(BaseModel):
3939
name: str = None
4040
type: str = None
4141
index: int = None
42-
attributes: Dict[str, Any] = None
43-
input_shapes: List[List[int]] = None
44-
output_shapes: List[List[int]] = None
45-
params: Dict[str, ParamMetaData] = None
42+
attributes: dict[str, Any] = None
43+
input_shapes: list[list[int]] = None
44+
output_shapes: list[list[int]] = None
45+
params: dict[str, ParamMetaData] = None
4646

4747

4848
class ModelMetaData(BaseModel):
4949
architecture: str = None
5050
sub_architecture: str = None
51-
input_shapes: List[List[int]] = None
52-
output_shapes: List[List[int]] = None
53-
layers: List[LayerMetaData] = Field(default_factory=list)
54-
layer_prefix: Optional[str] = None
51+
input_shapes: list[list[int]] = None
52+
output_shapes: list[list[int]] = None
53+
layers: list[LayerMetaData] = Field(default_factory=list)
54+
layer_prefix: str | None = None

0 commit comments

Comments
 (0)