@@ -294,9 +294,6 @@ def safe_globals():
294
294
if TYPE_CHECKING :
295
295
import optuna
296
296
297
- if is_datasets_available ():
298
- import datasets
299
-
300
297
logger = logging .get_logger (__name__ )
301
298
302
299
@@ -418,14 +415,14 @@ class Trainer:
418
415
def __init__ (
419
416
self ,
420
417
model : Union [PreTrainedModel , nn .Module , None ] = None ,
421
- args : TrainingArguments = None ,
418
+ args : Optional [ TrainingArguments ] = None ,
422
419
data_collator : Optional [DataCollator ] = None ,
423
420
train_dataset : Optional [Union [Dataset , IterableDataset , "datasets.Dataset" ]] = None ,
424
421
eval_dataset : Optional [Union [Dataset , dict [str , Dataset ], "datasets.Dataset" ]] = None ,
425
422
processing_class : Optional [
426
423
Union [PreTrainedTokenizerBase , BaseImageProcessor , FeatureExtractionMixin , ProcessorMixin ]
427
424
] = None ,
428
- model_init : Optional [Callable [[] , PreTrainedModel ]] = None ,
425
+ model_init : Optional [Callable [... , PreTrainedModel ]] = None ,
429
426
compute_loss_func : Optional [Callable ] = None ,
430
427
compute_metrics : Optional [Callable [[EvalPrediction ], dict ]] = None ,
431
428
callbacks : Optional [list [TrainerCallback ]] = None ,
0 commit comments