diff --git a/dspy/clients/lm_local.py b/dspy/clients/lm_local.py index c937ee6e5a..698c4b44e0 100644 --- a/dspy/clients/lm_local.py +++ b/dspy/clients/lm_local.py @@ -253,25 +253,43 @@ def tokenize_function(example): task_type="CAUSAL_LM", ) - sft_config = SFTConfig( - output_dir=train_kwargs["output_dir"], - num_train_epochs=train_kwargs["num_train_epochs"], - per_device_train_batch_size=train_kwargs["per_device_train_batch_size"], - gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"], - learning_rate=train_kwargs["learning_rate"], - max_grad_norm=2.0, # note that the current SFTConfig default is 1.0 - logging_steps=20, - warmup_ratio=0.03, - lr_scheduler_type="constant", - save_steps=10_000, - bf16=train_kwargs["bf16"], - max_seq_length=train_kwargs["max_seq_length"], - packing=train_kwargs["packing"], - dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves + # Handle compatibility between different TRL versions + # TRL >= 0.16.0 uses 'max_length' instead of 'max_seq_length' in SFTConfig + import inspect + sft_config_params = inspect.signature(SFTConfig.__init__).parameters + + # Build config parameters, handling the max_seq_length -> max_length change + config_kwargs = { + "output_dir": train_kwargs["output_dir"], + "num_train_epochs": train_kwargs["num_train_epochs"], + "per_device_train_batch_size": train_kwargs["per_device_train_batch_size"], + "gradient_accumulation_steps": train_kwargs["gradient_accumulation_steps"], + "learning_rate": train_kwargs["learning_rate"], + "max_grad_norm": 2.0, # note that the current SFTConfig default is 1.0 + "logging_steps": 20, + "warmup_ratio": 0.03, + "lr_scheduler_type": "constant", + "save_steps": 10_000, + "bf16": train_kwargs["bf16"], + "packing": train_kwargs["packing"], + "dataset_kwargs": { # We need to pass dataset_kwargs because we are processing the dataset ourselves "add_special_tokens": False, # Special tokens handled by template "append_concat_token": False, # No additional separator needed }, - ) + } + + # Add the sequence length parameter using the appropriate name for the TRL version + if "max_seq_length" in sft_config_params: + # Older TRL versions (< 0.16.0) + config_kwargs["max_seq_length"] = train_kwargs["max_seq_length"] + elif "max_length" in sft_config_params: + # Newer TRL versions (>= 0.16.0) + config_kwargs["max_length"] = train_kwargs["max_seq_length"] + else: + logger.warning("Neither 'max_seq_length' nor 'max_length' parameter found in SFTConfig. " + "This may indicate an incompatible TRL version.") + + sft_config = SFTConfig(**config_kwargs) logger.info("Starting training") trainer = SFTTrainer(