Handle fetch optimizer states for the KV ZCH load state dict case #4512
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
This diff updates
KVZCHCachedData
to hold multiple optimizer states per table in cached_optimizer_states_per_table, and updates apply_state_dict to handle writing out multiple optimizer states per table row to the cache. This is needed for enabling other optimizers to work with SSD TBE, such as Partial Rowwise Adam.There are 4 cases to handle when attempting to fetch the split optimizer states:
self.load_state_dict
isTrue
(i.e. fall back toself._cached_kvzch_data
)self.load_state_dict
isFalse
, andself.enable_optimizer_offloading
is falseself.load_state_dict
isFalse
, andself.enable_optimizer_offloading
isTrue
The diff completes the handling of returning optimizer states for the KV ZCH case, but where
self.load_state_dict
is true (case 2).Reviewed By: emlin
Differential Revision: D77771359