- 
                Notifications
    You must be signed in to change notification settings 
- Fork 452
fix resumption when ref policy loading fails #1110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| Summary of ChangesHello @saurabh111233212, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the stability of job resumption by introducing error handling for the reference policy loading process. Previously, a corrupted or unmaterializable reference policy checkpoint would lead to a job crash. With this change, the system will now catch loading errors, log a warning, and continue execution using the base weights, ensuring more resilient operation. Highlights
 Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either  
 Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a  Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a try-except block to gracefully handle failures when loading the reference policy checkpoint, preventing job crashes due to corrupted or missing files. The change is correct and improves the robustness of the training resumption process. I've added one suggestion to make the exception handling even more robust and to ensure consistency with other parts of the code.
| try: | ||
| state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device) | ||
| if hasattr(self.ref_policy, "module"): | ||
| # If wrapped by DeepSpeed | ||
| self.ref_policy.module.load_state_dict(state_dict) | ||
| else: | ||
| self.ref_policy.load_state_dict(state_dict) | ||
| except (OSError, RuntimeError) as err: | ||
| logger.warning( | ||
| f"{self.rank=}: Failed to load reference policy from " | ||
| f"{self.ref_policy_checkpoint_path}: {err}. Proceeding with base weights." | ||
| ) | ||
| else: | ||
| self.ref_policy.load_state_dict(state_dict) | ||
| logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") | ||
| logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The try-except block is a great addition for robustness. I have a couple of suggestions to make it even more robust:
- 
Broader Exception Handling: The current except (OSError, RuntimeError)might not catch all possible file corruption errors fromtorch.load, such aspickle.UnpicklingErrororzipfile.BadZipFile. Since loading the reference policy is a best-effort operation, it would be safer to catch a broaderExceptionto prevent any loading-related crash.
- 
Consistent weights_onlyparameter: Earlier in this file (lines 625-636),torch.loadis monkey-patched withweights_only=Falseto handle_pickle.UnpicklingError. For consistency and to prevent similar issues, especially as the default forweights_onlymay change in future PyTorch versions, it's a good practice to explicitly setweights_only=Falsein thistorch.loadcall as well.
Here is a suggested change that incorporates these points.
| try: | |
| state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device) | |
| if hasattr(self.ref_policy, "module"): | |
| # If wrapped by DeepSpeed | |
| self.ref_policy.module.load_state_dict(state_dict) | |
| else: | |
| self.ref_policy.load_state_dict(state_dict) | |
| except (OSError, RuntimeError) as err: | |
| logger.warning( | |
| f"{self.rank=}: Failed to load reference policy from " | |
| f"{self.ref_policy_checkpoint_path}: {err}. Proceeding with base weights." | |
| ) | |
| else: | |
| self.ref_policy.load_state_dict(state_dict) | |
| logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") | |
| logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") | |
| try: | |
| state_dict = torch.load(self.ref_policy_checkpoint_path, map_location=self.device, weights_only=False) | |
| if hasattr(self.ref_policy, "module"): | |
| # If wrapped by DeepSpeed | |
| self.ref_policy.module.load_state_dict(state_dict) | |
| else: | |
| self.ref_policy.load_state_dict(state_dict) | |
| except Exception as err: | |
| logger.warning( | |
| f"{self.rank=}: Failed to load reference policy from " | |
| f"{self.ref_policy_checkpoint_path}: {err}. Proceeding with base weights." | |
| ) | |
| else: | |
| logger.info(f"{self.rank=}: Loaded reference policy checkpoint from {self.ref_policy_checkpoint_path}") | 
when the ref policy is corrupted or fails to materialize, the resuming a job crashes. This fixes that