-
Notifications
You must be signed in to change notification settings - Fork 363
Added flux demo #3418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
cehongwang
wants to merge
3
commits into
main
Choose a base branch
from
flux-demo
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Added flux demo #3418
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,275 @@ | ||
import argparse | ||
import os | ||
import re | ||
import sys | ||
import time | ||
|
||
import gradio as gr | ||
import modelopt.torch.quantization as mtq | ||
import torch | ||
import torch_tensorrt | ||
from accelerate.hooks import remove_hook_from_module | ||
from diffusers import FluxPipeline | ||
from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel | ||
|
||
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py | ||
sys.path.append(os.path.join(os.path.dirname(__file__), "../dynamo")) | ||
from register_sdpa import * | ||
|
||
DEVICE = "cuda:0" | ||
|
||
|
||
def compile_model( | ||
args, | ||
) -> tuple[ | ||
FluxPipeline, FluxTransformer2DModel, torch_tensorrt.MutableTorchTensorRTModule | ||
]: | ||
|
||
if args.dtype == "fp8": | ||
enabled_precisions = {torch.float8_e4m3fn, torch.float16} | ||
ptq_config = mtq.FP8_DEFAULT_CFG | ||
|
||
elif args.dtype == "int8": | ||
enabled_precisions = {torch.int8, torch.float16} | ||
ptq_config = mtq.INT8_DEFAULT_CFG | ||
ptq_config["quant_cfg"]["*weight_quantizer"]["axis"] = None | ||
|
||
elif args.dtype == "fp16": | ||
enabled_precisions = {torch.float16} | ||
|
||
print(f"\nUsing {args.dtype}") | ||
|
||
pipe = FluxPipeline.from_pretrained( | ||
"black-forest-labs/FLUX.1-dev", | ||
torch_dtype=torch.float16, | ||
).to(torch.float16) | ||
|
||
if args.debug: | ||
pipe.transformer = FluxTransformer2DModel( | ||
num_layers=1, num_single_layers=1, guidance_embeds=True | ||
).to(torch.float16) | ||
|
||
if args.low_vram_mode: | ||
pipe.enable_model_cpu_offload() | ||
else: | ||
pipe.to(DEVICE) | ||
|
||
backbone = pipe.transformer | ||
backbone.eval() | ||
|
||
def filter_func(name): | ||
pattern = re.compile( | ||
r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding|pos_embed|time_text_embed|context_embedder|norm_out|x_embedder).*" | ||
) | ||
return pattern.match(name) is not None | ||
|
||
def do_calibrate( | ||
pipe, | ||
prompt: str, | ||
) -> None: | ||
""" | ||
Run calibration steps on the pipeline using the given prompts. | ||
""" | ||
image = pipe( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=20, | ||
generator=torch.Generator("cuda").manual_seed(0), | ||
).images[0] | ||
|
||
def forward_loop(mod): | ||
# Switch the pipeline's backbone, run calibration | ||
pipe.transformer = mod | ||
do_calibrate( | ||
pipe=pipe, | ||
prompt="a dog running in a park", | ||
) | ||
|
||
if args.dtype != "fp16": | ||
backbone = mtq.quantize(backbone, ptq_config, forward_loop) | ||
mtq.disable_quantizer(backbone, filter_func) | ||
|
||
batch_size = 2 if args.dynamic_shapes else 1 | ||
if args.dynamic_shapes: | ||
BATCH = torch.export.Dim("batch", min=1, max=8) | ||
dynamic_shapes = { | ||
"hidden_states": {0: BATCH}, | ||
"encoder_hidden_states": {0: BATCH}, | ||
"pooled_projections": {0: BATCH}, | ||
"timestep": {0: BATCH}, | ||
"txt_ids": {}, | ||
"img_ids": {}, | ||
"guidance": {0: BATCH}, | ||
"joint_attention_kwargs": {}, | ||
"return_dict": None, | ||
} | ||
else: | ||
dynamic_shapes = None | ||
|
||
settings = { | ||
"strict": False, | ||
"allow_complex_guards_as_runtime_asserts": True, | ||
"enabled_precisions": enabled_precisions, | ||
"truncate_double": True, | ||
"min_block_size": 1, | ||
"debug": False, | ||
"use_python_runtime": True, | ||
"immutable_weights": False, | ||
"offload_module_to_cpu": True, | ||
} | ||
if args.low_vram_mode: | ||
pipe.remove_all_hooks() | ||
pipe.enable_sequential_cpu_offload() | ||
remove_hook_from_module(pipe.transformer, recurse=True) | ||
pipe.transformer.to(DEVICE) | ||
trt_gm = torch_tensorrt.MutableTorchTensorRTModule(backbone, **settings) | ||
if dynamic_shapes: | ||
trt_gm.set_expected_dynamic_shape_range((), dynamic_shapes) | ||
pipe.transformer = trt_gm | ||
|
||
image = pipe( | ||
"Test", | ||
output_type="pil", | ||
num_inference_steps=2, | ||
num_images_per_prompt=batch_size, | ||
).images | ||
|
||
torch.cuda.empty_cache() | ||
|
||
if args.low_vram_mode: | ||
pipe.remove_all_hooks() | ||
pipe.to(DEVICE) | ||
|
||
return pipe, backbone, trt_gm | ||
|
||
|
||
def launch_gradio(pipeline, backbone, trt_gm): | ||
|
||
def generate_image(prompt, inference_step, batch_size=2): | ||
start_time = time.time() | ||
image = pipeline( | ||
prompt, | ||
output_type="pil", | ||
num_inference_steps=inference_step, | ||
num_images_per_prompt=batch_size, | ||
).images | ||
end_time = time.time() | ||
return image, end_time - start_time | ||
|
||
def model_change(model): | ||
if model == "Torch Model": | ||
pipeline.transformer = backbone | ||
backbone.to(DEVICE) | ||
else: | ||
backbone.to("cpu") | ||
pipeline.transformer = trt_gm | ||
torch.cuda.empty_cache() | ||
|
||
def load_lora(path): | ||
pipeline.load_lora_weights( | ||
path, | ||
adapter_name="lora1", | ||
) | ||
pipeline.set_adapters(["lora1"], adapter_weights=[1]) | ||
pipeline.fuse_lora() | ||
pipeline.unload_lora_weights() | ||
print("LoRA loaded! Begin refitting") | ||
generate_image(pipeline, ["Test"], 2) | ||
print("Refitting Finished!") | ||
|
||
# Create Gradio interface | ||
with gr.Blocks(title="Flux Demo with Torch-TensorRT") as demo: | ||
gr.Markdown("# Flux Image Generation Demo Accelerated by Torch-TensorRT") | ||
|
||
with gr.Row(): | ||
with gr.Column(): | ||
# Input components | ||
prompt_input = gr.Textbox( | ||
label="Prompt", placeholder="Enter your prompt here...", lines=3 | ||
) | ||
model_dropdown = gr.Dropdown( | ||
choices=["Torch Model", "Torch-TensorRT Accelerated Model"], | ||
value="Torch-TensorRT Accelerated Model", | ||
label="Model Variant", | ||
) | ||
|
||
lora_upload_path = gr.Textbox( | ||
label="LoRA Path", | ||
placeholder="Enter the LoRA checkpoint path here. It could be a local path or a Hugging Face URL.", | ||
value="gokaygokay/Flux-Engrave-LoRA", | ||
lines=2, | ||
) | ||
num_steps = gr.Slider( | ||
minimum=20, maximum=100, value=20, step=1, label="Inference Steps" | ||
) | ||
batch_size = gr.Slider( | ||
minimum=1, maximum=8, value=1, step=1, label="Batch Size" | ||
) | ||
|
||
generate_btn = gr.Button("Generate Image") | ||
load_lora_btn = gr.Button("Load LoRA") | ||
|
||
with gr.Column(): | ||
# Output component | ||
output_image = gr.Gallery(label="Generated Image") | ||
time_taken = gr.Textbox( | ||
label="Generation Time (seconds)", interactive=False | ||
) | ||
|
||
# Connect the button to the generation function | ||
model_dropdown.change(model_change, inputs=[model_dropdown]) | ||
load_lora_btn.click( | ||
fn=load_lora, | ||
inputs=[ | ||
lora_upload_path, | ||
], | ||
) | ||
|
||
# Update generate button click to include time output | ||
generate_btn.click( | ||
fn=generate_image, | ||
inputs=[ | ||
prompt_input, | ||
num_steps, | ||
batch_size, | ||
], | ||
outputs=[output_image, time_taken], | ||
) | ||
demo.launch() | ||
|
||
|
||
def main(args): | ||
pipe, backbone, trt_gm = compile_model(args) | ||
launch_gradio(pipe, backbone, trt_gm) | ||
|
||
|
||
# Launch the interface | ||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Run Flux quantization with different dtypes" | ||
) | ||
|
||
parser.add_argument( | ||
"--dtype", | ||
choices=["fp8", "int8", "fp16"], | ||
default="fp16", | ||
help="Select the data type to use (fp8 or int8 or fp16)", | ||
) | ||
parser.add_argument( | ||
"--low_vram_mode", | ||
action="store_true", | ||
help="Use low VRAM mode when you have a small GPU (<=32GB)", | ||
) | ||
parser.add_argument( | ||
"--dynamic_shapes", | ||
"-d", | ||
action="store_true", | ||
help="Use dynamic shapes", | ||
) | ||
parser.add_argument( | ||
"--debug", | ||
action="store_true", | ||
help="Use debug mode", | ||
) | ||
args = parser.parse_args() | ||
main(args) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this