From 863d4b2159f562cb26714b7a7c62b48f05830dcb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=A3=E6=BA=90?= Date: Thu, 3 Apr 2025 16:58:43 +0800 Subject: [PATCH 1/5] add swa --- videox_fun/models/wan_transformer3d.py | 38 +++++++++++++++++++++----- 1 file changed, 31 insertions(+), 7 deletions(-) diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index e9f7445c..8c3f7678 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -37,6 +37,8 @@ except ModuleNotFoundError: FLASH_ATTN_2_AVAILABLE = False +from einops import rearrange + def flash_attention( q, @@ -290,7 +292,8 @@ def __init__(self, num_heads, window_size=(-1, -1), qk_norm=True, - eps=1e-6): + eps=1e-6, + bidx=0): assert dim % num_heads == 0 super().__init__() self.dim = dim @@ -299,6 +302,7 @@ def __init__(self, self.window_size = window_size self.qk_norm = qk_norm self.eps = eps + self.bidx = bidx # layers self.q = nn.Linear(dim, dim) @@ -326,13 +330,30 @@ def qkv_fn(x): return q, k, v q, k, v = qkv_fn(x) + f, h, w = grid_sizes.tolist()[0] + q = rope_apply(q, grid_sizes, freqs).to(dtype) + k=rope_apply(k, grid_sizes, freqs).to(dtype) + v = v.to(dtype) + + q = rearrange(q, 'b (f h w) n d -> b (h w f) n d', f=f, h=h, w=w) + k = rearrange(k, 'b (f h w) n d -> b (h w f) n d', f=f, h=h, w=w) + v = rearrange(v, 'b (f h w) n d -> b (h w f) n d', f=f, h=h, w=w) + x = attention( - q=rope_apply(q, grid_sizes, freqs).to(dtype), - k=rope_apply(k, grid_sizes, freqs).to(dtype), - v=v.to(dtype), + q=q, + k=k, + v=v, k_lens=seq_lens, window_size=self.window_size) + x = rearrange(x, 'b (h w f) n d -> b (f h w) n d', f=f, h=h, w=w) + + # x = attention( + # q=rope_apply(q, grid_sizes, freqs).to(dtype), + # k=rope_apply(k, grid_sizes, freqs).to(dtype), + # v=v.to(dtype), + # k_lens=seq_lens, + # window_size=self.window_size) x = x.to(dtype) # output @@ -426,7 +447,8 @@ def __init__(self, window_size=(-1, -1), qk_norm=True, cross_attn_norm=False, - eps=1e-6): + eps=1e-6, + bidx=0): super().__init__() self.dim = dim self.ffn_dim = ffn_dim @@ -435,11 +457,13 @@ def __init__(self, self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps + if (bidx + 1)%5!=0: + window_size = (4096, 4096) # layers self.norm1 = WanLayerNorm(dim, eps) self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm, - eps) + eps, bidx=bidx) self.norm3 = WanLayerNorm( dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity() @@ -645,7 +669,7 @@ def __init__( cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps) + window_size, qk_norm, cross_attn_norm, eps, bidx=_) for _ in range(num_layers) ]) From ab8b273aeb0dbfd39ed6752fb44dd087b04fb0a0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=A3=E6=BA=90?= Date: Tue, 8 Apr 2025 08:25:59 +0800 Subject: [PATCH 2/5] add swa --- videox_fun/models/wan_transformer3d.py | 55 ++++++++++++++++++-------- 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index 8c3f7678..74a37518 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -335,25 +335,46 @@ def qkv_fn(x): k=rope_apply(k, grid_sizes, freqs).to(dtype) v = v.to(dtype) - q = rearrange(q, 'b (f h w) n d -> b (h w f) n d', f=f, h=h, w=w) - k = rearrange(k, 'b (f h w) n d -> b (h w f) n d', f=f, h=h, w=w) - v = rearrange(v, 'b (f h w) n d -> b (h w f) n d', f=f, h=h, w=w) - + qs = torch.tensor_split(q.to(torch.bfloat16), 2, 2) + ks = torch.tensor_split(k.to(torch.bfloat16), 2, 2) + vs = torch.tensor_split(v.to(torch.bfloat16), 2, 2) + + new_querys = [] + new_keys = [] + new_values = [] + for index, mode in enumerate( + [ + "bs (f h w) hn hd -> bs (h w f) hn hd", + "bs (f h w) hn hd -> bs (w h f) hn hd" + ] + ): + + new_querys.append(rearrange(qs[index], mode, f=f, h=h, w=w)) + new_keys.append(rearrange(ks[index], mode, f=f, h=h, w=w)) + new_values.append(rearrange(vs[index], mode, f=f, h=h, w=w)) + q = torch.cat(new_querys, dim=2) + k = torch.cat(new_keys, dim=2) + v = torch.cat(new_values, dim=2) x = attention( q=q, k=k, v=v, k_lens=seq_lens, - window_size=self.window_size) - x = rearrange(x, 'b (h w f) n d -> b (f h w) n d', f=f, h=h, w=w) - - # x = attention( - # q=rope_apply(q, grid_sizes, freqs).to(dtype), - # k=rope_apply(k, grid_sizes, freqs).to(dtype), - # v=v.to(dtype), - # k_lens=seq_lens, - # window_size=self.window_size) + window_size=self.window_size + ) + + hidden_states = torch.tensor_split(x, 2, 2) + new_hidden_states = [] + for index, mode in enumerate( + [ + "bs (h w f) hn hd -> bs (f h w) hn hd", + "bs (w h f) hn hd -> bs (f h w) hn hd" + ] + ): + new_hidden_states.append(rearrange(hidden_states[index], mode, f=f, h=h, w=w)) + x = torch.cat(new_hidden_states, dim=2) + x = x.to(dtype) # output @@ -448,7 +469,8 @@ def __init__(self, qk_norm=True, cross_attn_norm=False, eps=1e-6, - bidx=0): + bidx=0, + swa=False): super().__init__() self.dim = dim self.ffn_dim = ffn_dim @@ -457,7 +479,7 @@ def __init__(self, self.qk_norm = qk_norm self.cross_attn_norm = cross_attn_norm self.eps = eps - if (bidx + 1)%5!=0: + if (bidx + 1)%5!=0 and swa: window_size = (4096, 4096) # layers @@ -597,6 +619,7 @@ def __init__( eps=1e-6, in_channels=16, hidden_size=2048, + swa=False, ): r""" Initialize the diffusion model backbone. @@ -669,7 +692,7 @@ def __init__( cross_attn_type = 't2v_cross_attn' if model_type == 't2v' else 'i2v_cross_attn' self.blocks = nn.ModuleList([ WanAttentionBlock(cross_attn_type, dim, ffn_dim, num_heads, - window_size, qk_norm, cross_attn_norm, eps, bidx=_) + window_size, qk_norm, cross_attn_norm, eps, bidx=_, swa=swa) for _ in range(num_layers) ]) From 678affdc3bbb585f2d142466230717100b94f5c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=A3=E6=BA=90?= Date: Tue, 8 Apr 2025 08:36:02 +0800 Subject: [PATCH 3/5] add swa --- videox_fun/models/wan_transformer3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index 74a37518..b9ea51cd 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -332,7 +332,7 @@ def qkv_fn(x): q, k, v = qkv_fn(x) f, h, w = grid_sizes.tolist()[0] q = rope_apply(q, grid_sizes, freqs).to(dtype) - k=rope_apply(k, grid_sizes, freqs).to(dtype) + k = rope_apply(k, grid_sizes, freqs).to(dtype) v = v.to(dtype) qs = torch.tensor_split(q.to(torch.bfloat16), 2, 2) From 2793a77337c6a870b9405474c195479c5184f15b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=A3=E6=BA=90?= Date: Thu, 24 Apr 2025 16:09:44 +0800 Subject: [PATCH 4/5] using multi direction swa --- videox_fun/models/wan_transformer3d.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index f1779f3f..adfddf5b 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -404,7 +404,11 @@ def qkv_fn(x): for index, mode in enumerate( [ "bs (f h w) hn hd -> bs (h w f) hn hd", - "bs (f h w) hn hd -> bs (w h f) hn hd" + "bs (f h w) hn hd -> bs (w h f) hn hd", + "bs (f h w) hn hd -> bs (h f w) hn hd", + "bs (f h w) hn hd -> bs (w f h) hn hd", + "bs (f h w) hn hd -> bs (f h w) hn hd", + "bs (f h w) hn hd -> bs (f w h) hn hd", ] ): @@ -428,7 +432,11 @@ def qkv_fn(x): for index, mode in enumerate( [ "bs (h w f) hn hd -> bs (f h w) hn hd", - "bs (w h f) hn hd -> bs (f h w) hn hd" + "bs (w h f) hn hd -> bs (f h w) hn hd", + "bs (h f w) hn hd -> bs (f h w) hn hd", + "bs (w f h) hn hd -> bs (f h w) hn hd", + "bs (f h w) hn hd -> bs (f h w) hn hd", + "bs (f w h) hn hd -> bs (f h w) hn hd", ] ): new_hidden_states.append(rearrange(hidden_states[index], mode, f=f, h=h, w=w)) From f9cdb624163fa5d9639f02a83daba971109b1a22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=AE=A3=E6=BA=90?= Date: Thu, 24 Apr 2025 16:10:18 +0800 Subject: [PATCH 5/5] using multi direction swa --- videox_fun/models/wan_transformer3d.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/videox_fun/models/wan_transformer3d.py b/videox_fun/models/wan_transformer3d.py index adfddf5b..e98a5ceb 100755 --- a/videox_fun/models/wan_transformer3d.py +++ b/videox_fun/models/wan_transformer3d.py @@ -394,9 +394,9 @@ def qkv_fn(x): k = rope_apply(k, grid_sizes, freqs).to(dtype) v = v.to(dtype) - qs = torch.tensor_split(q.to(torch.bfloat16), 2, 2) - ks = torch.tensor_split(k.to(torch.bfloat16), 2, 2) - vs = torch.tensor_split(v.to(torch.bfloat16), 2, 2) + qs = torch.tensor_split(q.to(torch.bfloat16), 6, 2) + ks = torch.tensor_split(k.to(torch.bfloat16), 6, 2) + vs = torch.tensor_split(v.to(torch.bfloat16), 6, 2) new_querys = [] new_keys = [] @@ -427,7 +427,7 @@ def qkv_fn(x): window_size=self.window_size ) - hidden_states = torch.tensor_split(x, 2, 2) + hidden_states = torch.tensor_split(x, 6, 2) new_hidden_states = [] for index, mode in enumerate( [