diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 5f97c6092c..55f5495c53 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -796,29 +796,29 @@ def mock_get_checkpoint_validation_function(): @pytest.mark.gpu -@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) +# @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False]) @pytest.mark.parametrize( 'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp', [ - pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)), - pytest.param( - False, - 'adamw', - 'amp_bf16', - False, - ['rng'], - False, - False, - False, - marks=pytest.mark.world_size(2), - ), - pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)), + # pytest.param( + # False, + # 'adamw', + # 'amp_bf16', + # False, + # ['rng'], + # False, + # False, + # False, + # marks=pytest.mark.world_size(2), + # ), + # pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)), pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)), - pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)), + # pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)), ], ) @pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning') @@ -835,13 +835,13 @@ def test_fsdp_partitioned_state_dict_load( use_symlink: bool, use_tp: bool, use_hsdp: bool, - use_remote, s3_bucket, s3_ephemeral_prefix, request, + use_remote = False, ): - if use_tp: - pytest.skip('TP on PyTorch 2.3 has sharded state dict issues.') + if use_tp and version.parse(torch.__version__) < version.parse('2.4.0'): + pytest.skip('TP has sharded state dict issues before PyTorch 2.4.') if weights_only and autoresume: pytest.skip('Weights only with autoresume is not supported') if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'):