diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index f81c6823716..867850923d6 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -19,8 +19,10 @@ import numpy as np import pytest import torch +import transformers from accelerate.utils.memory import release_memory from datasets import Dataset, Features, Image, Value, load_dataset +from packaging.version import Version from parameterized import parameterized from transformers import ( AutoModelForCausalLM, @@ -171,6 +173,8 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name): @parameterized.expand(MODELS_TO_TEST) def test_training_with_transformers_paged(self, model_name): """Test that training works with transformers paged implementation (requires GPU).""" + if Version(transformers.__version__) < Version("4.56.2"): + pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index b26eb299337..47fbd1f5a1f 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. - import pytest +import transformers from datasets import Dataset, features, load_dataset +from packaging.version import Version from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer from transformers.testing_utils import require_peft, require_torch_accelerator, require_vision @@ -421,6 +422,8 @@ def test_generation_config_setup(self): @require_torch_accelerator @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_transformers_paged(self, config_name): + if Version(transformers.__version__) < Version("4.56.2"): + pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") training_args = OnlineDPOConfig( output_dir=self.tmp_dir, per_device_train_batch_size=2,