-
Notifications
You must be signed in to change notification settings - Fork 69
feat: refactor main_ds.py (3/n) Checkpointer Class #605
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Charlie Doern <[email protected]>
Signed-off-by: Charlie Doern <[email protected]>
5dca802
to
276e9b3
Compare
model_conf from `AutoConfig` has some key info we need in the checkpointer. Associate it with the model class and its subclasses Signed-off-by: Charlie Doern <[email protected]>
E2E (NVIDIA L40S x4) (python 3.11) workflow launched on this PR: View run |
e2e workflow succeeded on this PR: View run, congrats! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments below.
I'm also assuming that the contents of the specific methods (like save_fsdp_lora_model
) is largely unchanged. Is that correct?
print("[None] Skipping checkpointing.") | ||
|
||
# pylint: disable=unused-argument | ||
def save_fsdp_lora_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Potentially for a future PR, but I think it would be cleaner to have a base Checkpointer
abstract class and then have FSDPLoRACheckpointer
, HFFormatAccelerateCheckpointer
, etc. subclasses which each implement their own checkpoint
method. Instead of doing our own custom routing with self._checkpoint_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a similar argument for Model class before... Class hierarchies are exactly meant for such scenarios.
accelerator, | ||
samples_seen, | ||
is_lora=bool(args.lora_r), | ||
checkpointer.save_hf_format_accelerate( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be checkpointer.checkpoint()
@@ -50,11 +50,13 @@ def __init__( | |||
flash_enabled: bool = False, | |||
lora_config: Optional[LoraConfig] = None, | |||
lora_quant_bits: int = 0, | |||
model_conf=None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is model_conf
for? Currently we only seem to be using it to access model_conf.model_type
. Could we just store model_type
instead?
Also currently none of those accesses of model_conf.model_type
check that model_conf is not None
before trying to access the attribute, so this will raise an error if model_conf
is ever actually None
(it's current default value).
checkpointer = Checkpointer( | ||
strategy=strategy, model=m, optimizer=optimizer, accelerator=accelerator | ||
) | ||
checkpointer.load_latest_full_state(Path(args.output_dir)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can args.output_dir
be set in the Checkpointer.__init__
? It seems like we're currently passing it into every load/checkpoint function but it doesn't seem like it's changing values (or should change values) currently.
def save_fsdp_lora_model( | ||
self, | ||
output_dir: Path, | ||
**kwargs, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove kwargs that are not used; transform those used into specific arguments for the information that needs to be passed (with proper names, types etc.)
model: Model, | ||
optimizer: torch.optim.Optimizer, | ||
accelerator: Accelerator, | ||
strategy="all", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make it enum
# pylint: disable=unused-argument | ||
def save_full_state( | ||
self, | ||
output_dir, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
define all args' types.
print("[None] Skipping checkpointing.") | ||
|
||
# pylint: disable=unused-argument | ||
def save_fsdp_lora_model( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made a similar argument for Model class before... Class hierarchies are exactly meant for such scenarios.
|
||
@pytest.fixture | ||
def mock_accelerator(): | ||
accelerator = MagicMock(spec=Accelerator) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
instead of mocks, you could introduce a new subclass for TestAccelerator that would "do nothing" / "do bare minimum" for test purposes. Same for the rest. Why do we have to have mocks just to create an object? Are init methods destructive / invasive? (Maybe it should be fixed then - it should be generally safe / cheap to create objects.)
Introduce a new design for key components of main_ds.py. Namely splitting Model initialization, Accelerator initialization, Optimizer initialization, and Checkpoint saving initialization
into classes:
The Checkpointer class introduces a unified approach to our various checkpointing techniques. A user can pass in their checkpointing style (full_state or hf_format), and the checkpointer, via checkpointer.checkpoint, will save the model using the selected method and other techniques (LoRA).
This PR adds the new class and unit tests for the class
see previous PRs #572 and #594
note: this is probably the last of these large refactor for now with subsequent smaller followup PRs for cleanup.