diff --git a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py index 85a31b9736..ba83c89dcd 100644 --- a/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py +++ b/py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py @@ -1,5 +1,6 @@ import inspect import logging +import warnings from copy import deepcopy from enum import Enum, auto from typing import Any, Dict, Iterator, Optional, Set, Union @@ -476,6 +477,12 @@ def _process_kwarg_inputs(inputs: Any) -> Any: ) def forward(self, *args: Any, **kwargs: Any) -> Any: + warnings.warn( + "Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around" + ) + return self._forward(*args, **kwargs) + + def _forward(self, *args: Any, **kwargs: Any) -> Any: # Step 1: Check whether the input shape has changed kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs) self._validate_inputs(*args, **kwargs) @@ -535,7 +542,9 @@ def __deepcopy__(self, memo: Any) -> Any: return result def __call__(self, *args: Any, **kwargs: Any) -> Any: - return self.forward(*args, **kwargs) + # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward call, use _forward as a workaround. + # This is a temporary fix. + return self._forward(*args, **kwargs) def __getattr__(self, name: str) -> Any: if name in self.__dict__: