-
Notifications
You must be signed in to change notification settings - Fork 363
Added flux demo #3418
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Added flux demo #3418
Conversation
b2eb297
to
6d36077
Compare
Can the app display the inference time, might be nice to have some stats rendered live as you generate |
48a7c94
to
5a528f1
Compare
361fb76
to
0aeea36
Compare
9964674
to
cfbc9ea
Compare
…daGraph and Weight streaming
|
||
import gradio as gr | ||
import modelopt.torch.quantization as mtq | ||
import register_sdpa |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can avoid copying the register_sdpa.py
and sdpa_converter.py
files by doing this
import sys
import os
# Register SDPA as a standalone operator. Converter and lowering pass are defined in register_sdpa.py
sys.path.append(os.path.join(os.path.dirname(__file__), '../dynamo'))
from register_sdpa import *
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this file do ?
@@ -112,6 +112,8 @@ | |||
min_block_size=1, | |||
use_fp32_acc=True, | |||
use_explicit_typing=True, | |||
use_python_runtime=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we default to using C++ runtime (use_python_runtime=False) ?
backbone.to("cpu") | ||
pipe.transformer = trt_gm | ||
del ep | ||
torch.cuda.empty_cache() | ||
pipe.transformer.config = config | ||
|
||
trt_gm.device = torch.device("cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use the offload_module_to_cpu=True to handle this block of code ?
@@ -912,7 +913,7 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool: | |||
parse_graph_io(submodule, subgraph_data) | |||
dryrun_tracker.tensorrt_graph_count += 1 | |||
dryrun_tracker.per_subgraph_data.append(subgraph_data) | |||
|
|||
torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed here ?
@@ -341,7 +370,7 @@ def refit_module_weights( | |||
|
|||
# Iterate over all components that can be accelerated | |||
# Generate the corresponding TRT Module for those | |||
|
|||
new_weight_module.module().to(CPU_DEVICE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the new_weight_module the updated weights module that user provides ?
if verify_output and arg_inputs is not None: | ||
new_gm.to(torch.cuda.current_device()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should ensure we use the device that's passed in via args/Compilation settings or default device and not call rely on torch cuda calls unless it is needed.
self.original_model.to("cpu") | ||
torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use deallocate module
from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( | ||
clean_up_graph_after_modifications, | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use them from examples
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should avoid copying the whole model scripts for measuring perf. Try using the sys.path approach and importing the model and just a perf loop. something like
import sys
import os
sys.path.append(torchtrt_root + "examples/dynamo/apps")
from flux_demo import *
model = <insert FLUX model (fp16 or fp8) >
results = measure_flux_perf(.... )
Description
Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change.
Fixes # (issue)
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: