Skip to content

Commit

Permalink
only look at relevant tests
Browse files Browse the repository at this point in the history
  • Loading branch information
eitanturok committed Sep 6, 2024
1 parent 237e7b1 commit 6116c21
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions tests/trainer/test_fsdp_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,29 +796,29 @@ def mock_get_checkpoint_validation_function():


@pytest.mark.gpu
@pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
# @pytest.mark.parametrize('use_remote', [pytest.param(True, marks=pytest.mark.remote), False])
@pytest.mark.parametrize(
'weights_only,optimizer,precision,autoresume,load_ignore_keys,use_symlink,use_tp,use_hsdp',
[
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)),
pytest.param(
False,
'adamw',
'amp_bf16',
False,
['rng'],
False,
False,
False,
marks=pytest.mark.world_size(2),
),
pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)),
# pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
# pytest.param(True, 'adamw', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
# pytest.param(False, 'adam', 'amp_bf16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
# pytest.param(False, 'adamw', 'amp_fp16', False, None, False, False, False, marks=pytest.mark.world_size(2)),
# pytest.param(False, 'adamw', 'amp_bf16', True, None, False, False, False, marks=pytest.mark.world_size(2)),
# pytest.param(
# False,
# 'adamw',
# 'amp_bf16',
# False,
# ['rng'],
# False,
# False,
# False,
# marks=pytest.mark.world_size(2),
# ),
# pytest.param(False, 'adamw', 'amp_bf16', False, None, True, False, False, marks=pytest.mark.world_size(2)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, True, False, marks=pytest.mark.world_size(4)),
pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
# pytest.param(False, 'adamw', 'amp_bf16', False, None, False, False, True, marks=pytest.mark.world_size(4)),
],
)
@pytest.mark.filterwarnings(r'ignore:TypedStorage is deprecated.:UserWarning')
Expand All @@ -835,13 +835,13 @@ def test_fsdp_partitioned_state_dict_load(
use_symlink: bool,
use_tp: bool,
use_hsdp: bool,
use_remote,
s3_bucket,
s3_ephemeral_prefix,
request,
use_remote = False,
):
if use_tp and version.parse(torch.__version__) < version.parse('2.4.0'):
pytest.skip('TP has full state dict issues before PyTorch 2.4.')
pytest.skip('TP has sharded state dict issues before PyTorch 2.4.')
if weights_only and autoresume:
pytest.skip('Weights only with autoresume is not supported')
if (use_tp or use_hsdp) and version.parse(torch.__version__) < version.parse('2.3.0'):
Expand Down

0 comments on commit 6116c21

Please sign in to comment.