From fecddf5ccba5fc29eb52b2af7b7c6a0e58ce6b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franck=20St=C3=A9phane=20Ndzomga?= <101533724+fsndzomga@users.noreply.github.com> Date: Sat, 20 Sep 2025 21:46:29 +0000 Subject: [PATCH 1/2] Fix SFTConfig max_seq_length compatibility issue with newer TRL versions - Add compatibility layer for TRL versions >= 0.16.0 that use 'max_length' instead of 'max_seq_length' - Dynamically detect which parameter is supported by inspecting SFTConfig signature - Maintains backward compatibility with older TRL versions - Fixes TypeError: SFTConfig.__init__() got an unexpected keyword argument 'max_seq_length' Resolves #8762 --- dspy/clients/lm_local.py | 50 +++++++++++++++++++++++++++------------- 1 file changed, 34 insertions(+), 16 deletions(-) diff --git a/dspy/clients/lm_local.py b/dspy/clients/lm_local.py index c937ee6e5a..3c3f705517 100644 --- a/dspy/clients/lm_local.py +++ b/dspy/clients/lm_local.py @@ -253,25 +253,43 @@ def tokenize_function(example): task_type="CAUSAL_LM", ) - sft_config = SFTConfig( - output_dir=train_kwargs["output_dir"], - num_train_epochs=train_kwargs["num_train_epochs"], - per_device_train_batch_size=train_kwargs["per_device_train_batch_size"], - gradient_accumulation_steps=train_kwargs["gradient_accumulation_steps"], - learning_rate=train_kwargs["learning_rate"], - max_grad_norm=2.0, # note that the current SFTConfig default is 1.0 - logging_steps=20, - warmup_ratio=0.03, - lr_scheduler_type="constant", - save_steps=10_000, - bf16=train_kwargs["bf16"], - max_seq_length=train_kwargs["max_seq_length"], - packing=train_kwargs["packing"], - dataset_kwargs={ # We need to pass dataset_kwargs because we are processing the dataset ourselves + # Handle compatibility between different TRL versions + # TRL >= 0.16.0 uses 'max_length' instead of 'max_seq_length' in SFTConfig + import inspect + sft_config_params = inspect.signature(SFTConfig.__init__).parameters + + # Build config parameters, handling the max_seq_length -> max_length change + config_kwargs = { + "output_dir": train_kwargs["output_dir"], + "num_train_epochs": train_kwargs["num_train_epochs"], + "per_device_train_batch_size": train_kwargs["per_device_train_batch_size"], + "gradient_accumulation_steps": train_kwargs["gradient_accumulation_steps"], + "learning_rate": train_kwargs["learning_rate"], + "max_grad_norm": 2.0, # note that the current SFTConfig default is 1.0 + "logging_steps": 20, + "warmup_ratio": 0.03, + "lr_scheduler_type": "constant", + "save_steps": 10_000, + "bf16": train_kwargs["bf16"], + "packing": train_kwargs["packing"], + "dataset_kwargs": { # We need to pass dataset_kwargs because we are processing the dataset ourselves "add_special_tokens": False, # Special tokens handled by template "append_concat_token": False, # No additional separator needed }, - ) + } + + # Add the sequence length parameter using the appropriate name for the TRL version + if "max_seq_length" in sft_config_params: + # Older TRL versions (< 0.16.0) + config_kwargs["max_seq_length"] = train_kwargs["max_seq_length"] + elif "max_length" in sft_config_params: + # Newer TRL versions (>= 0.16.0) + config_kwargs["max_length"] = train_kwargs["max_seq_length"] + else: + logger.warning("Neither 'max_seq_length' nor 'max_length' parameter found in SFTConfig. " + "This may indicate an incompatible TRL version.") + + sft_config = SFTConfig(**config_kwargs) logger.info("Starting training") trainer = SFTTrainer( From 80a0903b1cfcb3a865ddd290abb4746900da6451 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Franck=20St=C3=A9phane=20Ndzomga?= <101533724+fsndzomga@users.noreply.github.com> Date: Sat, 20 Sep 2025 21:51:17 +0000 Subject: [PATCH 2/2] Fix formatting issues found by ruff - Remove trailing whitespace - Fix line spacing consistency --- dspy/clients/lm_local.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dspy/clients/lm_local.py b/dspy/clients/lm_local.py index 3c3f705517..698c4b44e0 100644 --- a/dspy/clients/lm_local.py +++ b/dspy/clients/lm_local.py @@ -257,7 +257,7 @@ def tokenize_function(example): # TRL >= 0.16.0 uses 'max_length' instead of 'max_seq_length' in SFTConfig import inspect sft_config_params = inspect.signature(SFTConfig.__init__).parameters - + # Build config parameters, handling the max_seq_length -> max_length change config_kwargs = { "output_dir": train_kwargs["output_dir"], @@ -277,18 +277,18 @@ def tokenize_function(example): "append_concat_token": False, # No additional separator needed }, } - + # Add the sequence length parameter using the appropriate name for the TRL version if "max_seq_length" in sft_config_params: # Older TRL versions (< 0.16.0) config_kwargs["max_seq_length"] = train_kwargs["max_seq_length"] elif "max_length" in sft_config_params: - # Newer TRL versions (>= 0.16.0) + # Newer TRL versions (>= 0.16.0) config_kwargs["max_length"] = train_kwargs["max_seq_length"] else: logger.warning("Neither 'max_seq_length' nor 'max_length' parameter found in SFTConfig. " "This may indicate an incompatible TRL version.") - + sft_config = SFTConfig(**config_kwargs) logger.info("Starting training")