diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 209fc185b296..6becebcb540b 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -112,6 +112,7 @@ class LatentPreviewMethod(enum.Enum): attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.") attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.") attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.") +attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.") attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.") parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.") diff --git a/comfy/ldm/modules/attention.py b/comfy/ldm/modules/attention.py index a8800ded0ca6..d23a753f99ad 100644 --- a/comfy/ldm/modules/attention.py +++ b/comfy/ldm/modules/attention.py @@ -30,6 +30,18 @@ raise e exit(-1) +SAGE_ATTENTION3_IS_AVAILABLE = False +try: + from sageattn3 import sageattn3_blackwell + SAGE_ATTENTION3_IS_AVAILABLE = True +except ImportError as e: + if model_management.sage_attention3_enabled(): + if e.name == "sageattn3": + logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell") + else: + raise e + exit(-1) + FLASH_ATTENTION_IS_AVAILABLE = False try: from flash_attn import flash_attn_func @@ -563,6 +575,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape= out = out.reshape(b, -1, heads * dim_head) return out +@wrap_attn +def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs): + exception_fallback = False + if (q.device.type != "cuda" or + q.dtype not in (torch.float16, torch.bfloat16) or + mask is not None): + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + B, H, L, D = q.shape + if H != heads: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=True, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + q_s, k_s, v_s = q, k, v + N = q.shape[2] + dim_head = D + else: + B, N, inner_dim = q.shape + if inner_dim % heads != 0: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + dim_head = inner_dim // heads + + if dim_head >= 256 or N <= 1024: + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=skip_reshape, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if not skip_reshape: + q_s, k_s, v_s = map( + lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(), + (q, k, v), + ) + B, H, L, D = q_s.shape + + try: + out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False) + except Exception as e: + exception_fallback = True + logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e) + + if exception_fallback: + if not skip_reshape: + del q_s, k_s, v_s + return attention_pytorch( + q, k, v, heads, + mask=mask, + attn_precision=attn_precision, + skip_reshape=False, + skip_output_reshape=skip_output_reshape, + **kwargs + ) + + if skip_reshape: + if not skip_output_reshape: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + else: + if skip_output_reshape: + pass + else: + out = out.permute(0, 2, 1, 3).reshape(B, L, H * D) + + return out try: @torch.library.custom_op("flash_attention::flash_attn", mutates_args=()) @@ -627,6 +726,9 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape if model_management.sage_attention_enabled(): logging.info("Using sage attention") optimized_attention = attention_sage +if model_management.sage_attention3_enabled(): + logging.info("Using sage attention 3") + optimized_attention = attention3_sage elif model_management.xformers_enabled(): logging.info("Using xformers attention") optimized_attention = attention_xformers @@ -650,6 +752,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape # register core-supported attention functions if SAGE_ATTENTION_IS_AVAILABLE: register_attention_function("sage", attention_sage) +if SAGE_ATTENTION3_IS_AVAILABLE: + register_attention_function("sage3", attention3_sage) if FLASH_ATTENTION_IS_AVAILABLE: register_attention_function("flash", attention_flash) if model_management.xformers_enabled(): diff --git a/comfy/model_management.py b/comfy/model_management.py index aeddbaefe4b9..c971dd95f74a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1189,6 +1189,9 @@ def unpin_memory(tensor): def sage_attention_enabled(): return args.use_sage_attention +def sage_attention3_enabled(): + return args.use_sage_attention3 + def flash_attention_enabled(): return args.use_flash_attention