Skip to content

feat: refactor main_ds.py (3/n) Checkpointer Class #605

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

cdoern
Copy link
Contributor

@cdoern cdoern commented Jun 10, 2025

Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization
into classes:

  1. Model
  2. Accelerator
  3. Checkpointer

The Checkpointer class introduces a unified approach to our various checkpointing techniques. A user can pass in their checkpointing style (full_state or hf_format), and the checkpointer, via checkpointer.checkpoint, will save the model using the selected method and other techniques (LoRA).

This PR adds the new class and unit tests for the class

see previous PRs #572 and #594

note: this is probably the last of these large refactor for now with subsequent smaller followup PRs for cleanup.

@mergify mergify bot added testing Relates to testing ci-failure labels Jun 10, 2025
@cdoern cdoern force-pushed the refactor-checkpoint branch from 5dca802 to 276e9b3 Compare June 10, 2025 15:15
@mergify mergify bot removed the ci-failure label Jun 10, 2025
model_conf from `AutoConfig` has some key info we need in the checkpointer. Associate it with the model class and its subclasses

Signed-off-by: Charlie Doern <[email protected]>
Copy link

E2E (NVIDIA L40S x4) (python 3.11) workflow launched on this PR: View run

Copy link

e2e workflow succeeded on this PR: View run, congrats!

Copy link
Collaborator

@fynnsu fynnsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments below.

I'm also assuming that the contents of the specific methods (like save_fsdp_lora_model) is largely unchanged. Is that correct?

print("[None] Skipping checkpointing.")

# pylint: disable=unused-argument
def save_fsdp_lora_model(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially for a future PR, but I think it would be cleaner to have a base Checkpointer abstract class and then have FSDPLoRACheckpointer, HFFormatAccelerateCheckpointer, etc. subclasses which each implement their own checkpoint method. Instead of doing our own custom routing with self._checkpoint_fn

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a similar argument for Model class before... Class hierarchies are exactly meant for such scenarios.

accelerator,
samples_seen,
is_lora=bool(args.lora_r),
checkpointer.save_hf_format_accelerate(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be checkpointer.checkpoint()

@@ -50,11 +50,13 @@ def __init__(
flash_enabled: bool = False,
lora_config: Optional[LoraConfig] = None,
lora_quant_bits: int = 0,
model_conf=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is model_conf for? Currently we only seem to be using it to access model_conf.model_type. Could we just store model_type instead?

Also currently none of those accesses of model_conf.model_type check that model_conf is not None before trying to access the attribute, so this will raise an error if model_conf is ever actually None (it's current default value).

checkpointer = Checkpointer(
strategy=strategy, model=m, optimizer=optimizer, accelerator=accelerator
)
checkpointer.load_latest_full_state(Path(args.output_dir))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can args.output_dir be set in the Checkpointer.__init__? It seems like we're currently passing it into every load/checkpoint function but it doesn't seem like it's changing values (or should change values) currently.

def save_fsdp_lora_model(
self,
output_dir: Path,
**kwargs,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove kwargs that are not used; transform those used into specific arguments for the information that needs to be passed (with proper names, types etc.)

model: Model,
optimizer: torch.optim.Optimizer,
accelerator: Accelerator,
strategy="all",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it enum

# pylint: disable=unused-argument
def save_full_state(
self,
output_dir,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

define all args' types.

print("[None] Skipping checkpointing.")

# pylint: disable=unused-argument
def save_fsdp_lora_model(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made a similar argument for Model class before... Class hierarchies are exactly meant for such scenarios.


@pytest.fixture
def mock_accelerator():
accelerator = MagicMock(spec=Accelerator)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of mocks, you could introduce a new subclass for TestAccelerator that would "do nothing" / "do bare minimum" for test purposes. Same for the rest. Why do we have to have mocks just to create an object? Are init methods destructive / invasive? (Maybe it should be fixed then - it should be generally safe / cheap to create objects.)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Relates to testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants