Skip to content

Added flux demo #3418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ new_local_repository = use_repo_rule("@bazel_tools//tools/build_defs/repo:local.
new_local_repository(
name = "cuda",
build_file = "@//third_party/cuda:BUILD",
path = "/usr/local/cuda-12.8/",
path = "/usr/local/cuda-12.9/",
)

# for Jetson
Expand Down
275 changes: 275 additions & 0 deletions examples/apps/flux_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
import argparse
import os
import re
import sys
import time

import gradio as gr
import modelopt.torch.quantization as mtq
import torch
import torch_tensorrt
from accelerate.hooks import remove_hook_from_module
from diffusers import FluxPipeline
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel

# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo"))
from register_sdpa import *

DEVICE = "cuda:0"


def compile_model(
args,
) -> tuple[
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule
]:

if args.dtype == "fp8":
enabled_precisions = {torch.float8_e4m3fn, torch.float16}
ptq_config = mtq.FP8_DEFAULT_CFG

elif args.dtype == "int8":
enabled_precisions = {torch.int8, torch.float16}
ptq_config = mtq.INT8_DEFAULT_CFG
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None

elif args.dtype == "fp16":
enabled_precisions = {torch.float16}

print(f"\nUsing {args.dtype}")

pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.float16,
).to(torch.float16)

if args.debug:
pipe.transformer = FluxTransformer2DModel(
num_layers=1, num_single_layers=1, guidance_embeds=True
).to(torch.float16)
Comment on lines +47 to +50
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this


if args.low_vram_mode:
pipe.enable_model_cpu_offload()
else:
pipe.to(DEVICE)

backbone = pipe.transformer
backbone.eval()

def filter_func(name):
pattern = re.compile(
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*"
)
return pattern.match(name) is not None

def do_calibrate(
pipe,
prompt: str,
) -> None:
"""
Run calibration steps on the pipeline using the given prompts.
"""
image = pipe(
prompt,
output_type="pil",
num_inference_steps=20,
generator=torch.Generator("cuda").manual_seed(0),
).images[0]

def forward_loop(mod):
# Switch the pipeline's backbone, run calibration
pipe.transformer = mod
do_calibrate(
pipe=pipe,
prompt="a dog running in a park",
)

if args.dtype != "fp16":
backbone = mtq.quantize(backbone, ptq_config, forward_loop)
mtq.disable_quantizer(backbone, filter_func)

batch_size = 2 if args.dynamic_shapes else 1
if args.dynamic_shapes:
BATCH = torch.export.Dim("batch", min=1, max=8)
dynamic_shapes = {
"hidden_states": {0: BATCH},
"encoder_hidden_states": {0: BATCH},
"pooled_projections": {0: BATCH},
"timestep": {0: BATCH},
"txt_ids": {},
"img_ids": {},
"guidance": {0: BATCH},
"joint_attention_kwargs": {},
"return_dict": None,
}
else:
dynamic_shapes = None

settings = {
"strict": False,
"allow_complex_guards_as_runtime_asserts": True,
"enabled_precisions": enabled_precisions,
"truncate_double": True,
"min_block_size": 1,
"debug": False,
"use_python_runtime": True,
"immutable_weights": False,
"offload_module_to_cpu": True,
}
if args.low_vram_mode:
pipe.remove_all_hooks()
pipe.enable_sequential_cpu_offload()
remove_hook_from_module(pipe.transformer, recurse=True)
pipe.transformer.to(DEVICE)
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings)
if dynamic_shapes:
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes)
pipe.transformer = trt_gm

image = pipe(
"Test",
output_type="pil",
num_inference_steps=2,
num_images_per_prompt=batch_size,
).images

torch.cuda.empty_cache()

if args.low_vram_mode:
pipe.remove_all_hooks()
pipe.to(DEVICE)

return pipe, backbone, trt_gm


def launch_gradio(pipeline, backbone, trt_gm):

def generate_image(prompt, inference_step, batch_size=2):
start_time = time.time()
image = pipeline(
prompt,
output_type="pil",
num_inference_steps=inference_step,
num_images_per_prompt=batch_size,
).images
end_time = time.time()
return image, end_time - start_time

def model_change(model):
if model == "Torch Model":
pipeline.transformer = backbone
backbone.to(DEVICE)
else:
backbone.to("cpu")
pipeline.transformer = trt_gm
torch.cuda.empty_cache()

def load_lora(path):
pipeline.load_lora_weights(
path,
adapter_name="lora1",
)
pipeline.set_adapters(["lora1"], adapter_weights=[1])
pipeline.fuse_lora()
pipeline.unload_lora_weights()
print("LoRA loaded! Begin refitting")
generate_image(pipeline, ["Test"], 2)
print("Refitting Finished!")

# Create Gradio interface
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo:
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT")

with gr.Row():
with gr.Column():
# Input components
prompt_input = gr.Textbox(
label="Prompt", placeholder="Enter your prompt here...", lines=3
)
model_dropdown = gr.Dropdown(
choices=["Torch Model", "Torch-TensorRT Accelerated Model"],
value="Torch-TensorRT Accelerated Model",
label="Model Variant",
)

lora_upload_path = gr.Textbox(
label="LoRA Path",
placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.",
value="gokaygokay/Flux-Engrave-LoRA",
lines=2,
)
num_steps = gr.Slider(
minimum=20, maximum=100, value=20, step=1, label="Inference Steps"
)
batch_size = gr.Slider(
minimum=1, maximum=8, value=1, step=1, label="Batch Size"
)

generate_btn = gr.Button("Generate Image")
load_lora_btn = gr.Button("Load LoRA")

with gr.Column():
# Output component
output_image = gr.Gallery(label="Generated Image")
time_taken = gr.Textbox(
label="Generation Time (seconds)", interactive=False
)

# Connect the button to the generation function
model_dropdown.change(model_change, inputs=[model_dropdown])
load_lora_btn.click(
fn=load_lora,
inputs=[
lora_upload_path,
],
)

# Update generate button click to include time output
generate_btn.click(
fn=generate_image,
inputs=[
prompt_input,
num_steps,
batch_size,
],
outputs=[output_image, time_taken],
)
demo.launch()


def main(args):
pipe, backbone, trt_gm = compile_model(args)
launch_gradio(pipe, backbone, trt_gm)


# Launch the interface
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run Flux quantization with different dtypes"
)

parser.add_argument(
"--dtype",
choices=["fp8", "int8", "fp16"],
default="fp16",
help="Select the data type to use (fp8 or int8 or fp16)",
)
parser.add_argument(
"--low_vram_mode",
action="store_true",
help="Use low VRAM mode when you have a small GPU (<=32GB)",
)
parser.add_argument(
"--dynamic_shapes",
"-d",
action="store_true",
help="Use dynamic shapes",
)
parser.add_argument(
"--debug",
action="store_true",
help="Use debug mode",
)
args = parser.parse_args()
main(args)
9 changes: 4 additions & 5 deletions examples/dynamo/mutable_torchtrt_module_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch_tensorrt as torch_trt
import torchvision.models as models
from diffusers import DiffusionPipeline

np.random.seed(5)
torch.manual_seed(5)
Expand All @@ -31,7 +32,7 @@
# Initialize the Mutable Torch TensorRT Module with settings.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
settings = {
"use_python": False,
"use_python_runtime": False,
"enabled_precisions": {torch.float32},
"immutable_weights": False,
}
Expand All @@ -40,7 +41,6 @@
mutable_module = torch_trt.MutableTorchTensorRTModule(model, **settings)
# You can use the mutable module just like the original pytorch module. The compilation happens while you first call the mutable module.
mutable_module(*inputs)

# %%
# Make modifications to the mutable module.
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -73,13 +73,12 @@
# Stable Diffusion with Huggingface
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

from diffusers import DiffusionPipeline

with torch.no_grad():
settings = {
"use_python_runtime": True,
"enabled_precisions": {torch.float16},
"debug": True,
"debug": False,
"immutable_weights": False,
}

Expand All @@ -106,7 +105,7 @@
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
"return_dict": False,
"return_dict": None,
}
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shapes, kwargs_dynamic_shapes
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/refit_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
)

# Check the output
model2.to("cuda")
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
assert torch.allclose(
Expand Down
11 changes: 6 additions & 5 deletions examples/dynamo/torch_export_flux_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,21 +114,22 @@
min_block_size=1,
use_fp32_acc=True,
use_explicit_typing=True,
immutable_weights=False,
offload_module_to_cpu=True,
)

# %%
# Post Processing
# ---------------------------
# Release the GPU memory occupied by the exported program and the pipe.transformer
# Set the transformer in the Flux pipeline to the Torch-TRT compiled model

del ep
backbone.to("cpu")
pipe.transformer = None
pipe.to(DEVICE)
torch.cuda.empty_cache()
pipe.transformer = trt_gm
del ep
torch.cuda.empty_cache()
pipe.transformer.config = config

trt_gm.device = torch.device("cuda")
# %%
# Image generation using prompt
# ---------------------------
Expand Down
3 changes: 2 additions & 1 deletion py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,7 @@ def compile(
)

gm = exported_program.module()
# Move the weights in the state_dict to CPU
logger.debug("Input graph: " + str(gm.graph))

# Apply lowering on the graph module
Expand Down Expand Up @@ -914,7 +915,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
parse_graph_io(submodule, subgraph_data)
dryrun_tracker.tensorrt_graph_count += 1
dryrun_tracker.per_subgraph_data.append(subgraph_data)

torch.cuda.empty_cache()
# Create TRT engines from submodule
if not settings.dryrun:
trt_module = convert_module(
Expand Down
Loading