Skip to content

Commit 1c0a8aa

Browse files
committed
Changed the stream of python runtime to default stream
1 parent 52f7c48 commit 1c0a8aa

File tree

1 file changed

+11
-25
lines changed

1 file changed

+11
-25
lines changed

py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,7 @@ def __init__(
171171
self._input_buffers: List[torch.Tensor] = []
172172
self._output_buffers: List[torch.Tensor] = []
173173
self.cudagraph: Optional[torch.cuda.CUDAGraph] = None
174-
self._caller_stream: Optional[torch.cuda.Stream] = None
175-
self._engine_stream: Optional[torch.cuda.Stream] = None
174+
self._engine_stream: torch.cuda.Stream = torch.cuda.current_stream()
176175
self.output_tensors: Optional[List[torch.Tensor]] = None
177176
self.sync_stream = True
178177

@@ -287,13 +286,7 @@ def setup_engine(self) -> None:
287286
), f"TensorRT engine was not built to target current platform (target: {self.target_platform}, current: {Platform.current_platform()})"
288287
# Stream handling: if the caller stream is the pytorch default stream, create a new engine stream
289288
# otherwise, use the caller stream and disable stream synchronization
290-
self._caller_stream = torch.cuda.current_stream()
291-
if self._caller_stream == torch.cuda.default_stream():
292-
self._engine_stream = torch.cuda.Stream()
293-
self.sync_stream = True
294-
else:
295-
self._engine_stream = self._caller_stream
296-
self.sync_stream = False
289+
self._engine_stream = torch.cuda.current_stream()
297290

298291
self.initialized = True
299292
runtime = trt.Runtime(TRT_LOGGER)
@@ -559,9 +552,6 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
559552
else nullcontext()
560553
):
561554

562-
if self.sync_stream:
563-
self._engine_stream.wait_stream(self._caller_stream)
564-
565555
if self.cudagraphs_enabled:
566556
if need_cudagraphs_record:
567557
self.cudagraph = torch.cuda.CUDAGraph()
@@ -587,10 +577,16 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
587577
self.cudagraph.replay() # type: ignore
588578

589579
else:
590-
self.context.execute_async_v3(self._engine_stream.cuda_stream)
580+
import warnings
591581

592-
if self.sync_stream:
593-
self._caller_stream.wait_stream(self._engine_stream)
582+
with warnings.catch_warnings():
583+
try:
584+
self.context.execute_async_v3(
585+
self._engine_stream.cuda_stream
586+
)
587+
except Warning as e:
588+
breakpoint()
589+
print("warning ignored")
594590

595591
if self.use_pre_allocated_outputs:
596592
self.pre_allocated_outputs = self.create_output_tensors()
@@ -645,22 +641,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
645641
if self.profiling_enabled
646642
else nullcontext()
647643
):
648-
self._caller_stream = torch.cuda.current_stream()
649-
if (
650-
self._engine_stream == torch.cuda.default_stream()
651-
or self._engine_stream is None
652-
):
653-
self._engine_stream = torch.cuda.Stream()
654-
655-
self._engine_stream.wait_stream(self._caller_stream)
656644

657645
with torch.cuda.stream(self._engine_stream):
658646
self.context.execute_async_v3(
659647
self._engine_stream.cuda_stream
660648
) # The OutputAllocator is called by execute_async_v3()
661649

662-
self._caller_stream.wait_stream(self._engine_stream)
663-
664650
with (
665651
torch.autograd.profiler.record_function(
666652
"PythonTorchTensorRTModule:ProcessOutputs"

0 commit comments

Comments
 (0)