From 4bdb0dddb74c8276a420eafd85a8a010d9508617 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Mon, 1 Dec 2025 09:54:19 -0800 Subject: [PATCH 1/2] Add support for sage attention 3 in comfyui, enable via new cli arg --use-sage-attiention3 --- comfy/cli_args.py | 1 + comfy/ldm/modules/attention.py | 96 ++++++++++++++++++++++++++++++++++ comfy/model_management.py | 3 ++ 3 files changed, 100 insertions(+) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 5f0dfaa10799..a3c4a6bc6eba 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,7 @@ class LatentPreviewMethod(enum.Enum): attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 7437e05675b5..2dfc55bfaf39 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,18 @@ raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError as e: + if model_management.sage_attention3_enabled(): + if e.name == "sageattn3": + logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell") + else: + raise e + exit(-1) + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -560,6 +572,85 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + if dim_head >= 256 or N <= 2048: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -624,6 +715,9 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +if model_management.sage_attention3_enabled(): + logging.info("Using sage attention 3") + optimized_attention = attention3_sage elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -647,6 +741,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe4b9..c971dd95f74a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,6 +1189,9 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention +def sage_attention3_enabled(): + return args.use_sage_attention3 + def flash_attention_enabled(): return args.use_flash_attention From 648814b7516b369bf5b57a75887c4a3be43228b7 Mon Sep 17 00:00:00 2001 From: Jianqiao Huang Date: Tue, 2 Dec 2025 08:19:40 -0800 Subject: [PATCH 2/2] Fix some bugs found in PR review. The N dimension at which Sage Attention 3 takes effect is reduced to 1024 (although the improvement is not significant at this scale). --- comfy/ldm/modules/attention.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index 2dfc55bfaf39..d51e49da2280 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -574,6 +574,7 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= @wrap_attn def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False if (q.device.type != "cuda" or q.dtype not in (torch.float16, torch.bfloat16) or mask is not None): @@ -599,6 +600,7 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape ) q_s, k_s, v_s = q, k, v N = q.shape[2] + dim_head = D else: B, N, inner_dim = q.shape if inner_dim % heads != 0: @@ -611,27 +613,33 @@ def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape **kwargs ) dim_head = inner_dim // heads - - q_s, k_s, v_s = map( - lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), - (q, k, v), - ) - B, H, L, D = q_s.shape - if dim_head >= 256 or N <= 2048: + if dim_head >= 256 or N <= 1024: return attention_pytorch( q, k, v, heads, mask=mask, attn_precision=attn_precision, - skip_reshape=False, + skip_reshape=skip_reshape, skip_output_reshape=skip_output_reshape, **kwargs ) + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + try: out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) except Exception as e: + exception_fallback = True logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s return attention_pytorch( q, k, v, heads, mask=mask,