|
1 | 1 | import inspect
|
2 | 2 | import logging
|
| 3 | +import warnings |
3 | 4 | from copy import deepcopy
|
4 | 5 | from enum import Enum, auto
|
5 | 6 | from typing import Any, Dict, Iterator, Optional, Set, Union
|
@@ -476,6 +477,12 @@ def _process_kwarg_inputs(inputs: Any) -> Any:
|
476 | 477 | )
|
477 | 478 |
|
478 | 479 | def forward(self, *args: Any, **kwargs: Any) -> Any:
|
| 480 | + warnings.warn( |
| 481 | + "Direct calls to {self.__class__}.forward() are currently broken by due to https://github.com/pytorch/pytorch/issues/157183. Either call {self.__class__}(...) directly or use {self.__class__}._forward as a work around" |
| 482 | + ) |
| 483 | + return self._forward(*args, **kwargs) |
| 484 | + |
| 485 | + def _forward(self, *args: Any, **kwargs: Any) -> Any: |
479 | 486 | # Step 1: Check whether the input shape has changed
|
480 | 487 | kwargs = MutableTorchTensorRTModule._process_kwarg_inputs(kwargs)
|
481 | 488 | self._validate_inputs(*args, **kwargs)
|
@@ -535,7 +542,9 @@ def __deepcopy__(self, memo: Any) -> Any:
|
535 | 542 | return result
|
536 | 543 |
|
537 | 544 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
538 |
| - return self.forward(*args, **kwargs) |
| 545 | + # Due to https://github.com/pytorch/pytorch/issues/157183, we cannot use forward call, use _forward as a workaround. |
| 546 | + # This is a temporary fix. |
| 547 | + return self._forward(*args, **kwargs) |
539 | 548 |
|
540 | 549 | def __getattr__(self, name: str) -> Any:
|
541 | 550 | if name in self.__dict__:
|
|
0 commit comments