Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation tutorial for Gemma model #829

Open
wants to merge 264 commits into
base: main
Choose a base branch
from
Open

Conversation

pggPL
Copy link
Collaborator

@pggPL pggPL commented May 1, 2024

Description

I added the tutorials with finetuning and with generation for the Gemma model. Moreover I added few features that
were neccessary to make my tutorials work.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)

Changes

  • Two new notebooks in the docs: one with finetuning for Gemma - analogous to the tutorial with Llama, one with generation for the Gemma,
  • Generalized the kernel for rotary positional encoding to allow the sequences to start with different encoding positions,
  • Added the kernel to effectively save key and value to the kv_cache,
  • Expanded the class InferenceParams - which is responsible for caching k and v,
  • Changed DotProductAttention to run THD attention when there are ragged tensors in kv_cache.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Future work:

TransformerLayer does not support thd and it is a problem. The solutions right now works that way:

  • one need to call setup_before_new_input before forward to indicate the sequence lengths,
  • then one passes forward with TransformerLayer with self_attn_format='thd' and padded sequences with shape bshd,
  • all layers get input is [b,s,*] format, not in [t, *] (including attention)
  • InferenceParams.retrieve_from_kv_cache retrieves key_layer in bshd or ths format depending of inference_params.qkv_format,

As can be seen, it is quite messy workaround. How I think it should be done in the future:

  • TransformerLayer supports thd and we do not need setup_before_new_input at all,
  • InferenceParams store lengths of cached sequences for each layer,
  • for each TransformerLayer invocation, provided sequences are copied to the cache and lengths of cached sequences for this layer are updated,

To do this one will need to remove save_to_kv_cache() kernel and write save_to_kv_cache_sbhd() and save_to_kv_cache_thd() (no bshd, because cache has shape sbhd both for bshd abd sbhd).
Logic of updating sequence lenghts needs to be moved from the setup_before_new_input into save_to_kv_cache.

It is worth noting that we need to take care of backwards compatibility. Right now generation works only for bsdh/sbdh and one needs to manually update self.sequence_len_offset. I think we can write setter which will update statistic for each of the layer when sequence_len_offset will be changed.

If TransformerLayer support of thd will not be added in near future, I propose to write sequence lengths into inference_params.cu_seqlens, note that it is beta (in the future probably cu_seqlens will be added as an argument to the TransformerLayer). Then use TransformerLayer with bsdh. If MultiHeadAttention gets inference_params.cu_seqlens != None, it converts bshd with padding into thd, calls save_to_kv_cache etc. and run DotProductAttention with a thd and then converts output back to the bshd.

@pggPL pggPL marked this pull request as draft May 1, 2024 18:14
@pggPL pggPL force-pushed the Gemma-generation branch 2 times, most recently from 3ff297e to 685827a Compare May 9, 2024 21:40
ksivaman and others added 26 commits May 22, 2024 17:05
Fix 0 grid size

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add LN margin to inference

Signed-off-by: Sangkug Lym <[email protected]>

* cleanup

Signed-off-by: Sangkug Lym <[email protected]>

* Fix symbolic func registration

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix grads

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Sangkug Lym <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Don't use autograd hook for bwd reduction

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* WIP: fp8 v1 fprop integration

Signed-off-by: Charlene Yang <[email protected]>

* WIP: minor fixes

Signed-off-by: Charlene Yang <[email protected]>

* add debug info

Signed-off-by: Charlene Yang <[email protected]>

* add more debug info

Signed-off-by: Charlene Yang <[email protected]>

* fprop working for h1; w/ debug info

Signed-off-by: Charlene Yang <[email protected]>

* WIP: add bprop

Signed-off-by: Charlene Yang <[email protected]>

* cleanup; bprop running but has mismatches

Signed-off-by: Charlene Yang <[email protected]>

* add gitlab frontend as submodule

Signed-off-by: Charlene Yang <[email protected]>

* clean up and add back v0.9.2 FE support; fprop/bprop passing with 5e-2 tols

Signed-off-by: Charlene Yang <[email protected]>

* fix after merge; add bias_b/h to caching descriptor

Signed-off-by: Charlene Yang <[email protected]>

* distinguish fwd/bwd tensor types for bprop

Signed-off-by: Charlene Yang <[email protected]>

* minor fix for F16 cases; include added dqkv_type and d_scale_dp

Signed-off-by: Charlene Yang <[email protected]>

* adjust out shape for bwd in test

Signed-off-by: Charlene Yang <[email protected]>

* add casting from/to FP8 to DPA module

Signed-off-by: Charlene Yang <[email protected]>

* WIP: bshd_bshd_bshd layout

Signed-off-by: Charlene Yang <[email protected]>

* WIP: support all sbhd/bshd layouts

Signed-off-by: Charlene Yang <[email protected]>

* clean up

Signed-off-by: Charlene Yang <[email protected]>

* add qkvpacked and kvpacked support in both FusedAttnFunc and C levels

Signed-off-by: Charlene Yang <[email protected]>

* remove qkvpacked/kvpacked calls in DPA module (used for testing)

Signed-off-by: Charlene Yang <[email protected]>

* remove tp setup; add allow_non_contiguous; update FE; revert to sbh3d in tests; clean up

Signed-off-by: Charlene Yang <[email protected]>

* add NVTE_FP8_DPA_BWD to control whether to use FP8 bwd or F16 bwd

Signed-off-by: Charlene Yang <[email protected]>

* fix MQA

Signed-off-by: Charlene Yang <[email protected]>

* fix MQA/GQA in FP8 v1 API

Signed-off-by: Charlene Yang <[email protected]>

* update FE to 705d8e3, with API change

Signed-off-by: Charlene Yang <[email protected]>

* test causal mask

Signed-off-by: Charlene Yang <[email protected]>

* restrict mha_fill for THD format

Signed-off-by: Charlene Yang <[email protected]>

* fix fused attn with CP and comment out is_alibi code

Signed-off-by: Charlene Yang <[email protected]>

* clean up FE0.9 vs FE1.0 FP8 implementations, and related unit tests

Signed-off-by: Charlene Yang <[email protected]>

* change NVTE_FP8_DPA_BWD default to 1, and fix its use in qkvpacked/kvpacked APIs

Signed-off-by: Charlene Yang <[email protected]>

* fix lint and self.tp_size/group in FusedAttention()

Signed-off-by: Charlene Yang <[email protected]>

* update FE to 6902c94

Signed-off-by: Charlene Yang <[email protected]>

* add FP8 MHA support

Signed-off-by: Charlene Yang <[email protected]>

* update to FE v1.3.0

Signed-off-by: Charlene Yang <[email protected]>

* minor fixes for FP8 MHA with different configs

Signed-off-by: Charlene Yang <[email protected]>

* emit stats regardless of is_training

Signed-off-by: Charlene Yang <[email protected]>

* fix linear when input is not Float8Tensor

Signed-off-by: Charlene Yang <[email protected]>

* fix d_out type when f16 bprop

Signed-off-by: Charlene Yang <[email protected]>

* fix user buffer for layernorm_linear/linear and revert two FP8 casts in MHA

Signed-off-by: Charlene Yang <[email protected]>

* add docstring for fp8_dpa/mha in recipe

Signed-off-by: Charlene Yang <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fix backend selection to avoid FA

Signed-off-by: Charlene Yang <[email protected]>

* replace transpose with transpose_2d

Signed-off-by: Charlene Yang <[email protected]>

* use RMSE for FP8 unit tests

Signed-off-by: Charlene Yang <[email protected]>

* replace two more transpose with transpose_2d

Signed-off-by: Charlene Yang <[email protected]>

* add FP8 initialization to FusedAttention

Signed-off-by: Charlene Yang <[email protected]>

* rm docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Revert "add FP8 initialization to FusedAttention"

This reverts commit 15fffd8.

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Change order of ctxs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* minor fixes

Signed-off-by: Charlene Yang <[email protected]>

* add back docs and mark as beta

Signed-off-by: Charlene Yang <[email protected]>

* minor fixes for tests and docs

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Use torch function as a class method

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* changed TE checkpoint passthrough logic to also recursively look for TE submodules

Signed-off-by: Alp Dener <[email protected]>

* simplified search for TE modules in the checkpointed network

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* fixes; docs

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Check for FP8

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix LoRa-like use cases

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Reviews

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…eporting for potential hangs (NVIDIA#757)

* Improving error reporting and hang detection logic

* Adding verbose error reporting in case of UB hang
* Adding CE hang detector
* Replacing hard-coded timeout with configurable one

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Cleaning up warnings in the code

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Removing unused codes

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Fixing styling issues reported on github

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Addressing lint new line and casting warnings

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Addressing lint warning about the usage of `unsigned long long`

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Removing unused case causing build issues on multi-arch setup

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

* Post GRDCOPY removal cleanup

* Remove cmake check
* Remove unused includes

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>

---------

Signed-off-by: Pasha (Pavel) Shamis <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
fix type checking in checkpointing to assume that there must be TE modules in custom callables

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…ax.jit (NVIDIA#785)

* fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest

Signed-off-by: Alp Dener <[email protected]>

* propagating the fix to the JAX mnist example

Signed-off-by: Alp Dener <[email protected]>

* fixed missing space ibetween flags i QAA scripts

Signed-off-by: Alp Dener <[email protected]>

* added TE warnings into the ignore list

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add NVRTC kernels for cast-transpose

Signed-off-by: Tim Moon <[email protected]>

* Update copyright year

Signed-off-by: Tim Moon <[email protected]>

* Add noop flag to NVRTC cast-transpose kernel

Signed-off-by: Tim Moon <[email protected]>

* Apply suggestions from code review

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
)

* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <[email protected]>

* Debug noop cat func

Signed-off-by: Tim Moon <[email protected]>

* Construct TE modules in tests with correct dtypes

Signed-off-by: Tim Moon <[email protected]>

* Add tolerances to numerical tests

Signed-off-by: Tim Moon <[email protected]>

* Use plain PyTorch concat when exporting to ONNX

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…#780)

* Allow multi-dims for dgamma and dbeta in LN descriptor.

Signed-off-by: Ming Huang <[email protected]>

* Fix the jit error in examples/jax

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Remove unnecessary Pylint overrides

Signed-off-by: Tim Moon <[email protected]>

* Fixes to lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* combined layernorm_geglu with layernorm_gelu into fused_layernorm

Signed-off-by: Phuong Nguyen <[email protected]>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py

Signed-off-by: Phuong Nguyen <[email protected]>

* cleaning and formatting

Signed-off-by: Phuong Nguyen <[email protected]>

* renaming based on reviewers suggestions

Signed-off-by: Phuong Nguyen <[email protected]>

* implemented partial fused layernorm

Signed-off-by: Phuong Nguyen <[email protected]>

* geglu + bias passed tests

Signed-off-by: Phuong Nguyen <[email protected]>

* added partial fused calculation for dbias_1

Signed-off-by: Phuong Nguyen <[email protected]>

* clean up

Co-authored-by: Alp Dener <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: Alp Dener <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Try using global buffer for cu_seqlens

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid using functools.lru_cache

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Added HF Nanotron to integrations and updated GTC 24 video to ondemand link

Signed-off-by: Santosh Bhavani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Implemented swiglu and silu

Signed-off-by: Phuong Nguyen <[email protected]>

* Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* make FusedAttn with CP support bias

Signed-off-by: Xiaowei Ren <[email protected]>

* assert Alibi cannot work with CP

Signed-off-by: Xiaowei Ren <[email protected]>

* syntax fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix variable name

Signed-off-by: Xiaowei Ren <[email protected]>

* fix tensor shapes

Signed-off-by: Xiaowei Ren <[email protected]>

* a typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix bias indexing for CP

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add attn bias tests

Signed-off-by: Xiaowei Ren <[email protected]>

* change dbias update location

Signed-off-by: Xiaowei Ren <[email protected]>

* fix CP test model configs

Signed-off-by: Xiaowei Ren <[email protected]>

* change CP test sequence length

Signed-off-by: Xiaowei Ren <[email protected]>

* make AttnFuncWithCP support qkv format of sbhd

Signed-off-by: Xiaowei Ren <[email protected]>

* make sure qkv are contiguous for CP in cuDNN fused attn

Signed-off-by: Xiaowei Ren <[email protected]>

* change assert message

Signed-off-by: Xiaowei Ren <[email protected]>

* fix code format

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: cyanguwa <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add module level filter for deprecation warning in common

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix module

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
remove tp_size/tp_group as amax reduction is handled by fp8_group()

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…IDIA#799)

restrict context parallel tests to sm80+ as fused/flash attn backends require sm80+

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch

Signed-off-by: Pawel Gadzinski <[email protected]>
@phu0ngng
Copy link
Collaborator

phu0ngng commented Jun 5, 2024

/te-ci pytorch

@sudhakarsingh27
Copy link
Collaborator

/te-ci pytorch

pggPL and others added 7 commits June 7, 2024 10:02
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
max_seqlen_q=max_seqlen_q,
max_seqlen_kv=max_seqlen_kv,
)
if self.attention_type == "self":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if this line+5 are useful

@pluiez
Copy link

pluiez commented Jan 23, 2025

Hi @pggPL , any update on the gemma tutorial?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.