Skip to content

🎨 Training-free Regional Prompting for Diffusion Transformers πŸ”₯

Notifications You must be signed in to change notification settings

creative-graphic-design/flux-regional-prompting

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

12 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Training-free Regional Prompting for Diffusion Transformers

Anthony Chen1,2 Β· Jianjin Xu3 Β· Wenzhao Zheng4 Β· Gaole Dai1 Β· Yida Wang5 Β· Renrui Zhang6 Β· Haofan Wang2 Β· Shanghang Zhang1*

1Peking University Β· 2InstantX Team Β· 3Carnegie Mellon University Β· 4UC Berkeley Β· 5Li Auto Inc. Β· 6CUHK

Installation

pip install git+https://github.com/creative-graphic-design/flux-regional-prompting

How to Use

import torch

from flux_regional_prompting.models.attention_processor import (
    RegionalFluxAttnProcessor2_0,
)
from flux_regional_prompting.models.transformers import RegionalFluxTransformer2DModel
from flux_regional_prompting.pipelines import (
    RegionalFluxPipeline,
)

model_id = "black-forest-labs/FLUX.1-dev"
torch_dtype = torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#
# Load module and pipeline
#
transformer = RegionalFluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
)
transformer = transformer.to(torch_dtype)

pipe = RegionalFluxPipeline.from_pretrained(
    model_id,
    transformer=transformer,
    torch_dtype=torch_dtype,
)
pipe = pipe.to(device)

#
# Replace `FluxAttnProcessor2_0` to RegionalFluxAttnProcessor2_0
#
attn_procs = {}
for name in pipe.transformer.attn_processors.keys():
    if "transformer_blocks" in name and name.endswith("attn.processor"):
        attn_procs[name] = RegionalFluxAttnProcessor2_0()
    else:
        attn_procs[name] = pipe.transformer.attn_processors[name]
pipe.transformer.set_attn_processor(attn_procs)

#
# Set hyperparameters
#

# General settings
image_width, image_height = 1280, 768
num_inference_steps = 24
guidance_scale = 3.5
seed = 124
base_prompt = "An ancient woman stands solemnly holding a blazing torch, while a fierce battle rages in the background, capturing both strength and tragedy in a historical war scene."
background_prompt = "a photo"  # set by default, but if you want to enrich background, you can set it to a more descriptive prompt
regional_prompt_mask_pairs = {
    "0": {
        "description": "A dignified woman in ancient robes stands in the foreground, her face illuminated by the torch she holds high. Her expression is one of determination and sorrow, her clothing and appearance reflecting the historical period. The torch casts dramatic shadows across her features, its flames dancing vibrantly against the darkness.",
        "mask": [128, 128, 640, 768],
    }
}

## region control factor settings
mask_inject_steps = 10  # larger means stronger control, recommended between 5-10
double_inject_blocks_interval = 1  # 1 means strongest control
single_inject_blocks_interval = 1  # 1 means strongest control
base_ratio = 0.2  # smaller means stronger control

#
# Prepare regional masks
#
regional_prompts = []
regional_masks = []
background_mask = torch.ones((image_height, image_width))
for region_name, region in regional_prompt_mask_pairs.items():
    description = region["description"]
    mask = region["mask"]
    x1, y1, x2, y2 = mask
    mask = torch.zeros((image_height, image_width))
    mask[y1:y2, x1:x2] = 1.0
    background_mask -= mask
    regional_prompts.append(description)
    regional_masks.append(mask)

# if regional masks don't cover the whole image, append background prompt and mask
if background_mask.sum() > 0:
    regional_prompts.append(background_prompt)
    regional_masks.append(background_mask)

#
# Generate image with `RegionalFluxPipeline`
#
image = pipe(
    prompt=base_prompt,
    width=image_width,
    height=image_height,
    mask_inject_steps=mask_inject_steps,
    guidance_scale=guidance_scale,
    num_inference_steps=num_inference_steps,
    generator=torch.Generator(device).manual_seed(seed),
    joint_attention_kwargs={
        "regional_prompts": regional_prompts,
        "regional_masks": regional_masks,
        "double_inject_blocks_interval": double_inject_blocks_interval,
        "single_inject_blocks_interval": single_inject_blocks_interval,
        "base_ratio": base_ratio,
    },
).images[0]

filename = f'{"-".join(base_prompt.lower().split())}.png'
image.save(filename)

Training-free Regional Prompting for Diffusion Transformers(Regional-Prompting-FLUX) enables Diffusion Transformers (i.e., FLUX) with find-grained compositional text-to-image generation capability in a training-free manner. Empirically, we show that our method is highly effective and compatible with LoRA and ControlNet.

We inference at speed much faster than the RPG-based implementation, yet take up less GPU memory.

Release

  • [2024/11/05] πŸ”₯ We release the code, feel free to try it out!
  • [2024/11/05] πŸ”₯ We release the technical report!

Demos

Custom Regional Control

Regional Masks Configuration Generated Result

Red: Cocktail region (xyxy: [450, 560, 960, 900])
Green: Table region (xyxy: [320, 900, 1280, 1280])
Blue: Background
Base Prompt:
"A tropical cocktail on a wooden table at a beach during sunset."

Background Prompt:
"A photo"

Regional Prompts:
  • Region 0: "A colorful cocktail in a glass with tropical fruits and a paper umbrella, with ice cubes and condensation."
  • Region 1: "Weathered wooden table with seashells and a napkin."
Settings:
  • Image Size: 1280x1280
  • Seed: 124
  • Mask Inject Steps: 10
  • Double Inject Interval: 1
  • Single Inject Interval: 2
  • Base Ratio: 0.1

Red: Rainbow region (xyxy: [0, 0, 1280, 256])
Green: Ship region (xyxy: [0, 256, 1280, 520])
Yellow: Fish region (xyxy: [0, 520, 640, 768])
Blue: Treasure region (xyxy: [640, 520, 1280, 768])
Base Prompt:
"A majestic ship sails under a rainbow as vibrant marine creatures glide through crystal waters below, embodying nature's wonder, while an ancient, rusty treasure chest lies hidden on the ocean floor."

Regional Prompts:
  • Region 0: "A massive, radiant rainbow arches across the vast sky, glowing in vivid colors and blending with ethereal clouds that drift gently, casting a magical light across the scene and creating a surreal, dreamlike atmosphere."
  • Region 1: "The majestic ship, with grand sails billowing against the crystal blue waters, glides forward as birds soar overhead. Its hull and sails mirror the vivid hues of the sea, embodying a sense of adventure and mystery as it journeys through this enchanted world."
  • Region 2: "Beneath the sparkling water, schools of colorful fish dart playfully, their scales flashing in shades of yellow, blue, and orange. Tiny seahorses drift by, while gentle turtles paddle along, creating a lively, enchanting underwater scene."
  • Region 3: "On the ocean floor lies an ancient, rusty treasure chest, heavily encrusted with barnacles and seaweed. The chest's corroded metal and weathered wood hint at centuries spent underwater. Its lid is slightly ajar, revealing a faint glow within, as small fish dart around, adding an air of mystery to the forgotten relic."
Settings:
  • Image Size: 1280x768
  • Seed: 124
  • Mask Inject Steps: 10
  • Double Inject Interval: 1
  • Single Inject Interval: 1
  • Base Ratio: 0.2

Red: Woman with torch region (xyxy: [128, 128, 640, 768])
Green: Background
Base Prompt:
"An ancient woman stands solemnly holding a blazing torch, while a fierce battle rages in the background, capturing both strength and tragedy in a historical war scene."

Background Prompt:
"A photo."

Regional Prompts:
  • Region 0: "A dignified woman in ancient robes stands in the foreground, her face illuminated by the torch she holds high. Her expression is one of determination and sorrow, her clothing and appearance reflecting the historical period. The torch casts dramatic shadows across her features, its flames dancing vibrantly against the darkness."
Settings:
  • Image Size: 1280x768
  • Seed: 124
  • Mask Inject Steps: 10
  • Double Inject Interval: 1
  • Single Inject Interval: 1
  • Base Ratio: 0.3

Red: Dog region (assets/demo_custom_0_mask_0.png)
Green: Cat region (assets/demo_custom_0_mask_1.png)
Blue: Background
Base Prompt:
"dog and cat sitting on lush green grass, in a sunny outdoor setting."

Background Prompt:
"A photo"

Regional Prompts:
  • Region 0: "A friendly golden retriever with a luxurious golden coat, floppy ears, and warm expression sitting on vibrant green grass."
  • Region 1: "A golden british shorthair cat with round face, plush coat, and copper eyes sitting regally"
Settings:
  • Image Size: 1280x768
  • Seed: 124
  • Mask Inject Steps: 10
  • Double Inject Interval: 1
  • Single Inject Interval: 2
  • Base Ratio: 0.1

Note: Generation with segmentation mask is a experimental function, the generated image is not perfectly constrained by the regions, we assume it is because the mask suffers from degradation during the downsampling process.

LoRA Compatability

Regional Masks Configuration Generated Result

Red: Dinosaur region (xyxy: [0, 0, 640, 1280]) Blue: City region (xyxy: [640, 0, 1280, 1280])
Base Prompt:
"Sketched style: A cute dinosaur playfully blowing tiny fire puffs over a cartoon city in a cheerful scene."

Regional Prompts:
  • Region 0: "Sketched style, dinosaur with round eyes and a mischievous smile, puffing small flames over the city."
  • Region 1: "Sketched style, city with colorful buildings and tiny flames gently floating above, adding a playful touch."
Settings:
  • Image Size: 1280x1280
  • Seed: 1298
  • Mask Inject Steps: 10
  • Double Inject Interval: 1
  • Single Inject Interval: 1
  • Base Ratio: 0.1
LoRA:
  • Path: Shakker-Labs/FLUX.1-dev-LoRA-Children-Simple-Sketch
  • Scale: 1.5
  • Trigger Words: "sketched style"

Red: UFO region (xyxy: [320, 320, 640, 640])
Base Prompt:
"A cute cartoon-style UFO floating above a sunny city street, artistic style blends reality and illustration elements"

Background Prompt:
"A photo"

Regional Prompts:
  • Region 0: "A cartoon-style silver UFO with blinking lights hovering in the air, artistic style blends reality and illustration elements"
Settings:
  • Image Size: 1280x1280
  • Seed: 1298
  • Mask Inject Steps: 10
  • Double Inject Interval: 1
  • Single Inject Interval: 2
  • Base Ratio: 0.2
LoRA:
  • Path: Shakker-Labs/FLUX.1-dev-LoRA-Vector-Journey
  • Scale: 1.0

ControlNet Compatability

Regional Masks Configuration Generated Result

Red: First car region (xyxy: [0, 0, 426, 968])
Green: Second car region (xyxy: [426, 0, 853, 968])
Blue: Third car region (xyxy: [853, 0, 1280, 968])


Base Prompt:
"Three high-performance sports cars, red, blue, and yellow, are racing side by side on a city street"

Regional Prompts:
  • Region 0: "A sleek red sports car in the lead position, with aggressive aerodynamic styling and gleaming paint that catches the light. The car appears to be moving at high speed with motion blur effects."
  • Region 1: "A powerful blue sports car in the middle position, neck-and-neck with its competitors. Its metallic paint shimmers as it races forward, with visible speed lines and dynamic movement."
  • Region 2: "A striking yellow sports car in the third position, its bold color standing out against the street. The car's aggressive stance and aerodynamic profile emphasize its racing performance."
Settings:
  • Image Size: 1280x968
  • Seed: 124
  • Mask Inject Steps: 10
  • Double Inject Blocks Interval: 1
  • Single Inject Blocks Interval: 2
  • Base Ratio: 0.2
ControlNet:
  • Control Mode: 2
  • ControlNet Conditioning Scale: 0.7

Red: Woman region (xyxy: [0, 0, 640, 968])
Green: Beach region (xyxy: [640, 0, 1280, 968])


Base Prompt:
"A woman walking along a beautiful beach with a scenic coastal view."

Regional Prompts:
  • Region 0: "A woman in a flowing summer dress with delicate pink and blue flower patterns walking barefoot on the sandy beach. Her floral-patterned dress billows gracefully in the ocean breeze as she strolls casually along the shoreline, with a peaceful expression on her face and her hair gently tousled by the wind."
  • Region 1: "A stunning coastal landscape with crystal clear turquoise waters meeting the horizon. Rhythmic waves roll in with white foamy crests, creating a mesmerizing pattern as they crash onto the shore. The waves vary in size, some gently lapping at the sand while others surge forward with more force. White sandy beach stretches into the distance, with gentle waves leaving intricate patterns on the wet sand and scattered palm trees swaying in the breeze."
Settings:
  • Image Size: 1280x968
  • Seed: 124
  • Mask Inject Steps: 10
  • Double Inject Blocks Interval: 1
  • Single Inject Blocks Interval: 2
  • Base Ratio: 0.2
ControlNet:
  • Control Mode: 4
  • ControlNet Conditioning Scale: 0.7

Installation

We use previous commit from diffusers repo to ensure reproducibility, as we found new diffusers version may experience different results.

# install diffusers locally
git clone https://github.com/huggingface/diffusers.git
cd diffusers

# reset diffusers version to 0.31.dev, where we developed Regional-Prompting-FLUX on, different version may experience different results
git reset --hard d13b0d63c0208f2c4c078c4261caf8bf587beb3b
pip install -e ".[torch]"
cd ..

# install other dependencies
pip install -U transformers sentencepiece protobuf PEFT

# clone this repo
git clone https://github.com/antonioo-c/Regional-Prompting-FLUX.git

# replace file in diffusers
cd Regional-Prompting-FLUX
cp transformer_flux.py ../diffusers/src/diffusers/models/transformers/transformer_flux.py

Quick Start

See detailed example (including LoRAs and ControlNets) in infer_flux_regional.py. Below is a quick start example.

import torch
from pipeline_flux_regional import RegionalFluxPipeline, RegionalFluxAttnProcessor2_0

pipeline = RegionalFluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16).to("cuda")
attn_procs = {}
for name in pipeline.transformer.attn_processors.keys():
    if 'transformer_blocks' in name and name.endswith("attn.processor"):
        attn_procs[name] = RegionalFluxAttnProcessor2_0()
    else:
        attn_procs[name] = pipeline.transformer.attn_processors[name]
pipeline.transformer.set_attn_processor(attn_procs)

## general settings
image_width = 1280
image_height = 768
num_inference_steps = 24
seed = 124
base_prompt = "An ancient woman stands solemnly holding a blazing torch, while a fierce battle rages in the background, capturing both strength and tragedy in a historical war scene."
background_prompt = "a photo" # set by default, but if you want to enrich background, you can set it to a more descriptive prompt
regional_prompt_mask_pairs = {
    "0": {
        "description": "A dignified woman in ancient robes stands in the foreground, her face illuminated by the torch she holds high. Her expression is one of determination and sorrow, her clothing and appearance reflecting the historical period. The torch casts dramatic shadows across her features, its flames dancing vibrantly against the darkness.",
        "mask": [128, 128, 640, 768]
    }
}
## region control factor settings
mask_inject_steps = 10 # larger means stronger control, recommended between 5-10
double_inject_blocks_interval = 1 # 1 means strongest control
single_inject_blocks_interval = 1 # 1 means strongest control
base_ratio = 0.2 # smaller means stronger control

regional_prompts = []
regional_masks = []
background_mask = torch.ones((image_height, image_width))
for region_idx, region in regional_prompt_mask_pairs.items():
    description = region['description']
    mask = region['mask']
    x1, y1, x2, y2 = mask
    mask = torch.zeros((image_height, image_width))
    mask[y1:y2, x1:x2] = 1.0
    background_mask -= mask
    regional_prompts.append(description)
    regional_masks.append(mask)
# if regional masks don't cover the whole image, append background prompt and mask
if background_mask.sum() > 0:
    regional_prompts.append(background_prompt)
    regional_masks.append(background_mask)

image = pipeline(
    prompt=base_prompt,
    width=image_width, height=image_height,
    mask_inject_steps=mask_inject_steps,
    num_inference_steps=num_inference_steps,
    generator=torch.Generator("cuda").manual_seed(seed),
    joint_attention_kwargs={
        "regional_prompts": regional_prompts,
        "regional_masks": regional_masks,
        "double_inject_blocks_interval": double_inject_blocks_interval,
        "single_inject_blocks_interval": single_inject_blocks_interval,
        "base_ratio": base_ratio
    },
  ).images[0]

image.save(f"output.jpg")

πŸ‘ Acknowledgment

Our work is sponsored by HuggingFace and fal.ai. Thanks!

Cite

If you find Regional-Prompting-FLUX useful for your research and applications, please cite us using this BibTeX:

@article{chen2024training,
  title={Training-free Regional Prompting for Diffusion Transformers},
  author={Chen, Anthony and Xu, Jianjin and Zheng, Wenzhao and Dai, Gaole and Wang, Yida and Zhang, Renrui and Wang, Haofan and Zhang, Shanghang},
  journal={arXiv preprint arXiv:2411.02395},
  year={2024}
}

For any question, feel free to contact us via [email protected].

About

🎨 Training-free Regional Prompting for Diffusion Transformers πŸ”₯

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%