@@ -118,13 +118,13 @@ static nb::ndarray<> dlpack(nb::handle_t<ArrayBase> h, bool force_cpu, nb::handl
118118 // https://data-apis.org/array-api/latest/API_specification/generated/array_api.array.__dlpack__.html
119119 /*
120120 stream = -1 request producer to perform no synchronization
121- stream = 0 is ambiguous
121+ stream = 0 is ambiguous (could mean either None, 1, or 2)
122122 stream = 1 or None is the legacy default stream
123123 stream = 2 is the per-thread default stream
124124 stream > 2 is a CUDA handle to the consumer's stream
125125 */
126126 if (!stream.is_none () && !stream.equal (nb::int_ (-1 )) && !stream.equal (nb::int_ (1 ))) {
127- if (stream.equal (nb::int_ (0 )))
127+ if (stream.equal (nb::int_ (0 )) || stream. equal ( nb::int_ ( 2 )) )
128128 jit_sync_thread ();
129129 else {
130130 uintptr_t stream_handle;
@@ -133,6 +133,7 @@ static nb::ndarray<> dlpack(nb::handle_t<ArrayBase> h, bool force_cpu, nb::handl
133133 jit_cuda_sync_stream (stream_handle);
134134 }
135135 }
136+
136137 } else {
137138 jit_sync_thread ();
138139 }
@@ -265,6 +266,9 @@ void export_dlpack(nb::module_ &) {
265266 .def (" tf" ,
266267 [](nb::handle_t <ArrayBase> h) {
267268 nb::module_ tf = nb::module_::import_ (" tensorflow.experimental.dlpack" );
268- return tf.attr (" from_dlpack" )(dlpack (h, false ));
269+ // TensorFlow uses non-default streams for compute and data transfer, so
270+ // we must synchronize on the stream used by DrJit (producer) before
271+ // proceeding with TF.
272+ return tf.attr (" from_dlpack" )(dlpack (h, /* force_cpu */ false , /* stream */ nb::int_ (2 )));
269273 }, doc_tf);
270274}
0 commit comments