Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions py/torch_tensorrt/dynamo/_refit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
logger = logging.getLogger(__name__)


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
Expand Down Expand Up @@ -85,7 +85,7 @@ def construct_refit_mapping(
return weight_refit_map


@needs_refit
@needs_refit # type: ignore[misc]
def construct_refit_mapping_from_weight_name_map(
weight_name_map: dict[Any, Any],
state_dict: dict[Any, Any],
Expand Down Expand Up @@ -128,7 +128,7 @@ def construct_refit_mapping_from_weight_name_map(
return engine_weight_map


@needs_refit
@needs_refit # type: ignore[misc]
def _refit_single_trt_engine_with_gm(
new_gm: torch.fx.GraphModule,
old_engine: trt.ICudaEngine,
Expand Down Expand Up @@ -211,7 +211,7 @@ def _refit_single_trt_engine_with_gm(
raise AssertionError("Refitting failed.")


@needs_refit
@needs_refit # type: ignore[misc]
def refit_module_weights(
compiled_module: torch.fx.GraphModule | ExportedProgram,
new_weight_module: ExportedProgram,
Expand Down Expand Up @@ -484,9 +484,10 @@ def refit_module_weights(
weight_name_map=None,
)

# clear EXCLUDE_WEIGHTS flag
# clear EXCLUDE_WEIGHTS flag and set INCLUDE_REFIT flag to make the engine refittable
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
serialized_engine = engine.serialize_with_config(serialization_config)

if isinstance(compiled_submodule, PythonTorchTensorRTModule):
Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,6 @@ def __setstate__(self, state: dict[str, Any]) -> None:
"engine_capability",
"hardware_compatible",
"refit_identical_engine_weights",
"strip_engine_weights", # TODO: @Evan to remove this after implementing caching weight-stripped engines as default?
"immutable_weights",
"enable_weight_streaming",
"tiling_optimization_level",
Expand Down
4 changes: 0 additions & 4 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,6 @@ def _pretraced_backend(
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
)
if settings.strip_engine_weights:
logger.error(
"strip_engine_weights arg is not supported for torch.compile()"
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
Expand Down
95 changes: 1 addition & 94 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from torch_tensorrt._utils import is_tensorrt_version_supported
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
DYNAMO_CONVERTERS as CONVERTERS,
Expand Down Expand Up @@ -594,79 +594,6 @@ def _save_weight_mapping(self) -> None:
gc.collect()
torch.cuda.empty_cache()

@needs_refit # type: ignore[misc]
def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
# query the cached TRT engine
cached_data = self.engine_cache.check(hash_val) # type: ignore[union-attr]
if cached_data is not None: # hit the cache
(
serialized_engine,
self._input_names,
self._output_names,
cached_engine_input_specs,
engine_compilation_settings,
self.weight_name_map,
self.ctx.requires_output_allocator,
) = cached_data

setting_compatiblity, incompattible_settings = settings_are_compatible(
self.compilation_settings, engine_compilation_settings
)
assert (
setting_compatiblity
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {engine_compilation_settings}, new_settings: {self.compilation_settings})"

for i, e in enumerate(
[
Input.equivalent_spec(c, i)
for c, i in zip(cached_engine_input_specs, self.input_specs)
]
):
assert (
e
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_input_specs[i]}, new size: {self.input_specs[i]}"

_LOGGER.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

# refit the cached engine with the new graph module
if not self.compilation_settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(serialized_engine)

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

_refit_single_trt_engine_with_gm(
new_gm=self.module,
old_engine=engine,
input_list=self.input_specs,
settings=self.compilation_settings,
weight_name_map=self.weight_name_map,
)

# TODO: @Evan is waiting for TRT's feature to load the weight-stripped engine
# # EXCLUDE_WEIGHTS flag must be cleared
# serialization_config = engine.create_serialization_config()
# serialization_config.clear_flag(
# trt.SerializationFlag.EXCLUDE_WEIGHTS
# )
# serialized_engine = engine.serialize_with_config(
# serialization_config
# )
# # As of now, the engine becomes non-refittable because when EXCLUDE_WEIGHTS flag is cleared, the REFIT flag is also cleared by TRT to make the plan file smaller

return TRTInterpreterResult(
engine,
self._input_names,
self._output_names,
self.weight_name_map,
self.ctx.requires_output_allocator,
)
return None

def run(
self,
strict_type_constraints: bool = False,
Expand All @@ -682,26 +609,6 @@ def run(
Return:
TRTInterpreterResult
"""
# self.engine_cache could be None if:
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
# 2) both cache_built_engines and reuse_cached_engines are False
if (
self.engine_cache is not None
and not self.compilation_settings.immutable_weights
):
if (
self.compilation_settings.cache_built_engines
or self.compilation_settings.reuse_cached_engines
):
hash_val = self.engine_cache.get_hash(
self.module, self.input_specs, self.compilation_settings
)

if self.compilation_settings.reuse_cached_engines:
interpreter_result = self._pull_cached_engine(hash_val)
if interpreter_result is not None: # hit the cache
return interpreter_result # type: ignore[no-any-return]

self._construct_trt_network_def()
_LOGGER.debug(
f"CPU memory usage after network construction: {get_cpu_memory_usage()} MB"
Expand Down
163 changes: 139 additions & 24 deletions py/torch_tensorrt/dynamo/conversion/_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import logging
from typing import Any, List, NamedTuple, Optional, Sequence

import tensorrt as trt
import torch
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._features import ENABLED_FEATURES, needs_refit
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.conversion._TRTInterpreter import TRTInterpreter
from torch_tensorrt.dynamo._settings import CompilationSettings, settings_are_compatible
from torch_tensorrt.dynamo.conversion._TRTInterpreter import (
TRTInterpreter,
TRTInterpreterResult,
)
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.utils import (
get_cpu_memory_usage,
get_output_dtypes,
release_host_and_device_memory,
)
from torch_tensorrt.logging import TRT_LOGGER

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -63,6 +68,128 @@ def interpret_module_to_result(
SerializedInterpreterResult
"""

def _insert_engine_to_cache(
hash_val: str, interpreter_result: TRTInterpreterResult
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, I like this

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that the function needs to be in the interpret functions scope?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a specific reason, but I just don't know when the engine_cache will be used other than in the function interpret_module_to_result(). To make it safe and self-contained, I picked the smallest scope. Is there any other cases that might use engine_cache?

) -> None: # type: ignore[unused-ignore]
# Cache the weight-stripped engine regardless of the `strip_engine_weights` setting
if engine_cache.check(hash_val) is not None: # type: ignore[union-attr]
logger.info(f"Engine already exists in cache for hash: {hash_val}")
return
if not settings.strip_engine_weights:
# set EXCLUDE_WEIGHTS flag to strip weights
serialization_config = (
interpreter_result.engine.create_serialization_config()
)
serialization_config.set_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
weight_stripped_serialized_engine = (
interpreter_result.engine.serialize_with_config(serialization_config)
)
else:
weight_stripped_serialized_engine = interpreter_result.engine.serialize()

# Insert weight-stripped engine to cache
engine_cache.insert( # type: ignore[union-attr]
hash_val,
(
weight_stripped_serialized_engine,
interpreter_result.input_names,
interpreter_result.output_names,
inputs,
settings,
interpreter_result.weight_name_map,
interpreter_result.requires_output_allocator,
),
)
logger.info(f"Engine was successfully inserted into cache for hash: {hash_val}")

@needs_refit # type: ignore[misc]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the insert and extract both be needs refit?

Also shouldnt this gracefully pass through vs the typically unimplemented error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the insert and extract both be needs refit?

insert seems not to involve any refitting stuff. It supports a scenario that users insert engines on machine A that doesn't support refit but pull engines on machine B that supports refit. Please correct me if wrong.

Also shouldnt this gracefully pass through vs the typically unimplemented error?

Not sure if I understand your question correctly. The reason why we need refit in pull is that we save weight-stripped engine in this implementation, which needs to be refitted to get correct weights before using.

def _pull_cached_engine(hash_val: str) -> Optional[SerializedInterpreterResult]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

# query the cached TRT engine
cached_data = engine_cache.check(hash_val) # type: ignore[union-attr]
if cached_data is not None: # hit the cache
(
serialized_engine, # weight-stripped engine
input_names,
output_names,
cached_engine_inputs,
cached_engine_compilation_settings,
weight_name_map,
requires_output_allocator,
) = cached_data

setting_compatiblity, incompattible_settings = settings_are_compatible(
settings, cached_engine_compilation_settings
)
assert (
setting_compatiblity
), f"Attempted to refit a cached engine with incompatible settings: {incompattible_settings}, (old_settings: {cached_engine_compilation_settings}, new_settings: {settings})"

for i, e in enumerate(
[
Input.equivalent_spec(c, i)
for c, i in zip(cached_engine_inputs, inputs)
]
):
assert (
e
), f"Attempted to refit a cached engine built for a different input size (input: {i}, cached size: {cached_engine_inputs[i]}, new size: {inputs[i]}"

logger.info(
"Found the cached engine that corresponds to this graph. It is directly loaded."
)

# refit the cached engine with the new graph module
if not settings.strip_engine_weights:
runtime = trt.Runtime(TRT_LOGGER)
engine = runtime.deserialize_cuda_engine(
serialized_engine
) # weight-stripped engine

from torch_tensorrt.dynamo._refit import (
_refit_single_trt_engine_with_gm,
)

# weight-stripped engine --in place--> weight-included engine
_refit_single_trt_engine_with_gm(
new_gm=module,
old_engine=engine,
input_list=inputs,
settings=settings,
weight_name_map=weight_name_map,
)

# EXCLUDE_WEIGHTS flag must be cleared and INCLUDE_REFIT flag must be set
serialization_config = engine.create_serialization_config()
serialization_config.clear_flag(trt.SerializationFlag.EXCLUDE_WEIGHTS)
serialization_config.set_flag(trt.SerializationFlag.INCLUDE_REFIT)
serialized_engine = engine.serialize_with_config(serialization_config)
# Start from here, the engine is weight-included and refittable

with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
serialized_engine = engine_bytes.getvalue()

return SerializedInterpreterResult(
serialized_engine=serialized_engine,
input_names=input_names,
output_names=output_names,
weight_name_map=weight_name_map,
requires_output_allocator=requires_output_allocator,
)
return None

# engine_cache could be None if:
# 1) engine_cache is not passed in when calling this function like convert_exported_program_to_serialized_trt_engine etc., or
# 2) both cache_built_engines and reuse_cached_engines are False
if engine_cache is not None and not settings.immutable_weights:
if settings.cache_built_engines or settings.reuse_cached_engines:
hash_val = engine_cache.get_hash(module, inputs, settings)

if settings.reuse_cached_engines:
serialized_interpreter_result = _pull_cached_engine(hash_val)
if serialized_interpreter_result is not None: # hit the cache
return serialized_interpreter_result # type: ignore[no-any-return]

output_dtypes = infer_module_output_dtypes(
module, truncate_double=settings.truncate_double
)
Expand All @@ -86,32 +213,20 @@ def interpret_module_to_result(
f"CPU memory usage after clearing frozen parameters and building memory in conversion: {get_cpu_memory_usage()} MB"
)

serialized_engine = interpreter_result.engine.serialize()
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
serialized_engine = engine_bytes.getvalue()
logger.debug(
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
)

# Engine caching only for refittable engines
if (
not settings.immutable_weights
and settings.cache_built_engines
and engine_cache is not None
):
hash_val = engine_cache.get_hash(module, inputs, settings)
engine_cache.insert(
hash_val,
(
serialized_engine,
interpreter_result.input_names,
interpreter_result.output_names,
inputs,
settings,
interpreter_result.weight_name_map,
interpreter_result.requires_output_allocator,
),
_insert_engine_to_cache(hash_val, interpreter_result)

serialized_engine = interpreter_result.engine.serialize()
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
serialized_engine = engine_bytes.getvalue()
logger.debug(
f"CPU memory usage after serializing engine: {get_cpu_memory_usage()} MB"
)

serialized_interpreter_result = SerializedInterpreterResult(
Expand All @@ -122,7 +237,7 @@ def interpret_module_to_result(
requires_output_allocator=interpreter_result.requires_output_allocator,
)

return serialized_interpreter_result
return serialized_interpreter_result # type: ignore[no-any-return]


def convert_module(
Expand Down
Loading
Loading