88"""
99
1010from dataclasses import dataclass , field
11- from typing import Any , Callable , Dict , List , Optional , Union
11+ from typing import Any , Callable
1212
1313from transformers import DefaultDataCollator
1414
@@ -19,7 +19,7 @@ class DVCDatasetArguments:
1919 Arguments for training using DVC
2020 """
2121
22- dvc_data_repository : Optional [ str ] = field (
22+ dvc_data_repository : str | None = field (
2323 default = None ,
2424 metadata = {"help" : "Path to repository used for dvc_dataset_path" },
2525 )
@@ -31,7 +31,7 @@ class CustomDatasetArguments(DVCDatasetArguments):
3131 Arguments for training using custom datasets
3232 """
3333
34- dataset_path : Optional [ str ] = field (
34+ dataset_path : str | None = field (
3535 default = None ,
3636 metadata = {
3737 "help" : (
@@ -52,12 +52,12 @@ class CustomDatasetArguments(DVCDatasetArguments):
5252 },
5353 )
5454
55- remove_columns : Union [ None , str , List ] = field (
55+ remove_columns : str | list [ str ] | None = field (
5656 default = None ,
5757 metadata = {"help" : "Column names to remove after preprocessing (deprecated)" },
5858 )
5959
60- preprocessing_func : Union [ None , str , Callable ] = field (
60+ preprocessing_func : str | Callable | None = field (
6161 default = None ,
6262 metadata = {
6363 "help" : (
@@ -85,7 +85,7 @@ class DatasetArguments(CustomDatasetArguments):
8585 arguments to be able to specify them on the command line
8686 """
8787
88- dataset : Optional [ str ] = field (
88+ dataset : str | None = field (
8989 default = None ,
9090 metadata = {
9191 "help" : (
@@ -94,7 +94,7 @@ class DatasetArguments(CustomDatasetArguments):
9494 )
9595 },
9696 )
97- dataset_config_name : Optional [ str ] = field (
97+ dataset_config_name : str | None = field (
9898 default = None ,
9999 metadata = {
100100 "help" : ("The configuration name of the dataset to use" ),
@@ -114,15 +114,15 @@ class DatasetArguments(CustomDatasetArguments):
114114 "help" : "Whether or not to concatenate datapoints to fill max_seq_length"
115115 },
116116 )
117- raw_kwargs : Dict = field (
117+ raw_kwargs : dict = field (
118118 default_factory = dict ,
119119 metadata = {"help" : "Additional keyboard args to pass to datasets load_data" },
120120 )
121- splits : Union [ None , str , List , Dict ] = field (
121+ splits : str | list [ str ] | dict [ str , str ] | None = field (
122122 default = None ,
123123 metadata = {"help" : "Optional percentages of each split to download" },
124124 )
125- num_calibration_samples : Optional [ int ] = field (
125+ num_calibration_samples : int | None = field (
126126 default = 512 ,
127127 metadata = {"help" : "Number of samples to use for one-shot calibration" },
128128 )
@@ -136,21 +136,21 @@ class DatasetArguments(CustomDatasetArguments):
136136 "module definitions"
137137 },
138138 )
139- shuffle_calibration_samples : Optional [ bool ] = field (
139+ shuffle_calibration_samples : bool | None = field (
140140 default = True ,
141141 metadata = {
142142 "help" : "whether to shuffle the dataset before selecting calibration data"
143143 },
144144 )
145- streaming : Optional [ bool ] = field (
145+ streaming : bool | None = field (
146146 default = False ,
147147 metadata = {"help" : "True to stream data from a cloud dataset" },
148148 )
149149 overwrite_cache : bool = field (
150150 default = False ,
151151 metadata = {"help" : "Overwrite the cached preprocessed datasets or not." },
152152 )
153- preprocessing_num_workers : Optional [ int ] = field (
153+ preprocessing_num_workers : int | None = field (
154154 default = None ,
155155 metadata = {"help" : "The number of processes to use for the preprocessing." },
156156 )
@@ -162,14 +162,14 @@ class DatasetArguments(CustomDatasetArguments):
162162 "in the batch (which can be faster on GPU but will be slower on TPU)."
163163 },
164164 )
165- max_train_samples : Optional [ int ] = field (
165+ max_train_samples : int | None = field (
166166 default = None ,
167167 metadata = {
168168 "help" : "For debugging purposes or quicker training, truncate the number "
169169 "of training examples to this value if set."
170170 },
171171 )
172- min_tokens_per_module : Optional [ float ] = field (
172+ min_tokens_per_module : float | None = field (
173173 default = None ,
174174 metadata = {
175175 "help" : (
@@ -182,15 +182,15 @@ class DatasetArguments(CustomDatasetArguments):
182182 },
183183 )
184184 # --- pipeline arguments --- #
185- pipeline : Optional [ str ] = field (
185+ pipeline : str | None = field (
186186 default = "independent" ,
187187 metadata = {
188188 "help" : "Calibration pipeline used to calibrate model"
189189 "Options: ['basic', 'datafree', 'sequential', 'layer_sequential', "
190190 "independent]"
191191 },
192192 )
193- tracing_ignore : List [str ] = field (
193+ tracing_ignore : list [str ] = field (
194194 default_factory = lambda : [
195195 "_update_causal_mask" ,
196196 "create_causal_mask" ,
@@ -209,7 +209,7 @@ class DatasetArguments(CustomDatasetArguments):
209209 "{module}.{method_name} or {function_name}"
210210 },
211211 )
212- sequential_targets : Optional [ List [ str ]] = field (
212+ sequential_targets : list [ str ] | None = field (
213213 default = None ,
214214 metadata = {
215215 "help" : "List of layer targets for the sequential pipeline. "
0 commit comments