-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
NeMo1 -> NeMo2 checkpoint conversion #180
Open
jstjohn
wants to merge
26
commits into
main
Choose a base branch
from
jstjohn/nemo1-checkpoint-connector
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
6d17206
to
3b6be46
Compare
/build-ci |
144a0c4
to
11bca4f
Compare
/build-ci |
/build-ci |
akoumpa
reviewed
Sep 27, 2024
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/connector.py
Outdated
Show resolved
Hide resolved
Does this work for you? I don’t think the state dict loading line actually replaces the meta tensor. I think contents are filled in. import torchimport torch.nn as nn# Define a module with parameters initialized as meta tensorsclass MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.linear = nn.Linear(10, 10, device='meta') # Meta device# Instantiate the modulemodule = MyModule()# Print the current parameter device (meta)print(module.linear.weight.device) # Should output "meta"# Create a state_dict with actual tensorsstate_dict = { 'linear.weight': torch.randn(10, 10), 'linear.bias': torch.randn(10)}# Load the state_dict into the modulemodule.load_state_dict(state_dict)# Check that the parameters have been replaced with actual tensorsprint(module.linear.weight.device) # Should output "cpu" or the device of the tensors in state_dictThis prints meta device. And there’s a warning (yikes!) not an error! Sent from my iPhoneOn Sep 26, 2024, at 7:27 PM, Alexandros Koumparoulis ***@***.***> wrote:
@akoumpa commented on this pull request.
In sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/connector.py:
+ def is_te_mapping(model: BioBertLightningModule) -> bool:
+ """Check for TE layers, for now infer this from the config."""
+ return model.config.biobert_spec_option in {
+ BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+ BiobertSpecOption.bert_layer_with_transformer_engine_and_qk_ln_spec,
+ }
+
+ def convert_state(self, source: Dict[str, torch.Tensor], target: BioBertLightningModule) -> BioBertLightningModule:
+ """Convert the input state_dict keys from nemo1 biobert to nemo2 biobert."""
+ te_mapping = self.is_te_mapping(target) # check for TE layers.
+ target.module.cpu()
+ new_state_dict_from_old = {}
+ for k, v in source.items():
+ new_key = nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=te_mapping)
+ new_state_dict_from_old[new_key] = v
+ target.module.load_state_dict(new_state_dict_from_old, strict=not te_mapping)
@jstjohn I would add here something like the following
meta_tensors = list(filter(lambda x: isinstance(x[1], torch.Tensor) and x[1].device.type == 'meta', target.module.state_dict().items()))
assert len(meta_tensors) == 0, meta_tensors
This should print all the tensors that have meta device. The assumption here was that the input state_dict (in this case new_state_dict_from_old) contains all the parameters needed by target.module.
Please let me know if that works.
The other thing would be add a parameter to nemo_setup to init using CPU instead of meta, but I want to avoid this, for large models (e.g. 100B parameters) this takes too long and it's not useful since the initialized parameters will be overwritten with the ones from the checkpoint.
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Haha chat gpt wrote that code and was all like “this will work”. I asked it to run it. It printed meta device. Anyways I manually tried running this in a collab instance since I’m AFK, and I can confirm that it just prints a warning and doesn’t actually fill in the tensors. Sent from my iPhoneOn Sep 26, 2024, at 7:47 PM, John St John ***@***.***> wrote:Does this work for you? I don’t think the state dict loading line actually replaces the meta tensor. I think contents are filled in. import torchimport torch.nn as nn# Define a module with parameters initialized as meta tensorsclass MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.linear = nn.Linear(10, 10, device='meta') # Meta device# Instantiate the modulemodule = MyModule()# Print the current parameter device (meta)print(module.linear.weight.device) # Should output "meta"# Create a state_dict with actual tensorsstate_dict = { 'linear.weight': torch.randn(10, 10), 'linear.bias': torch.randn(10)}# Load the state_dict into the modulemodule.load_state_dict(state_dict)# Check that the parameters have been replaced with actual tensorsprint(module.linear.weight.device) # Should output "cpu" or the device of the tensors in state_dictThis prints meta device. And there’s a warning (yikes!) not an error! Sent from my iPhoneOn Sep 26, 2024, at 7:27 PM, Alexandros Koumparoulis ***@***.***> wrote:
@akoumpa commented on this pull request.
In sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/connector.py:
+ def is_te_mapping(model: BioBertLightningModule) -> bool:
+ """Check for TE layers, for now infer this from the config."""
+ return model.config.biobert_spec_option in {
+ BiobertSpecOption.bert_layer_with_transformer_engine_spec,
+ BiobertSpecOption.bert_layer_with_transformer_engine_and_qk_ln_spec,
+ }
+
+ def convert_state(self, source: Dict[str, torch.Tensor], target: BioBertLightningModule) -> BioBertLightningModule:
+ """Convert the input state_dict keys from nemo1 biobert to nemo2 biobert."""
+ te_mapping = self.is_te_mapping(target) # check for TE layers.
+ target.module.cpu()
+ new_state_dict_from_old = {}
+ for k, v in source.items():
+ new_key = nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=te_mapping)
+ new_state_dict_from_old[new_key] = v
+ target.module.load_state_dict(new_state_dict_from_old, strict=not te_mapping)
@jstjohn I would add here something like the following
meta_tensors = list(filter(lambda x: isinstance(x[1], torch.Tensor) and x[1].device.type == 'meta', target.module.state_dict().items()))
assert len(meta_tensors) == 0, meta_tensors
This should print all the tensors that have meta device. The assumption here was that the input state_dict (in this case new_state_dict_from_old) contains all the parameters needed by target.module.
Please let me know if that works.
The other thing would be add a parameter to nemo_setup to init using CPU instead of meta, but I want to avoid this, for large models (e.g. 100B parameters) this takes too long and it's not useful since the initialized parameters will be overwritten with the ones from the checkpoint.
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: ***@***.***>
|
/build-ci |
/build-ci |
…emo1-checkpoint-connector
…emo1-checkpoint-connector
pstjohn
reviewed
Oct 1, 2024
Yes it’s newer than TOT. It’s on the commit I’m working on on the Nemo side to fix the checkpoint conversion stuff. Sent from my iPhoneOn Sep 30, 2024, at 7:01 PM, Peter St. John ***@***.***> wrote:
@pstjohn commented on this pull request.
On 3rdparty/NeMo:
Is this on TOT? Let's just not downgrade if the dependabot updates bring us to a more recent commit
—Reply to this email directly, view it on GitHub, or unsubscribe.You are receiving this because you were mentioned.Message ID: ***@***.***>
|
pstjohn
reviewed
Oct 1, 2024
sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-geneformer/tests/bionemo/geneformer/test_model.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/connector.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-llm/src/bionemo/llm/model/biobert/connector.py
Outdated
Show resolved
Hide resolved
sub-packages/bionemo-esm2/tests/bionemo/esm2/model/test_model.py
Outdated
Show resolved
Hide resolved
/build-ci |
…emo1-checkpoint-connector
/build-ci |
…emo1-checkpoint-connector
…emo1-checkpoint-connector
jstjohn
commented
Oct 4, 2024
Comment on lines
+26
to
+37
"""Usage: | ||
# ESM2 3B | ||
## ESM2 3b checkpoint conversion: | ||
python scripts/protein/esm2/make_nemo2_checkpoints.py --s3-path s3://bionemo-ci/models/esm2nv_3B_converted.nemo --output-path ~/.cache/bionemo/checkpoints/esm2_3B_nemo2 | ||
## ESM2 3b checkpoint upload (recursive since it is a directory) | ||
aws s3 cp --recursive ~/.cache/bionemo/checkpoints/esm2_3B_nemo2 s3://bionemo-ci/models/esm2_3B_nemo2 | ||
# ESM2 650M | ||
## ESM2 650M checkpoint conversion | ||
python scripts/protein/esm2/make_nemo2_checkpoints.py --s3-path s3://bionemo-ci/models/esm2nv_650M_converted.nemo --output-path ~/.cache/bionemo/checkpoints/esm2_650M_nemo2 | ||
## ESM2 650M checkpoint upload | ||
aws s3 cp --recursive ~/.cache/bionemo/checkpoints/esm2_650M_nemo2 s3://bionemo-ci/models/esm2_650M_nemo2 | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO add the steps to create the .tar.gz
files:
cd esm2_650M_nemo2 && tar czvf ../esm2_650M_nemo2.tar.gz *
for example.
…hn/nemo1-checkpoint-connector
/build-ci |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
Nemo1 to nemo2 checkpoint conversion
Details
Usage
This could either be used in an interactive session or placed in a script to one off convert some specific checkpoint in nemo1 format to a checkpoint in nemo2 format.
Testing
Test checks that the checkpoint can be converted by this function, and then pointing a model at the new nemo2 checkpoint works as expected when doing fine-tuning resumption.
Tests for these changes can be run via:
SKIP_CI
label to your PR?PYTEST_NOT_REQUIRED
label to your PR?JET_NOT_REQUIRED
label to your PR?