Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qwen1.5-0.5B failed to save model with huggingface transformers #1482

Open
xinpengzz opened this issue Feb 13, 2025 · 2 comments
Open

qwen1.5-0.5B failed to save model with huggingface transformers #1482

xinpengzz opened this issue Feb 13, 2025 · 2 comments
Labels
bug Something isn't working

Comments

@xinpengzz
Copy link

envs

python 3.11
transformer_engine       1.13.0
transformer_engine_cu12  1.13.0
transformer_engine_torch 1.13.0
transformers             4.46.3
torch 2.5.1

reimplement code

import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM

def convert_model(model, to_transformer_engine=True, _convert_linear=True, _convert_ln=True):
    """
    Recursively converts the linear and layernorm layers of a model to their `transformers_engine` counterpart.
    """
    import transformer_engine.pytorch as te

    for name, module in model.named_children():
        if isinstance(module, nn.Linear) and to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None

            if any(p % 16 != 0 for p in module.weight.shape):
                return
            te_module = te.Linear(
                module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
            )
            te_module.weight.copy_(module.weight)
            if has_bias:
                te_module.bias.copy_(module.bias)

                setattr(model, name, te_module)
        # Note: @xrsrke (Phuc) found that te.LayerNorm doesn't have any real memory savings or speedups over nn.LayerNorm
        elif isinstance(module, nn.LayerNorm) and to_transformer_engine and _convert_ln:
            te_module = te.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
            te_module.weight.copy_(module.weight)
            te_module.bias.copy_(module.bias)

            setattr(model, name, te_module)
        elif isinstance(module, te.Linear) and not to_transformer_engine and _convert_linear:
            has_bias = module.bias is not None
            new_module = nn.Linear(
                module.in_features, module.out_features, bias=has_bias, params_dtype=module.weight.dtype
            )
            new_module.weight.copy_(module.weight)
            if has_bias:
                new_module.bias.copy_(module.bias)

            setattr(model, name, new_module)
        elif isinstance(module, te.LayerNorm) and not to_transformer_engine and _convert_ln:
            new_module = nn.LayerNorm(module.normalized_shape[0], eps=module.eps, params_dtype=module.weight.dtype)
            new_module.weight.copy_(module.weight)
            new_module.bias.copy_(module.bias)

            setattr(model, name, new_module)
        else:
            convert_model(
                module,
                to_transformer_engine=to_transformer_engine,
                _convert_linear=_convert_linear,
                _convert_ln=_convert_ln,
            )

model_path = "Qwen/Qwen1.5-0.5B"

model = AutoModelForCausalLM.from_pretrained(model_path, 
                                             torch_dtype=torch.bfloat16)

with torch.no_grad():
    convert_model(model)

model = model.cuda()
model.train()

model.eval()

state_dict = model.state_dict()

model.save_pretrained("tmp", state_dict=state_dict, safe_serialization=False)

information of the error

[2025-02-13 14:21:14,729] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)
Traceback (most recent call last):
  File "/home/xinpeng/workspace/spui/test/TE/demo.py", line 71, in <module>
    model.save_pretrained("tmp", state_dict=state_dict, safe_serialization=False)
  File "/home/xinpeng/miniforge3/envs/torch/lib/python3.11/site-packages/transformers/modeling_utils.py", line 2971, in save_pretrained
    state_dict_split = split_torch_state_dict_into_shards(
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xinpeng/miniforge3/envs/torch/lib/python3.11/site-packages/huggingface_hub/serialization/_torch.py", line 351, in split_torch_state_dict_into_shards
    return split_state_dict_into_shards_factory(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xinpeng/miniforge3/envs/torch/lib/python3.11/site-packages/huggingface_hub/serialization/_base.py", line 108, in split_state_dict_into_shards_factory
    storage_id = get_storage_id(tensor)
                 ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/xinpeng/miniforge3/envs/torch/lib/python3.11/site-packages/huggingface_hub/serialization/_torch.py", line 403, in get_torch_storage_id
    if tensor.device.type == "meta":
       ^^^^^^^^^^^^^
AttributeError: '_io.BytesIO' object has no attribute 'device'
@timmoon10
Copy link
Collaborator

This bug should be fixed with #1335, which is included in Transformer Engine 2.0.

@timmoon10 timmoon10 added the bug Something isn't working label Feb 19, 2025
@xinpengzz
Copy link
Author

This bug should be fixed with #1335, which is included in Transformer Engine 2.0.

Transformer Engine 2.0 has solved this issue !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants