Skip to content

Commit

Permalink
add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Sep 25, 2024
1 parent 9c91a2b commit 3b6be46
Showing 1 changed file with 7 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,18 @@ class GenericBioBertNeMo1LightningModuleConnector(

@abstractmethod
def get_config_class(self) -> Type[BioBertGenericConfig[MegatronBioBertModelT]]:
"""Return the class of the config so that it's easier to subclass."""
raise NotImplementedError("Implement me")

def init(self) -> BioBertLightningModule:
"""Initialize the lightning module (no model initialized in it yet)."""
return BioBertLightningModule(
self.config,
self.tokenizer,
)

def apply(self, output_path: Path) -> Path:
"""Save this nemo1 checkpoint to the desired output path in nemo2 format."""
nemo1_path = str(self) # self is a Path object
with tarfile.open(nemo1_path, "r") as old_ckpt:
ckpt_file = old_ckpt.extractfile("./model_weights.ckpt")
Expand All @@ -82,6 +85,7 @@ def is_te_mapping(model: BioBertLightningModule) -> bool:
}

def convert_state(self, source: Dict[str, torch.Tensor], target: BioBertLightningModule) -> BioBertLightningModule:
"""Convert the input state_dict keys from nemo1 biobert to nemo2 biobert."""
te_mapping = self.is_te_mapping(target) # check for TE layers.
new_state_dict_from_old = {}
for k, v in source.items():
Expand All @@ -93,12 +97,11 @@ def convert_state(self, source: Dict[str, torch.Tensor], target: BioBertLightnin
@property
@abstractmethod
def tokenizer(self) -> "AutoTokenizer":
"""Generic method to return a tokenizer, override this for your implemented nemo1 to nemo2 biobert converter."""
raise NotImplementedError("Implement this method")
# from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

# return AutoTokenizer(self.save_hf_tokenizer_assets(str(self)))

def get_nemo1_config(self) -> Dict[str, Any]:
"""Return the nemo1 config from the checkpoint."""
# First read from nemo file and get config settings
nemo1_path = str(self) # self inherits from PosixPath
with tarfile.open(nemo1_path, "r") as old_ckpt:
Expand All @@ -109,6 +112,7 @@ def get_nemo1_config(self) -> Dict[str, Any]:

@property
def config(self) -> BioBertGenericConfig[MegatronBioBertModelT]:
"""Convert and return the nemo2 config from the nemo1 config."""
nemo1_settings = self.get_nemo1_config()
cfg_class = self.get_config_class()
autocast_dtype = get_autocast_dtype(nemo1_settings["precision"])
Expand Down

0 comments on commit 3b6be46

Please sign in to comment.