-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generation tutorial for Gemma model #829
Open
pggPL
wants to merge
264
commits into
NVIDIA:main
Choose a base branch
from
pggPL:Gemma-generation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
pggPL
force-pushed
the
Gemma-generation
branch
2 times, most recently
from
May 9, 2024 21:40
3ff297e
to
685827a
Compare
Fix 0 grid size Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add LN margin to inference Signed-off-by: Sangkug Lym <[email protected]> * cleanup Signed-off-by: Sangkug Lym <[email protected]> * Fix symbolic func registration Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix grads Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Sangkug Lym <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Don't use autograd hook for bwd reduction Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* WIP: fp8 v1 fprop integration Signed-off-by: Charlene Yang <[email protected]> * WIP: minor fixes Signed-off-by: Charlene Yang <[email protected]> * add debug info Signed-off-by: Charlene Yang <[email protected]> * add more debug info Signed-off-by: Charlene Yang <[email protected]> * fprop working for h1; w/ debug info Signed-off-by: Charlene Yang <[email protected]> * WIP: add bprop Signed-off-by: Charlene Yang <[email protected]> * cleanup; bprop running but has mismatches Signed-off-by: Charlene Yang <[email protected]> * add gitlab frontend as submodule Signed-off-by: Charlene Yang <[email protected]> * clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols Signed-off-by: Charlene Yang <[email protected]> * fix after merge; add bias_b/h to caching descriptor Signed-off-by: Charlene Yang <[email protected]> * distinguish fwd/bwd tensor types for bprop Signed-off-by: Charlene Yang <[email protected]> * minor fix for F16 cases; include added dqkv_type and d_scale_dp Signed-off-by: Charlene Yang <[email protected]> * adjust out shape for bwd in test Signed-off-by: Charlene Yang <[email protected]> * add casting from/to FP8 to DPA module Signed-off-by: Charlene Yang <[email protected]> * WIP: bshd_bshd_bshd layout Signed-off-by: Charlene Yang <[email protected]> * WIP: support all sbhd/bshd layouts Signed-off-by: Charlene Yang <[email protected]> * clean up Signed-off-by: Charlene Yang <[email protected]> * add qkvpacked and kvpacked support in both FusedAttnFunc and C levels Signed-off-by: Charlene Yang <[email protected]> * remove qkvpacked/kvpacked calls in DPA module (used for testing) Signed-off-by: Charlene Yang <[email protected]> * remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up Signed-off-by: Charlene Yang <[email protected]> * add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd Signed-off-by: Charlene Yang <[email protected]> * fix MQA Signed-off-by: Charlene Yang <[email protected]> * fix MQA/GQA in FP8 v1 API Signed-off-by: Charlene Yang <[email protected]> * update FE to 705d8e3, with API change Signed-off-by: Charlene Yang <[email protected]> * test causal mask Signed-off-by: Charlene Yang <[email protected]> * restrict mha_fill for THD format Signed-off-by: Charlene Yang <[email protected]> * fix fused attn with CP and comment out is_alibi code Signed-off-by: Charlene Yang <[email protected]> * clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests Signed-off-by: Charlene Yang <[email protected]> * change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs Signed-off-by: Charlene Yang <[email protected]> * fix lint and self.tp_size/group in FusedAttention() Signed-off-by: Charlene Yang <[email protected]> * update FE to 6902c94 Signed-off-by: Charlene Yang <[email protected]> * add FP8 MHA support Signed-off-by: Charlene Yang <[email protected]> * update to FE v1.3.0 Signed-off-by: Charlene Yang <[email protected]> * minor fixes for FP8 MHA with different configs Signed-off-by: Charlene Yang <[email protected]> * emit stats regardless of is_training Signed-off-by: Charlene Yang <[email protected]> * fix linear when input is not Float8Tensor Signed-off-by: Charlene Yang <[email protected]> * fix d_out type when f16 bprop Signed-off-by: Charlene Yang <[email protected]> * fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA Signed-off-by: Charlene Yang <[email protected]> * add docstring for fp8_dpa/mha in recipe Signed-off-by: Charlene Yang <[email protected]> * fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix backend selection to avoid FA Signed-off-by: Charlene Yang <[email protected]> * replace transpose with transpose_2d Signed-off-by: Charlene Yang <[email protected]> * use RMSE for FP8 unit tests Signed-off-by: Charlene Yang <[email protected]> * replace two more transpose with transpose_2d Signed-off-by: Charlene Yang <[email protected]> * add FP8 initialization to FusedAttention Signed-off-by: Charlene Yang <[email protected]> * rm docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Revert "add FP8 initialization to FusedAttention" This reverts commit 15fffd8. Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Change order of ctxs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * minor fixes Signed-off-by: Charlene Yang <[email protected]> * add back docs and mark as beta Signed-off-by: Charlene Yang <[email protected]> * minor fixes for tests and docs Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Use torch function as a class method Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* changed TE checkpoint passthrough logic to also recursively look for TE submodules Signed-off-by: Alp Dener <[email protected]> * simplified search for TE modules in the checkpointed network Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* fixes; docs Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Check for FP8 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix LoRa-like use cases Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Reviews Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…eporting for potential hangs (NVIDIA#757) * Improving error reporting and hang detection logic * Adding verbose error reporting in case of UB hang * Adding CE hang detector * Replacing hard-coded timeout with configurable one Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Cleaning up warnings in the code Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Removing unused codes Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Fixing styling issues reported on github Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Addressing lint new line and casting warnings Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Addressing lint warning about the usage of `unsigned long long` Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Removing unused case causing build issues on multi-arch setup Signed-off-by: Pasha (Pavel) Shamis <[email protected]> * Post GRDCOPY removal cleanup * Remove cmake check * Remove unused includes Signed-off-by: Pasha (Pavel) Shamis <[email protected]> --------- Signed-off-by: Pasha (Pavel) Shamis <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
fix type checking in checkpointing to assume that there must be TE modules in custom callables Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…ax.jit (NVIDIA#785) * fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest Signed-off-by: Alp Dener <[email protected]> * propagating the fix to the JAX mnist example Signed-off-by: Alp Dener <[email protected]> * fixed missing space ibetween flags i QAA scripts Signed-off-by: Alp Dener <[email protected]> * added TE warnings into the ignore list Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add NVRTC kernels for cast-transpose Signed-off-by: Tim Moon <[email protected]> * Update copyright year Signed-off-by: Tim Moon <[email protected]> * Add noop flag to NVRTC cast-transpose kernel Signed-off-by: Tim Moon <[email protected]> * Apply suggestions from code review Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon <[email protected]> * Debug noop cat func Signed-off-by: Tim Moon <[email protected]> * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon <[email protected]> * Add tolerances to numerical tests Signed-off-by: Tim Moon <[email protected]> * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…#780) * Allow multi-dims for dgamma and dbeta in LN descriptor. Signed-off-by: Ming Huang <[email protected]> * Fix the jit error in examples/jax Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Remove unnecessary Pylint overrides Signed-off-by: Tim Moon <[email protected]> * Fixes to lint Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by: Phuong Nguyen <[email protected]> * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by: Phuong Nguyen <[email protected]> * cleaning and formatting Signed-off-by: Phuong Nguyen <[email protected]> * renaming based on reviewers suggestions Signed-off-by: Phuong Nguyen <[email protected]> * implemented partial fused layernorm Signed-off-by: Phuong Nguyen <[email protected]> * geglu + bias passed tests Signed-off-by: Phuong Nguyen <[email protected]> * added partial fused calculation for dbias_1 Signed-off-by: Phuong Nguyen <[email protected]> * clean up Co-authored-by: Alp Dener <[email protected]> Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: Alp Dener <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Try using global buffer for cu_seqlens Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Avoid using functools.lru_cache Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Added HF Nanotron to integrations and updated GTC 24 video to ondemand link Signed-off-by: Santosh Bhavani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Implemented swiglu and silu Signed-off-by: Phuong Nguyen <[email protected]> * Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* make FusedAttn with CP support bias Signed-off-by: Xiaowei Ren <[email protected]> * assert Alibi cannot work with CP Signed-off-by: Xiaowei Ren <[email protected]> * syntax fix Signed-off-by: Xiaowei Ren <[email protected]> * fix variable name Signed-off-by: Xiaowei Ren <[email protected]> * fix tensor shapes Signed-off-by: Xiaowei Ren <[email protected]> * a typo fix Signed-off-by: Xiaowei Ren <[email protected]> * fix bias indexing for CP Signed-off-by: Xiaowei Ren <[email protected]> * bug fix Signed-off-by: Xiaowei Ren <[email protected]> * add attn bias tests Signed-off-by: Xiaowei Ren <[email protected]> * change dbias update location Signed-off-by: Xiaowei Ren <[email protected]> * fix CP test model configs Signed-off-by: Xiaowei Ren <[email protected]> * change CP test sequence length Signed-off-by: Xiaowei Ren <[email protected]> * make AttnFuncWithCP support qkv format of sbhd Signed-off-by: Xiaowei Ren <[email protected]> * make sure qkv are contiguous for CP in cuDNN fused attn Signed-off-by: Xiaowei Ren <[email protected]> * change assert message Signed-off-by: Xiaowei Ren <[email protected]> * fix code format Signed-off-by: Xiaowei Ren <[email protected]> --------- Signed-off-by: Xiaowei Ren <[email protected]> Co-authored-by: cyanguwa <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add support for MoE with FP8. Signed-off-by: Dennis Liu <[email protected]> * Fix unittest. Signed-off-by: Dennis Liu <[email protected]> * Fix error in linear backward. Signed-off-by: Dennis Liu <[email protected]> --------- Signed-off-by: Dennis Liu <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add module level filter for deprecation warning in common Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix module Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
remove tp_size/tp_group as amax reduction is handled by fp8_group() Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…IDIA#799) restrict context parallel tests to sm80+ as fused/flash attn backends require sm80+ Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
/te-ci pytorch |
Signed-off-by: Pawel Gadzinski <[email protected]>
/te-ci pytorch |
cyanguwa
reviewed
Jun 5, 2024
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
/te-ci pytorch |
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <[email protected]>
max_seqlen_q=max_seqlen_q, | ||
max_seqlen_kv=max_seqlen_kv, | ||
) | ||
if self.attention_type == "self": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if this line+5 are useful
Hi @pggPL , any update on the gemma tutorial? |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
I added the tutorials with finetuning and with generation for the Gemma model. Moreover I added few features that
were neccessary to make my tutorials work.
Type of change
Changes
InferenceParams
- which is responsible for caching k and v,Checklist:
Future work:
TransformerLayer
does not support thd and it is a problem. The solutions right now works that way:setup_before_new_input
before forward to indicate the sequence lengths,self_attn_format='thd'
and padded sequences with shapebshd
,InferenceParams.retrieve_from_kv_cache
retrieves key_layer in bshd or ths format depending ofinference_params.qkv_format
,As can be seen, it is quite messy workaround. How I think it should be done in the future:
setup_before_new_input
at all,InferenceParams
store lengths of cached sequences for each layer,To do this one will need to remove
save_to_kv_cache()
kernel and writesave_to_kv_cache_sbhd()
andsave_to_kv_cache_thd()
(no bshd, because cache has shape sbhd both for bshd abd sbhd).Logic of updating sequence lenghts needs to be moved from the
setup_before_new_input
intosave_to_kv_cache
.It is worth noting that we need to take care of backwards compatibility. Right now generation works only for bsdh/sbdh and one needs to manually update
self.sequence_len_offset
. I think we can write setter which will update statistic for each of the layer whensequence_len_offset
will be changed.If
TransformerLayer
support of thd will not be added in near future, I propose to write sequence lengths intoinference_params.cu_seqlens
, note that it is beta (in the future probably cu_seqlens will be added as an argument to theTransformerLayer
). Then useTransformerLayer
with bsdh. IfMultiHeadAttention
getsinference_params.cu_seqlens != None
, it convertsbshd
with padding intothd
, callssave_to_kv_cache
etc. and runDotProductAttention
with athd
and then converts output back to thebshd
.