Skip to content

Commit 0ed3049

Browse files
authored
Merge pull request #196 from huggingface/main
Merge changes
2 parents d18e6ae + 233dffd commit 0ed3049

20 files changed

+1174
-25
lines changed

.github/workflows/push_tests_mps.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
shell: arch -arch arm64 bash {0}
4747
run: |
4848
${CONDA_RUN} python -m pip install --upgrade pip uv
49-
${CONDA_RUN} python -m uv pip install -e [quality,test]
49+
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
5050
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
5151
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
5252
${CONDA_RUN} python -m uv pip install transformers --upgrade

docs/source/en/quantization/gguf.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,11 @@ transformer = FluxTransformer2DModel.from_single_file(
4545
pipe = FluxPipeline.from_pretrained(
4646
"black-forest-labs/FLUX.1-dev",
4747
transformer=transformer,
48-
generator=torch.manual_seed(0),
4948
torch_dtype=torch.bfloat16,
5049
)
5150
pipe.enable_model_cpu_offload()
5251
prompt = "A cat holding a sign that says hello world"
53-
image = pipe(prompt).images[0]
52+
image = pipe(prompt, generator=torch.manual_seed(0)).images[0]
5453
image.save("flux-gguf.png")
5554
```
5655

docs/source/en/quantization/overview.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ If you are new to the quantization field, we recommend you to check out these be
3333
## When to use what?
3434

3535
Diffusers currently supports the following quantization methods.
36-
- [BitsandBytes](./bitsandbytes.md)
37-
- [TorchAO](./torchao.md)
38-
- [GGUF](./gguf.md)
36+
- [BitsandBytes](./bitsandbytes)
37+
- [TorchAO](./torchao)
38+
- [GGUF](./gguf)
3939

4040
[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import argparse
2+
from contextlib import nullcontext
3+
4+
import safetensors.torch
5+
from accelerate import init_empty_weights
6+
from huggingface_hub import hf_hub_download
7+
8+
from diffusers.utils.import_utils import is_accelerate_available, is_transformers_available
9+
10+
11+
if is_transformers_available():
12+
from transformers import CLIPVisionModelWithProjection
13+
14+
vision = True
15+
else:
16+
vision = False
17+
18+
"""
19+
python scripts/convert_flux_xlabs_ipadapter_to_diffusers.py \
20+
--original_state_dict_repo_id "XLabs-AI/flux-ip-adapter" \
21+
--filename "flux-ip-adapter.safetensors"
22+
--output_path "flux-ip-adapter-hf/"
23+
"""
24+
25+
26+
CTX = init_empty_weights if is_accelerate_available else nullcontext
27+
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
30+
parser.add_argument("--filename", default="flux.safetensors", type=str)
31+
parser.add_argument("--checkpoint_path", default=None, type=str)
32+
parser.add_argument("--output_path", type=str)
33+
parser.add_argument("--vision_pretrained_or_path", default="openai/clip-vit-large-patch14", type=str)
34+
35+
args = parser.parse_args()
36+
37+
38+
def load_original_checkpoint(args):
39+
if args.original_state_dict_repo_id is not None:
40+
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=args.filename)
41+
elif args.checkpoint_path is not None:
42+
ckpt_path = args.checkpoint_path
43+
else:
44+
raise ValueError(" please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
45+
46+
original_state_dict = safetensors.torch.load_file(ckpt_path)
47+
return original_state_dict
48+
49+
50+
def convert_flux_ipadapter_checkpoint_to_diffusers(original_state_dict, num_layers):
51+
converted_state_dict = {}
52+
53+
# image_proj
54+
## norm
55+
converted_state_dict["image_proj.norm.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
56+
converted_state_dict["image_proj.norm.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
57+
## proj
58+
converted_state_dict["image_proj.proj.weight"] = original_state_dict.pop("ip_adapter_proj_model.norm.weight")
59+
converted_state_dict["image_proj.proj.bias"] = original_state_dict.pop("ip_adapter_proj_model.norm.bias")
60+
61+
# double transformer blocks
62+
for i in range(num_layers):
63+
block_prefix = f"ip_adapter.{i}."
64+
# to_k_ip
65+
converted_state_dict[f"{block_prefix}to_k_ip.bias"] = original_state_dict.pop(
66+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.bias"
67+
)
68+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
69+
f"double_blocks.{i}.processor.ip_adapter_double_stream_k_proj.weight"
70+
)
71+
# to_v_ip
72+
converted_state_dict[f"{block_prefix}to_v_ip.bias"] = original_state_dict.pop(
73+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.bias"
74+
)
75+
converted_state_dict[f"{block_prefix}to_k_ip.weight"] = original_state_dict.pop(
76+
f"double_blocks.{i}.processor.ip_adapter_double_stream_v_proj.weight"
77+
)
78+
79+
return converted_state_dict
80+
81+
82+
def main(args):
83+
original_ckpt = load_original_checkpoint(args)
84+
85+
num_layers = 19
86+
converted_ip_adapter_state_dict = convert_flux_ipadapter_checkpoint_to_diffusers(original_ckpt, num_layers)
87+
88+
print("Saving Flux IP-Adapter in Diffusers format.")
89+
safetensors.torch.save_file(converted_ip_adapter_state_dict, f"{args.output_path}/model.safetensors")
90+
91+
if vision:
92+
model = CLIPVisionModelWithProjection.from_pretrained(args.vision_pretrained_or_path)
93+
model.save_pretrained(f"{args.output_path}/image_encoder")
94+
95+
96+
if __name__ == "__main__":
97+
main(args)

scripts/convert_sana_to_diffusers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
CTX = init_empty_weights if is_accelerate_available else nullcontext
2626

2727
ckpt_ids = [
28+
"Efficient-Large-Model/Sana_1600M_2Kpx_BF16/checkpoints/Sana_1600M_2Kpx_BF16.pth",
2829
"Efficient-Large-Model/Sana_1600M_1024px_MultiLing/checkpoints/Sana_1600M_1024px_MultiLing.pth",
2930
"Efficient-Large-Model/Sana_1600M_1024px_BF16/checkpoints/Sana_1600M_1024px_BF16.pth",
3031
"Efficient-Large-Model/Sana_1600M_512px_MultiLing/checkpoints/Sana_1600M_512px_MultiLing.pth",
@@ -265,9 +266,9 @@ def main(args):
265266
"--image_size",
266267
default=1024,
267268
type=int,
268-
choices=[512, 1024],
269+
choices=[512, 1024, 2048],
269270
required=False,
270-
help="Image size of pretrained model, 512 or 1024.",
271+
help="Image size of pretrained model, 512, 1024 or 2048.",
271272
)
272273
parser.add_argument(
273274
"--model_type", default="SanaMS_1600M_P1_D20", type=str, choices=["SanaMS_1600M_P1_D20", "SanaMS_600M_P1_D28"]

src/diffusers/loaders/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def text_encoder_attn_modules(text_encoder):
5555

5656
if is_torch_available():
5757
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
58-
58+
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
5959
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
6060
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
6161
_import_structure["utils"] = ["AttnProcsLayers"]
@@ -77,6 +77,7 @@ def text_encoder_attn_modules(text_encoder):
7777
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
7878
_import_structure["ip_adapter"] = [
7979
"IPAdapterMixin",
80+
"FluxIPAdapterMixin",
8081
"SD3IPAdapterMixin",
8182
]
8283

@@ -86,12 +87,14 @@ def text_encoder_attn_modules(text_encoder):
8687
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
8788
if is_torch_available():
8889
from .single_file_model import FromOriginalModelMixin
90+
from .transformer_flux import FluxTransformer2DLoadersMixin
8991
from .transformer_sd3 import SD3Transformer2DLoadersMixin
9092
from .unet import UNet2DConditionLoadersMixin
9193
from .utils import AttnProcsLayers
9294

9395
if is_transformers_available():
9496
from .ip_adapter import (
97+
FluxIPAdapterMixin,
9598
IPAdapterMixin,
9699
SD3IPAdapterMixin,
97100
)

0 commit comments

Comments
 (0)