Skip to content

Commit 7e593c3

Browse files
cyanguwaOleg-Goncharovphu0ngngrootwdykas
authored
Add num_splits support for FA3 backend (#2380)
* [Common] Deleted unused header (#2324) Deleted unused header Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] L1_jax_distributed_test suit with individual executions (#2321) * L1 rework Signed-off-by: Phuong Nguyen <[email protected]> * comment out test_multi_process_grouped_gemm for now Signed-off-by: Phuong Nguyen <[email protected]> * rm e5m2 from test norm + MXFP8 Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * for branch Signed-off-by: Peter Dykas <[email protected]> * clean up and tests Signed-off-by: Peter Dykas <[email protected]> * change tests Signed-off-by: Peter Dykas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Peter Dykas <[email protected]> * [PyTorch debug] Fixes to debug tests failures (#2268) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix: Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * [PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Fix bug with pre scale bias (#2300) * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> * fix Signed-off-by: Pawel Gadzinski <[email protected]> --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Try to use pre-downloaded dataset artifacts first (#2345) * Try to use pre-downloaded dataset artifacts first Signed-off-by: Jeremy Berchtold <[email protected]> * Set HF_HUB_OFFLINE to disable any network calls to HF when the pre-downloaded dataset is available Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * Fix out of bounds access in the FP4 dequantize kernel (#2346) Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * Make FP8 weights compatible with older MCore version (#2342) * Make cast_master_weights_to_fp8 compatible with older MCore version Signed-off-by: kunlunl <[email protected]> * Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test Signed-off-by: kunlunl <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove redundant _test_mini_optimizer() Signed-off-by: kunlunl <[email protected]> --------- Signed-off-by: kunlunl <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348) * Add test to check jaxpr that amax is reused for nvfp4 recipe Signed-off-by: Jeremy Berchtold <[email protected]> * Move test to test_helper.py and rename file Signed-off-by: Jeremy Berchtold <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * Fix sharding of segment position to match id in ring attention. (#2349) Signed-off-by: Peter Dykas <[email protected]> * Disable cuDNN attention for known IMA and NaNs (#2344) * Fix cuDNN backend selection for more case. Add CG as a option as well Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fix logic Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuDNN checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add more checks Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix cuddn version Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix error message Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Add check for window size Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Default to fused attention in JAX DPA (#2363) * Default to fused attention in JAX DPA Signed-off-by: Kshitij Lakhani <[email protected]> * Consolidate documentation for DPA in JAX Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> * Correctly update the documentation for defaults in JAX DPA Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <[email protected]> --------- Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * Update cudnn frontend to v1.16.0 (#2362) Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287) * code drop Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski <[email protected]> * depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * Move Triton to common (#2359) * move triton to common and change paths Signed-off-by: tdophung <[email protected]> * Formatting Signed-off-by: tdophung <[email protected]> --------- Signed-off-by: tdophung <[email protected]> Signed-off-by: Peter Dykas <[email protected]> * [JAX] Fused layers argument default values changed (#2347) * Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False Signed-off-by: tdophung <[email protected]> * Fixing the failing tests by hard coding arguments to the previous values instead of relying on newer default values Signed-off-by: tdophung <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tdophung <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Peter Dykas <[email protected]> * remove comment from gpt Signed-off-by: Peter Dykas <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor changes for num_splits logic Signed-off-by: Charlene Yang <[email protected]> * replace None with 1 as default Signed-off-by: Charlene Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit Signed-off-by: Charlene Yang <[email protected]> * fix docstring Signed-off-by: Charlene Yang <[email protected]> * fix dtype in pack/unpack when FP8 Signed-off-by: Charlene Yang <[email protected]> * add fused_attn_supported constraint for some tests Signed-off-by: Charlene Yang <[email protected]> * update FA3 installation commands Signed-off-by: Charlene Yang <[email protected]> * update FA3 installation commands in DPA Signed-off-by: Charlene Yang <[email protected]> * separate fused fp8 and f16 flags in tests Signed-off-by: Charlene Yang <[email protected]> * initialize fused_attn_supported_f16 Signed-off-by: Charlene Yang <[email protected]> * fix FA installation in L3 tests Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Oleg Goncharov <[email protected]> Signed-off-by: Peter Dykas <[email protected]> Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]> Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: kunlunl <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Kshitij Lakhani <[email protected]> Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: tdophung <[email protected]> Co-authored-by: Oleg Goncharov <[email protected]> Co-authored-by: Phuong Nguyen <[email protected]> Co-authored-by: root <[email protected]> Co-authored-by: Peter Dykas <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paweł Gadziński <[email protected]> Co-authored-by: jberchtold-nvidia <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Co-authored-by: Kunlun Li <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Michael Goldfarb <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kshitij Lakhani <[email protected]> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: Teddy Do <[email protected]> Co-authored-by: wdykas <[email protected]>
1 parent 1df4a69 commit 7e593c3

File tree

6 files changed

+126
-48
lines changed

6 files changed

+126
-48
lines changed

qa/L3_pytorch_FA_versions_test/test.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ do
3030
# Build Flash Attention
3131
if [ "${fa_version}" \< "3.0.0" ]
3232
then
33-
pip3 install flash-attn==${fa_version}
33+
pip3 install flash-attn==${fa_version} --no-build-isolation
3434
else
3535
git clone https://github.com/Dao-AILab/flash-attention.git
36-
cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install
36+
cd flash-attention/hopper && python setup.py install
3737
python_path=`python -c "import site; print(site.getsitepackages()[0])"`
3838
mkdir -p $python_path/flash_attn_3
39-
wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py
39+
cp flash_attn_interface.py $python_path/flash_attn_3/
4040
cd ../../
4141
fi
4242

tests/pytorch/attention/test_attention.py

Lines changed: 83 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,14 @@ def reset_global_fp8_state():
117117
@pytest.mark.parametrize("swa", [False])
118118
@pytest.mark.parametrize("pad_between_seqs", [False])
119119
def test_dot_product_attention(
120-
dtype, model_configs, model, ckpt_attn, workspace_opt, qkv_layout, swa, pad_between_seqs
120+
dtype,
121+
model_configs,
122+
model,
123+
ckpt_attn,
124+
workspace_opt,
125+
qkv_layout,
126+
swa,
127+
pad_between_seqs,
121128
):
122129
"""Test DotProductAttention module"""
123130

@@ -308,6 +315,31 @@ def test_dpa_max_logit(dtype, model_configs, model, qkv_layout):
308315
test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False)
309316

310317

318+
model_configs_num_splits = {
319+
# test: ModelConfig(b, sq, hq, dqk)
320+
"num_splits_1_0": ModelConfig(2, 2048, 24, 128, num_splits=2),
321+
"num_splits_1_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, num_splits=4),
322+
}
323+
324+
325+
@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.")
326+
@pytest.mark.parametrize("dtype", param_types)
327+
@pytest.mark.parametrize("model_configs", [model_configs_num_splits])
328+
@pytest.mark.parametrize("model", model_configs_num_splits.keys())
329+
def test_dpa_num_splits(dtype, model_configs, model):
330+
"""Test DotProductAttention with FlashAttention-3 num_splits enabled"""
331+
test_dot_product_attention(
332+
dtype,
333+
model_configs,
334+
model,
335+
False,
336+
True,
337+
None,
338+
False,
339+
False,
340+
)
341+
342+
311343
model_configs_softmax = {
312344
# test: ModelConfig(b, sq, hq, dqk)
313345
"softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8),
@@ -1152,6 +1184,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
11521184
core_attention_bias=bias,
11531185
alibi_slopes=alibi_slopes,
11541186
fast_zero_fill=True,
1187+
# Only pass num_splits when exercising the FlashAttention path
1188+
num_splits=config.num_splits if backend == "FlashAttention" else 1,
11551189
)
11561190
max_logit = None
11571191
if config.return_max_logit:
@@ -1786,18 +1820,19 @@ def test_mha_fp8_vs_f16(
17861820
fp8_meta=fp8_meta,
17871821
is_training=is_training,
17881822
)
1789-
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
1790-
if flash_attn_supported + fused_attn_supported < 1:
1823+
flash_attn_supported, fused_attn_supported_fp8, unfused_attn_supported = available_backends
1824+
if flash_attn_supported + fused_attn_supported_fp8 < 1:
17911825
pytest.skip("No FP8 attention backend available.")
1826+
fused_attn_supported_f16 = False
17921827
if not fp8_dpa_bwd:
17931828
available_backends, _, fused_attn_backends = get_available_attention_backends(
17941829
config,
17951830
qkv_dtype=dtype,
17961831
qkv_layout=qkv_format.replace("hd", "h3d"),
17971832
is_training=is_training,
17981833
)
1799-
_, fused_attn_supported, _ = available_backends
1800-
if not fused_attn_supported:
1834+
_, fused_attn_supported_f16, _ = available_backends
1835+
if not fused_attn_supported_f16:
18011836
pytest.skip("No attention backend available.")
18021837

18031838
if flash_attn_supported:
@@ -1809,23 +1844,28 @@ def test_mha_fp8_vs_f16(
18091844
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
18101845
)
18111846

1812-
os.environ["NVTE_FLASH_ATTN"] = "0"
1813-
os.environ["NVTE_FUSED_ATTN"] = "1"
1814-
_attention_backends["backend_selection_requires_update"] = True
1815-
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
1816-
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1817-
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1818-
)
1847+
if fused_attn_supported_fp8:
1848+
os.environ["NVTE_FLASH_ATTN"] = "0"
1849+
os.environ["NVTE_FUSED_ATTN"] = "1"
1850+
_attention_backends["backend_selection_requires_update"] = True
1851+
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True")
1852+
fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16(
1853+
dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1854+
)
18191855

1820-
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
1821-
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
1822-
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1823-
)
1856+
if fused_attn_supported_f16:
1857+
os.environ["NVTE_FLASH_ATTN"] = "0"
1858+
os.environ["NVTE_FUSED_ATTN"] = "1"
1859+
_attention_backends["backend_selection_requires_update"] = True
1860+
logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False")
1861+
fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16(
1862+
dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe
1863+
)
18241864

18251865
atol = 5e-1
18261866
rtol = 5e-1
18271867
rmse_tol = 0.15
1828-
if flash_attn_supported:
1868+
if flash_attn_supported and fused_attn_supported_f16:
18291869
logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:"))
18301870
logging.debug("========== {:^25s} ==========".format("forward output"))
18311871
compare_and_assert(
@@ -1838,32 +1878,33 @@ def test_mha_fp8_vs_f16(
18381878
rmse_tol,
18391879
True,
18401880
)
1841-
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
1842-
logging.debug("========== {:^25s} ==========".format("forward output"))
1843-
compare_and_assert(
1844-
fused_attn_fwd_fp8,
1845-
fused_attn_fwd_f16,
1846-
"fused_attn_fwd_fp8",
1847-
"fused_attn_fwd_f16",
1848-
atol,
1849-
rtol,
1850-
rmse_tol,
1851-
True,
1852-
)
1881+
if fused_attn_supported_fp8 and fused_attn_supported_f16:
1882+
logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:"))
1883+
logging.debug("========== {:^25s} ==========".format("forward output"))
1884+
compare_and_assert(
1885+
fused_attn_fwd_fp8,
1886+
fused_attn_fwd_f16,
1887+
"fused_attn_fwd_fp8",
1888+
"fused_attn_fwd_f16",
1889+
atol,
1890+
rtol,
1891+
rmse_tol,
1892+
True,
1893+
)
18531894

1854-
if is_training:
1855-
for i in range(len(param_names[:1])):
1856-
logging.debug("========== {:^25s} ==========".format(param_names[i]))
1857-
compare_and_assert(
1858-
fused_attn_bwd_fp8[i],
1859-
fused_attn_bwd_f16[i],
1860-
f"fused_attn_bwd_fp8[{i}]",
1861-
f"fused_attn_bwd_f16[{i}]",
1862-
atol,
1863-
rtol,
1864-
rmse_tol,
1865-
True,
1866-
)
1895+
if is_training:
1896+
for i in range(len(param_names[:1])):
1897+
logging.debug("========== {:^25s} ==========".format(param_names[i]))
1898+
compare_and_assert(
1899+
fused_attn_bwd_fp8[i],
1900+
fused_attn_bwd_f16[i],
1901+
f"fused_attn_bwd_fp8[{i}]",
1902+
f"fused_attn_bwd_f16[{i}]",
1903+
atol,
1904+
rtol,
1905+
rmse_tol,
1906+
True,
1907+
)
18671908

18681909

18691910
def _run_mha_fp8_vs_f16(

tests/pytorch/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import os
99
from contextlib import contextmanager
1010
from typing import Optional, Tuple, Dict, Any, List
11+
from packaging.version import Version as PkgVersion
1112

1213
import torch
1314

@@ -210,6 +211,7 @@ def __init__(
210211
max_ctx_len: int = None,
211212
num_layers: int = 1,
212213
eps: float = 1e-5,
214+
num_splits=1,
213215
):
214216
self.batch_size = batch_size
215217
self.max_seqlen_q = max_seqlen_q
@@ -239,6 +241,7 @@ def __init__(
239241
self.max_ctx_len = max_ctx_len
240242
self.num_layers = num_layers
241243
self.eps = eps
244+
self.num_splits = num_splits
242245

243246

244247
@contextmanager
@@ -321,6 +324,9 @@ def test():
321324
inference_params=inference_params,
322325
softmax_type=config.softmax_type,
323326
return_max_logit=config.return_max_logit,
327+
# allow all backends to pass so they can be used for testing;
328+
# check for FA3 availability later
329+
num_splits=1,
324330
)
325331
(
326332
use_flash_attention,
@@ -330,6 +336,10 @@ def test():
330336
use_unfused_attention,
331337
available_backends,
332338
) = get_attention_backend(attention_params)
339+
# Check if FA3 is an available backend when num_splits != 1
340+
if available_backends[0]:
341+
if config.num_splits != 1 and not flash_attention_backend > PkgVersion("3.0.0b"):
342+
available_backends[0] = False
333343
# Set attention.py _attention_backends var using return value
334344
# from get_attention_backend()
335345
_attention_backends["use_flash_attention"] = use_flash_attention

transformer_engine/pytorch/attention/dot_product_attention/backends.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -681,6 +681,7 @@ def forward(
681681
inference_params: Optional[InferenceParams] = None,
682682
flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"),
683683
fp8_output: bool = False,
684+
num_splits: Optional[int] = 1,
684685
) -> torch.Tensor:
685686
"""flash-attn fprop"""
686687

@@ -957,6 +958,7 @@ def forward(
957958
else:
958959
fa_3_optional_forward_kwargs = {}
959960
fa_3_optional_forward_kwargs["window_size"] = window_size
961+
fa_3_optional_forward_kwargs["num_splits"] = num_splits
960962
if inference_params is None:
961963
fa_3_optional_forward_kwargs["deterministic"] = self.deterministic
962964
else:

transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -799,6 +799,7 @@ def forward(
799799
inference_params: Optional[InferenceParams] = None,
800800
pad_between_seqs: Optional[bool] = None,
801801
fp8_output: Optional[bool] = False,
802+
num_splits: Optional[int] = 1,
802803
) -> torch.Tensor:
803804
"""
804805
Dot Product Attention Layer.
@@ -973,6 +974,10 @@ def forward(
973974
If true, there are padding tokens between individual sequences in a packed batch.
974975
fp8_output: Optional[bool], default = `False`
975976
Whether to enforce output to be in FP8 or not.
977+
num_splits: Optional[int], default = 1
978+
Optional split control for FlashAttention-3 only. When set, this value is forwarded
979+
to the FA3 backend to control internal kernel splitting behavior for non-context-parallel
980+
cases. It is ignored for other backends and when context parallelism is enabled.
976981
"""
977982

978983
with self.prepare_forward(
@@ -1315,6 +1320,7 @@ def forward(
13151320
softmax_type=self.softmax_type,
13161321
return_max_logit=self.return_max_logit,
13171322
cuda_graph=is_graph_capturing(),
1323+
num_splits=num_splits,
13181324
)
13191325
global _attention_backends
13201326
if is_in_onnx_export_mode():
@@ -1413,6 +1419,7 @@ def forward(
14131419
inference_params=inference_params,
14141420
flash_attention_backend=flash_attention_backend,
14151421
fp8_output=fp8_output,
1422+
num_splits=num_splits,
14161423
)
14171424

14181425
if use_fused_attention:

transformer_engine/pytorch/attention/dot_product_attention/utils.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class FlashAttentionUtils:
135135
# Please follow these instructions to install FA3
136136
v3_installation_steps = """\
137137
(1) git clone https://github.com/Dao-AILab/flash-attention.git
138-
(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install
138+
(2) cd flash-attention/hopper && python setup.py install
139139
(3) python_path=`python -c "import site; print(site.getsitepackages()[0])"`
140140
(4) mkdir -p $python_path/flash_attn_3
141141
(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py"""
@@ -233,6 +233,8 @@ class AttentionParams:
233233
Whether to output max_logit.
234234
cuda_graph: bool, default = `False`
235235
Whether support for cuda graph capture is needed or not.
236+
num_splits: int, default = 1
237+
The number of kernels to split attention to.
236238
"""
237239

238240
qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor
@@ -263,6 +265,7 @@ class AttentionParams:
263265
softmax_type: str = "vanilla"
264266
return_max_logit: bool = False
265267
cuda_graph: bool = False
268+
num_splits: int = 1
266269

267270
def __eq__(self, other):
268271
"""
@@ -338,6 +341,7 @@ def get_attention_backend(
338341
softmax_type = attention_params.softmax_type
339342
return_max_logit = attention_params.return_max_logit
340343
cuda_graph = attention_params.cuda_graph
344+
num_splits = attention_params.num_splits
341345

342346
# Run config
343347
logger = logging.getLogger("DotProductAttention")
@@ -511,6 +515,18 @@ def get_attention_backend(
511515
use_flash_attention = False
512516
use_fused_attention = False
513517

518+
# Filter: num_splits
519+
if num_splits != 1:
520+
if use_flash_attention_2 and FlashAttentionUtils.is_installed:
521+
logger.debug("Disabling FlashAttention 2 for num_splits")
522+
use_flash_attention_2 = False
523+
if use_fused_attention:
524+
logger.debug("Disabling FusedAttention for num_splits")
525+
use_fused_attention = False
526+
if use_unfused_attention:
527+
logger.debug("Disabling UnfusedDotProductAttention for num_splits")
528+
use_unfused_attention = False
529+
514530
# Filter: Return max_logit
515531
if return_max_logit:
516532
if use_flash_attention:
@@ -1566,8 +1582,9 @@ def _pack_tensor(
15661582
"""
15671583
Packs the given tensor using the `indices`.
15681584
"""
1585+
dtype = tensor.dtype if not isinstance(tensor, Float8Tensor) else torch.uint8
15691586
padding_indice = torch.zeros(
1570-
1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
1587+
1, tensor.shape[1], tensor.shape[2], dtype=dtype, device=tensor.device
15711588
)
15721589
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
15731590
if isinstance(tensor, Float8Tensor):
@@ -1622,8 +1639,9 @@ def _unpack_tensor(
16221639
Inverse of `_pack_tensor`.
16231640
"""
16241641
indices = indices.repeat(1, tensor.shape[1], tensor.shape[2])
1642+
dtype = tensor.dtype if not isinstance(tensor, Float8Tensor) else torch.uint8
16251643
unpacked = torch.zeros(
1626-
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=tensor.dtype, device=tensor.device
1644+
dim0 + 1, tensor.shape[1], tensor.shape[2], dtype=dtype, device=tensor.device
16271645
)
16281646
if isinstance(tensor, Float8Tensor):
16291647
unpacked.scatter_(0, indices, tensor._data)

0 commit comments

Comments
 (0)