Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ class LatentPreviewMethod(enum.Enum):
attn_group.add_argument("--use-quad-cross-attention", action="store_true", help="Use the sub-quadratic cross attention optimization . Ignored when xformers is used.")
attn_group.add_argument("--use-pytorch-cross-attention", action="store_true", help="Use the new pytorch 2.0 cross attention function.")
attn_group.add_argument("--use-sage-attention", action="store_true", help="Use sage attention.")
attn_group.add_argument("--use-sage-attention3", action="store_true", help="Use sage attention 3.")
attn_group.add_argument("--use-flash-attention", action="store_true", help="Use FlashAttention.")

parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
Expand Down
104 changes: 104 additions & 0 deletions comfy/ldm/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@
raise e
exit(-1)

SAGE_ATTENTION3_IS_AVAILABLE = False
try:
from sageattn3 import sageattn3_blackwell
SAGE_ATTENTION3_IS_AVAILABLE = True
except ImportError as e:
if model_management.sage_attention3_enabled():
if e.name == "sageattn3":
logging.error(f"\n\nTo use the `--use-sage-attention3` feature, the `sageattn3` package must be installed first.\nPlease check https://github.com/thu-ml/SageAttention/tree/main/sageattention3_blackwell")
else:
raise e
exit(-1)

FLASH_ATTENTION_IS_AVAILABLE = False
try:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -560,6 +572,93 @@ def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=
out = out.reshape(b, -1, heads * dim_head)
return out

@wrap_attn
def attention3_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
exception_fallback = False
if (q.device.type != "cuda" or
q.dtype not in (torch.float16, torch.bfloat16) or
mask is not None):
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)

if skip_reshape:
B, H, L, D = q.shape
if H != heads:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=True,
skip_output_reshape=skip_output_reshape,
**kwargs
)
q_s, k_s, v_s = q, k, v
N = q.shape[2]
dim_head = D
else:
B, N, inner_dim = q.shape
if inner_dim % heads != 0:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)
dim_head = inner_dim // heads

if dim_head >= 256 or N <= 1024:
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=skip_reshape,
skip_output_reshape=skip_output_reshape,
**kwargs
)

if not skip_reshape:
q_s, k_s, v_s = map(
lambda t: t.view(B, -1, heads, dim_head).permute(0, 2, 1, 3).contiguous(),
(q, k, v),
)
B, H, L, D = q_s.shape

try:
out = sageattn3_blackwell(q_s, k_s, v_s, is_causal=False)
except Exception as e:
exception_fallback = True
logging.error("Error running SageAttention3: %s, falling back to pytorch attention.", e)

if exception_fallback:
if not skip_reshape:
del q_s, k_s, v_s
return attention_pytorch(
q, k, v, heads,
mask=mask,
attn_precision=attn_precision,
skip_reshape=False,
skip_output_reshape=skip_output_reshape,
**kwargs
)

if skip_reshape:
if not skip_output_reshape:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)
else:
if skip_output_reshape:
pass
else:
out = out.permute(0, 2, 1, 3).reshape(B, L, H * D)

return out

try:
@torch.library.custom_op("flash_attention::flash_attn", mutates_args=())
Expand Down Expand Up @@ -624,6 +723,9 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
if model_management.sage_attention_enabled():
logging.info("Using sage attention")
optimized_attention = attention_sage
if model_management.sage_attention3_enabled():
logging.info("Using sage attention 3")
optimized_attention = attention3_sage
elif model_management.xformers_enabled():
logging.info("Using xformers attention")
optimized_attention = attention_xformers
Expand All @@ -647,6 +749,8 @@ def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
# register core-supported attention functions
if SAGE_ATTENTION_IS_AVAILABLE:
register_attention_function("sage", attention_sage)
if SAGE_ATTENTION3_IS_AVAILABLE:
register_attention_function("sage3", attention3_sage)
if FLASH_ATTENTION_IS_AVAILABLE:
register_attention_function("flash", attention_flash)
if model_management.xformers_enabled():
Expand Down
3 changes: 3 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,6 +1189,9 @@ def unpin_memory(tensor):
def sage_attention_enabled():
return args.use_sage_attention

def sage_attention3_enabled():
return args.use_sage_attention3

def flash_attention_enabled():
return args.use_flash_attention

Expand Down