Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it necessary to perform layer replacement on te.xx? If not, is it effective to use te.fp8.autocast directly #1556

Open
wangli68 opened this issue Mar 11, 2025 · 2 comments

Comments

@wangli68
Copy link

I did not replace te.xx, but directly replaced it with te.fp8.autocast at the mixed precision position (amp.autocast). Can fp8 be enabled for calculation?If I want to enable FP8 mixed precision calculation, how can I modify it?

for example:
class WanAttentionBlock(nn.Module):

def __init__(self,
             cross_attn_type,
             dim,
             ffn_dim,
             num_heads,
             window_size=(-1, -1),
             qk_norm=True,
             cross_attn_norm=False,
             eps=1e-6):
    super().__init__()
    self.dim = dim
    self.ffn_dim = ffn_dim
    self.num_heads = num_heads
    self.window_size = window_size
    self.qk_norm = qk_norm
    self.cross_attn_norm = cross_attn_norm
    self.eps = eps

    # layers
    self.norm1 = WanLayerNorm(dim, eps)
    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
                                       eps)
    self.norm3 = WanLayerNorm(
        dim, eps,
        elementwise_affine=True) if cross_attn_norm else nn.Identity()
    self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
        dim, num_heads, (-1, -1), qk_norm, eps)
    self.norm2 = WanLayerNorm(dim, eps)
    self.ffn = nn.Sequential(
        torch.nn.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
        torch.nn.Linear(ffn_dim, dim))

    # modulation
    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)

def forward(
    self,
    x,
    e,
    seq_lens,
    grid_sizes,
    freqs,
    context,
    context_lens,
):
    # assert e.dtype == torch.bfloat16
    # with amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
    with te.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe):
        e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
    # assert e[0].dtype == torch.bfloat16

    # self-attention
    y = self.self_attn(
        self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
        freqs)
    # with amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
    with te.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe):
        x = x + y * e[2]

    # cross-attention & ffn function
    def cross_attn_ffn(x, context, context_lens, e):
        x = x + self.cross_attn(self.norm3(x), context, context_lens)
        y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
        # with amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
        with te.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe):
            x = x + y * e[5]
        return x

    x = cross_attn_ffn(x, context, context_lens, e)
    return x
@ksivaman
Copy link
Member

ksivaman commented Mar 11, 2025

te.fp8_autocast and torch AMP work differently and aren't exactly interchangeable (although they can be used together), and thus te.fp8_autocast is being used incorrectly in this script. te.fp8_autocast should be used to wrap the forward pass, and within that it will enable FP8 execution only for Transformer Engine layers. More details in our getting started with FP8 docs. There are a few fixes needed in your script that you could try:

  1. Replacing torch Linear with TE's Linear.
  2. Wrapping the entire cross_attn_func in te.fp8_autocast .

Something like this:

def __init__(self,
             cross_attn_type,
             dim,
             ffn_dim,
             num_heads,
             window_size=(-1, -1),
             qk_norm=True,
             cross_attn_norm=False,
             eps=1e-6):
    super().__init__()
    self.dim = dim
    self.ffn_dim = ffn_dim
    self.num_heads = num_heads
    self.window_size = window_size
    self.qk_norm = qk_norm
    self.cross_attn_norm = cross_attn_norm
    self.eps = eps

    # layers
    self.norm1 = WanLayerNorm(dim, eps)
    self.self_attn = WanSelfAttention(dim, num_heads, window_size, qk_norm,
                                       eps)
    self.norm3 = WanLayerNorm(
        dim, eps,
        elementwise_affine=True) if cross_attn_norm else nn.Identity()
    self.cross_attn = WANX_CROSSATTENTION_CLASSES[cross_attn_type](
        dim, num_heads, (-1, -1), qk_norm, eps)
    self.norm2 = WanLayerNorm(dim, eps)
    self.ffn = nn.Sequential(
        te.Linear(dim, ffn_dim), nn.GELU(approximate='tanh'),
        te.Linear(ffn_dim, dim))

    # modulation
    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)

def forward(
    self,
    x,
    e,
    seq_lens,
    grid_sizes,
    freqs,
    context,
    context_lens,
):
    # assert e.dtype == torch.bfloat16
    # with amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
    with te.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe):
        e = (self.modulation.to(dtype=e.dtype, device=e.device) + e).chunk(6, dim=1)
    # assert e[0].dtype == torch.bfloat16

    # self-attention
    y = self.self_attn(
        self.norm1(x).float() * (1 + e[1]) + e[0], seq_lens, grid_sizes,
        freqs)
    # with amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
    x = x + y * e[2]

    # cross-attention & ffn function
    def cross_attn_ffn(x, context, context_lens, e):
        x = x + self.cross_attn(self.norm3(x), context, context_lens)
        y = self.ffn(self.norm2(x).float() * (1 + e[4]) + e[3])
        # with amp.autocast(dtype=torch.bfloat16, device_type="cuda"):
        x = x + y * e[5]
        return x
    
    with te.fp8_autocast(enabled=True,fp8_recipe=fp8_recipe):
        x = cross_attn_ffn(x, context, context_lens, e)
    return x

Note: Non TE modules are unaffected by te.fp8_autocast.

@wangli68
Copy link
Author

Because I am doing lora fine-tuning on someone else's base model, if torch.nn Change Linear to te Will Linear lead to inconsistent model structure?
I have made the modifications in your way, and the prompt is as follows:
File "/output/Test/DiffSynth-Studio/examples/wanvideo/train_wan_t2v.py", line 204, in add_lora_to_model
model = inject_adapter_in_model(lora_config, model)
File "/openbayes/home/Test/env/wan211/lib/python3.10/site-packages/peft/mapping.py", line 260, in inject_adapter_in_model
peft_model = tuner_cls(model, peft_config, adapter_name=adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
File "/openbayes/home/Test/env/wan211/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 141, in init
super().init(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
File "/openbayes/home/Test/env/wan211/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 184, in init
self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
File "/openbayes/home/Test/env/wan211/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 501, in inject_adapter
self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
File "/openbayes/home/Test/env/wan211/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 235, in _create_and_replace
new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
File "/openbayes/home/Test/env/wan211/lib/python3.10/site-packages/peft/tuners/lora/model.py", line 360, in _create_new_module
raise ValueError(
ValueError: Target module Linear() is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, torch.nn.Conv3d, transformers.pytorch_utils.Conv1D.

Does this mean that the method cannot be applied to Lora fine-tuning?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants