@@ -171,8 +171,7 @@ def __init__(
171171        self ._input_buffers : List [torch .Tensor ] =  []
172172        self ._output_buffers : List [torch .Tensor ] =  []
173173        self .cudagraph : Optional [torch .cuda .CUDAGraph ] =  None 
174-         self ._caller_stream : Optional [torch .cuda .Stream ] =  None 
175-         self ._engine_stream : Optional [torch .cuda .Stream ] =  None 
174+         self ._engine_stream : torch .cuda .Stream  =  torch .cuda .current_stream ()
176175        self .output_tensors : Optional [List [torch .Tensor ]] =  None 
177176        self .sync_stream  =  True 
178177
@@ -287,13 +286,7 @@ def setup_engine(self) -> None:
287286        ), f"TensorRT engine was not built to target current platform (target: { self .target_platform }  , current: { Platform .current_platform ()}  )" 
288287        # Stream handling: if the caller stream is the pytorch default stream, create a new engine stream 
289288        # otherwise, use the caller stream and disable stream synchronization 
290-         self ._caller_stream  =  torch .cuda .current_stream ()
291-         if  self ._caller_stream  ==  torch .cuda .default_stream ():
292-             self ._engine_stream  =  torch .cuda .Stream ()
293-             self .sync_stream  =  True 
294-         else :
295-             self ._engine_stream  =  self ._caller_stream 
296-             self .sync_stream  =  False 
289+         self ._engine_stream  =  torch .cuda .current_stream ()
297290
298291        self .initialized  =  True 
299292        runtime  =  trt .Runtime (TRT_LOGGER )
@@ -559,9 +552,6 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
559552                else  nullcontext ()
560553            ):
561554
562-                 if  self .sync_stream :
563-                     self ._engine_stream .wait_stream (self ._caller_stream )
564- 
565555                if  self .cudagraphs_enabled :
566556                    if  need_cudagraphs_record :
567557                        self .cudagraph  =  torch .cuda .CUDAGraph ()
@@ -587,10 +577,16 @@ def run_standard_execution() -> torch.Tensor | Tuple[torch.Tensor, ...]:
587577                    self .cudagraph .replay ()  # type: ignore 
588578
589579                else :
590-                     self . context . execute_async_v3 ( self . _engine_stream . cuda_stream ) 
580+                     import   warnings 
591581
592-                 if  self .sync_stream :
593-                     self ._caller_stream .wait_stream (self ._engine_stream )
582+                     with  warnings .catch_warnings ():
583+                         try :
584+                             self .context .execute_async_v3 (
585+                                 self ._engine_stream .cuda_stream 
586+                             )
587+                         except  Warning  as  e :
588+                             breakpoint ()
589+                             print ("warning ignored" )
594590
595591            if  self .use_pre_allocated_outputs :
596592                self .pre_allocated_outputs  =  self .create_output_tensors ()
@@ -645,22 +641,12 @@ def run_output_allocator() -> torch.Tensor | Tuple[torch.Tensor, ...]:
645641                if  self .profiling_enabled 
646642                else  nullcontext ()
647643            ):
648-                 self ._caller_stream  =  torch .cuda .current_stream ()
649-                 if  (
650-                     self ._engine_stream  ==  torch .cuda .default_stream ()
651-                     or  self ._engine_stream  is  None 
652-                 ):
653-                     self ._engine_stream  =  torch .cuda .Stream ()
654- 
655-                 self ._engine_stream .wait_stream (self ._caller_stream )
656644
657645                with  torch .cuda .stream (self ._engine_stream ):
658646                    self .context .execute_async_v3 (
659647                        self ._engine_stream .cuda_stream 
660648                    )  # The OutputAllocator is called by execute_async_v3() 
661649
662-                 self ._caller_stream .wait_stream (self ._engine_stream )
663- 
664650            with  (
665651                torch .autograd .profiler .record_function (
666652                    "PythonTorchTensorRTModule:ProcessOutputs" 
0 commit comments