Skip to content

Commit

Permalink
Check in nemo change and updated packages
Browse files Browse the repository at this point in the history
  • Loading branch information
jstjohn committed Sep 26, 2024
1 parent 3b6be46 commit 144a0c4
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 9 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/NeMo
Submodule NeMo updated 83 files
+19 −9 .github/workflows/cherry-pick-release-commit.yml
+50 −14 .github/workflows/cicd-main.yml
+39 −31 .github/workflows/release-freeze.yml
+10 −0 README.md
+3 −15 examples/asr/asr_cache_aware_streaming/speech_to_text_cache_aware_streaming_infer.py
+1 −11 examples/asr/asr_chunked_inference/aed/speech_to_text_aed_chunked_infer.py
+19 −23 examples/asr/asr_chunked_inference/ctc/speech_to_text_buffered_infer_ctc.py
+4 −3 examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py
+23 −19 examples/asr/asr_vad/speech_to_text_with_vad.py
+6 −11 examples/asr/experimental/sclite/speech_to_text_sclite.py
+6 −15 examples/asr/quantization/speech_to_text_calibrate.py
+6 −12 examples/asr/quantization/speech_to_text_quant_infer.py
+10 −13 examples/asr/quantization/speech_to_text_quant_infer_trt.py
+4 −12 examples/asr/speech_translation/translate_speech.py
+1 −11 examples/asr/transcribe_speech.py
+3 −2 examples/multimodal/multimodal_llm/neva/conf/neva_export.yaml
+1 −0 examples/multimodal/multimodal_llm/neva/conf/neva_trt_infer.yaml
+2 −0 examples/multimodal/multimodal_llm/neva/neva_export.py
+7 −0 examples/multimodal/multimodal_llm/neva/neva_trt_run.py
+7 −13 examples/slu/speech_intent_slot/eval_utils/inference.py
+2 −11 nemo/collections/asr/models/clustering_diarizer.py
+12 −4 nemo/collections/asr/parts/k2/graph_transducer.py
+2 −2 nemo/collections/asr/parts/preprocessing/features.py
+1 −1 nemo/collections/asr/parts/submodules/jasper.py
+2 −10 nemo/collections/asr/parts/utils/asr_confidence_benchmarking_utils.py
+3 −3 nemo/collections/asr/parts/utils/decoder_timestamps_utils.py
+3 −2 nemo/collections/asr/parts/utils/transcribe_utils.py
+1 −11 nemo/collections/asr/parts/utils/vad_utils.py
+4 −2 nemo/collections/audio/modules/masking.py
+4 −4 nemo/collections/audio/modules/transforms.py
+1 −1 nemo/collections/audio/parts/submodules/multichannel.py
+17 −11 nemo/collections/common/tokenizers/huggingface/auto_tokenizer.py
+6 −2 nemo/collections/llm/__init__.py
+1 −1 nemo/collections/llm/peft/lora.py
+0 −0 nemo/collections/llm/t5/__init__.py
+3 −0 nemo/collections/llm/t5/data/__init__.py
+329 −0 nemo/collections/llm/t5/data/pre_training.py
+19 −0 nemo/collections/llm/t5/model/__init__.py
+255 −0 nemo/collections/llm/t5/model/t5.py
+4 −2 nemo/collections/multimodal/data/__init__.py
+13 −15 nemo/collections/multimodal/models/vision_language_foundation/clip/megatron_clip_models.py
+7 −10 nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py
+13 −7 nemo/collections/nlp/parts/nlp_overrides.py
+23 −20 nemo/collections/tts/data/dataset.py
+8 −4 nemo/collections/tts/models/aligner.py
+10 −9 nemo/collections/tts/modules/common.py
+1 −1 nemo/collections/vlm/neva/model/llava.py
+5 −0 nemo/deploy/multimodal/query_multimodal.py
+60 −6 nemo/export/multimodal/build.py
+10 −1 nemo/export/multimodal/run.py
+45 −2 nemo/export/tensorrt_mm_exporter.py
+0 −3 nemo/export/trt_llm/converter/model_converter.py
+3 −2 nemo/export/trt_llm/tensorrt_llm_build.py
+15 −9 nemo/lightning/io/connector.py
+6 −3 nemo/lightning/io/mixin.py
+22 −2 nemo/lightning/io/state.py
+513 −202 nemo/lightning/megatron_parallel.py
+1 −1 nemo/lightning/pytorch/callbacks/model_checkpoint.py
+3 −4 nemo/lightning/pytorch/callbacks/peft.py
+11 −1 nemo/lightning/pytorch/optim/lr_scheduler.py
+19 −12 nemo/lightning/pytorch/plugins/data_sampler.py
+2 −2 nemo/lightning/pytorch/plugins/mixed_precision.py
+18 −11 nemo/lightning/pytorch/strategies/megatron_strategy.py
+66 −2 nemo/lightning/run/plugins.py
+6 −6 nemo/utils/cast_utils.py
+3 −1 nemo/utils/exp_manager.py
+1 −1 nemo/utils/export_utils.py
+4 −11 scripts/asr_context_biasing/eval_greedy_decoding_with_context_biasing.py
+1 −11 scripts/asr_language_modeling/neural_rescorer/eval_neural_rescorer.py
+4 −17 scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_ctc.py
+4 −17 scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer.py
+1 −16 scripts/asr_language_modeling/ngram_lm/eval_wfst_decoding_ctc.py
+38 −1 scripts/deploy/multimodal/deploy_triton.py
+9 −0 scripts/deploy/multimodal/query.py
+3 −5 scripts/export.py
+2 −8 scripts/speech_recognition/confidence/benchmark_asr_confidence.py
+4 −2 tests/collections/llm/gpt/model/megatron_ssm_finetuning.py
+3 −1 tests/collections/llm/gpt/model/megatron_ssm_pretraining.py
+115 −0 tests/collections/llm/gpt_finetuning.py
+141 −0 tests/collections/llm/megatron_t5_pretraining.py
+131 −0 tests/lightning/test_nemo_run.py
+2 −2 tutorials/asr/asr_adapters/Multi_Task_Adapters.ipynb
+42 −42 tutorials/nlp/lora.ipynb
7 changes: 5 additions & 2 deletions sub-packages/bionemo-geneformer/src/bionemo/geneformer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,11 @@ class GeneformerNeMo1LightningModuleConnector(GenericBioBertNeMo1LightningModule
def tokenizer(self):
nemo1_settings = self.get_nemo1_config()
fmt_vocab, vocab_tar_path = nemo1_settings["tokenizer"]["vocab_file"].split(":")
fmt_medians, medians_tar_path = nemo1_settings["data"]["medians_file"].split(":")
assert fmt_vocab == fmt_medians and fmt_vocab == "nemo"
assert fmt_vocab == "nemo"
# TODO add another function to pull out the medians file from a nemo1 checkpoint, if the user wants it.
# It's not needed for checkpoint conversion though.
# fmt_medians, medians_tar_path = nemo1_settings["data"]["medians_file"].split(":")
# assert fmt_vocab == fmt_medians and fmt_vocab == "nemo"
nemo1_path = str(self)
with tarfile.open(nemo1_path, "r") as old_ckpt:
vocab_gene_ens_dict = json.loads(old_ckpt.extractfile(f"./{vocab_tar_path}").readlines()[0])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@
]

MODEL_PRECISION: str = "bf16-mixed"
USE_TE: bool = False # TODO use this for high level decisions around whether we're ready to switch to TE
USE_TE: bool = True # TODO use this for high level decisions around whether we're ready to switch to TE


@pytest.fixture()
Expand All @@ -140,7 +140,7 @@ def geneformer_config():
hidden_dropout=0.02,
init_method_std=0.02,
kv_channels=None,
apply_query_key_layer_scaling=True,
apply_query_key_layer_scaling=False,
make_vocab_size_divisible_by=128,
masked_softmax_fusion=True, # TODO(@jstjohn) check this
fp16_lm_cross_entropy=False,
Expand Down Expand Up @@ -453,16 +453,23 @@ def test_nemo1_checkpoint_conversion(
tmpdir: Path, geneformer_config: GeneformerConfig, cells: List[List[str]], seed: int = 42
):
with megatron_parallel_state_utils.distributed_model_parallel_state(32):
converter = GeneformerNeMo1LightningModuleConnector(nemo1_checkpoint_path)
converter = GeneformerNeMo1LightningModuleConnector(nemo1_release_checkpoint_path)
assert isinstance(converter.tokenizer, GeneTokenizer)
diffs = compare_dataclasses(converter.config, geneformer_config)
skip_fields = {"nemo1_ckpt_path", "return_only_hidden_states", "init_method", "output_layer_init_method"}
filt = [d for d in diffs if d["field"] not in skip_fields]
assert filt == []
out_config = tmpdir / "out_config"
converter.apply(out_config) # currently crashes in here during self.nemo_save(out_path, trainer)
assert io.is_distributed_ckpt(out_config / "checkpoint")
assert False # TODO test weights are the right ones.
assert io.is_distributed_ckpt(out_config / "weights")
geneformer_config_logit = deepcopy(geneformer_config)
# Set up the model to return logits and switch to the released 10M checkpoint
geneformer_config_logit.set_hparam("return_only_hidden_states", False) # return logits
geneformer_config_logit.set_hparam("initial_ckpt_path", str(out_config)) # release checkpoint is important

mean_loss = _get_loss_from_model(geneformer_config_logit, seed)
target: float = 2.368649959564209
assert mean_loss < target or mean_loss == pytest.approx(target, abs=1e-2, rel=None)


@pytest.mark.skipif(USE_TE, reason="This per-layer test does not yet support TE mapping.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def apply(self, output_path: Path) -> Path:
old_weights = torch.load(ckpt_file)
target = self.init()
trainer = self.nemo_setup(target)
target.trainer = trainer
self.convert_state(old_weights, target)
self.nemo_save(output_path, trainer)

Expand All @@ -87,6 +88,7 @@ def is_te_mapping(model: BioBertLightningModule) -> bool:
def convert_state(self, source: Dict[str, torch.Tensor], target: BioBertLightningModule) -> BioBertLightningModule:
"""Convert the input state_dict keys from nemo1 biobert to nemo2 biobert."""
te_mapping = self.is_te_mapping(target) # check for TE layers.
target.module.cpu()
new_state_dict_from_old = {}
for k, v in source.items():
new_key = nemo1_to_nemo2_biobert_key_mapping(k, new_model_prefix="", te_mapping=te_mapping)
Expand Down Expand Up @@ -124,6 +126,7 @@ def config(self) -> BioBertGenericConfig[MegatronBioBertModelT]:
"fp32_residual_connection": False,
"bias_activation_fusion": True,
"bias_dropout_fusion": True,
"apply_query_key_layer_scaling": False,
"share_embeddings_and_output_weights": True,
"fp16": autocast_dtype == torch.float16,
"bf16": autocast_dtype == torch.bfloat16,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ class BioBertGenericConfig(
num_attention_heads: int = 8
num_layers: int = 6
init_method_std: float = 0.02
biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_local_spec
biobert_spec_option: BiobertSpecOption = BiobertSpecOption.bert_layer_with_transformer_engine_spec

# TODO: Move this to better places?
get_attention_mask_from_fusion: bool = False
Expand Down

0 comments on commit 144a0c4

Please sign in to comment.