From 06fdf8be6f4abb62aed245f70b744804d8e7ef71 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Sat, 28 Jun 2025 16:48:13 -0700 Subject: [PATCH 1/4] fix mutable decomposition error. --- .../dynamo/runtime/_MutableTorchTensorRTModule.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 85a31b9736..14a49d8790 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -475,7 +475,9 @@ def _process_kwarg_inputs(inputs: Any) -> Any: + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) - def forward(self, *args: Any, **kwargs: Any) -> Any: + # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward as a workaround. + # This is a temporary fix. + def call_forward(self, *args: Any, **kwargs: Any) -> Any: # Step 1: Check whether the input shape has changed kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) self._validate_inputs(*args, **kwargs) @@ -535,7 +537,7 @@ def __deepcopy__(self, memo: Any) -> Any: return result def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.forward(*args, **kwargs) + return self.call_forward(*args, **kwargs) def __getattr__(self, name: str) -> Any: if name in self.__dict__: From bcbee3aa15d4546fce5e5dcf406609063abede71 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 30 Jun 2025 10:20:10 -0700 Subject: [PATCH 2/4] address comments --- .../dynamo/runtime/_MutableTorchTensorRTModule.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 14a49d8790..7c19006690 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -477,7 +477,7 @@ def _process_kwarg_inputs(inputs: Any) -> Any: # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward as a workaround. # This is a temporary fix. - def call_forward(self, *args: Any, **kwargs: Any) -> Any: + def _forward(self, *args: Any, **kwargs: Any) -> Any: # Step 1: Check whether the input shape has changed kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) self._validate_inputs(*args, **kwargs) @@ -537,7 +537,7 @@ def __deepcopy__(self, memo: Any) -> Any: return result def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.call_forward(*args, **kwargs) + return self._forward(*args, **kwargs) def __getattr__(self, name: str) -> Any: if name in self.__dict__: From 83a0488580caee0848e6b0e45cc17cd615d4a7d9 Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 30 Jun 2025 13:41:50 -0700 Subject: [PATCH 3/4] address comments --- .../dynamo/runtime/_MutableTorchTensorRTModule.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 7c19006690..2527e6087f 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,5 +1,6 @@ import inspect import logging +import warnings from copy import deepcopy from enum import Enum, auto from typing import Any, Dict, Iterator, Optional, Set, Union @@ -475,8 +476,10 @@ def _process_kwarg_inputs(inputs: Any) -> Any: + "Allowed input types: {torch_tensorrt.Input, torch.Tensor, list, tuple, dict}" ) - # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward as a workaround. - # This is a temporary fix. + def forward(self, *args: Any, **kwargs: Any) -> Any: + warnings.warn("Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around") + return self._forward(*args, **kwargs) + def _forward(self, *args: Any, **kwargs: Any) -> Any: # Step 1: Check whether the input shape has changed kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) @@ -537,6 +540,8 @@ def __deepcopy__(self, memo: Any) -> Any: return result def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward call, use _forward as a workaround. + # This is a temporary fix. return self._forward(*args, **kwargs) def __getattr__(self, name: str) -> Any: From 153b1429023156ce050c2679e8389daadb44a9bd Mon Sep 17 00:00:00 2001 From: lanluo-nvidia Date: Mon, 30 Jun 2025 13:45:08 -0700 Subject: [PATCH 4/4] fix lint --- .../dynamo/runtime/_MutableTorchTensorRTModule.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 2527e6087f..ba83c89dcd 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -477,7 +477,9 @@ def _process_kwarg_inputs(inputs: Any) -> Any: ) def forward(self, *args: Any, **kwargs: Any) -> Any: - warnings.warn("Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around") + warnings.warn( + "Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around" + ) return self._forward(*args, **kwargs) def _forward(self, *args: Any, **kwargs: Any) -> Any: