Skip to content

Commit 21cfb78

Browse files
Temporary fix to workaround the mutable decomposition error. (#3636)
1 parent b3b5f45 commit 21cfb78

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import inspect
22
import logging
3+
import warnings
34
from copy import deepcopy
45
from enum import Enum, auto
56
from typing import Any, Dict, Iterator, Optional, Set, Union
@@ -476,6 +477,12 @@ def _process_kwarg_inputs(inputs: Any) -> Any:
476477
)
477478

478479
def forward(self, *args: Any, **kwargs: Any) -> Any:
480+
warnings.warn(
481+
"Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around"
482+
)
483+
return self._forward(*args, **kwargs)
484+
485+
def _forward(self, *args: Any, **kwargs: Any) -> Any:
479486
# Step 1: Check whether the input shape has changed
480487
kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs)
481488
self._validate_inputs(*args, **kwargs)
@@ -535,7 +542,9 @@ def __deepcopy__(self, memo: Any) -> Any:
535542
return result
536543

537544
def __call__(self, *args: Any, **kwargs: Any) -> Any:
538-
return self.forward(*args, **kwargs)
545+
# Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward call, use _forward as a workaround.
546+
# This is a temporary fix.
547+
return self._forward(*args, **kwargs)
539548

540549
def __getattr__(self, name: str) -> Any:
541550
if name in self.__dict__:

0 commit comments

Comments
 (0)