@@ -165,6 +165,8 @@ namespace nvexec {
165165 }
166166 };
167167
168+ struct stream_scheduler ;
169+
168170 struct context_state_t {
169171 std::pmr::memory_resource* pinned_resource_{nullptr };
170172 std::pmr::memory_resource* managed_resource_{nullptr };
@@ -195,9 +197,9 @@ namespace nvexec {
195197 void return_stream (cudaStream_t stream) {
196198 stream_pools_->return_stream (stream, priority_);
197199 }
198- };
199200
200- struct stream_scheduler ;
201+ stream_scheduler make_stream_scheduler () const noexcept ;
202+ };
201203
202204 struct stream_sender_base {
203205 using is_sender = void ;
@@ -265,6 +267,10 @@ namespace nvexec {
265267 stream_provider_t * operator ()(const Env& env) const noexcept {
266268 return tag_invoke (get_stream_provider_t {}, env);
267269 }
270+
271+ friend constexpr bool tag_invoke (forwarding_query_t , const get_stream_provider_t &) noexcept {
272+ return true ;
273+ }
268274 };
269275
270276 template <class ... Ts>
@@ -308,7 +314,10 @@ namespace nvexec {
308314 using variant_storage_t = //
309315 __minvoke< __minvoke<
310316 __mfold_right<
311- __mbind_front_q<stream_storage_impl::variant, ::cuda::std::tuple<set_noop>>,
317+ __mbind_front_q<
318+ stream_storage_impl::variant,
319+ ::cuda::std::tuple<set_noop>,
320+ ::cuda::std::tuple<set_error_t , cudaError_t>>,
312321 __mbind_front_q<stream_storage_impl::__bind_completions_t , _Sender, _Env>>,
313322 set_value_t ,
314323 set_error_t ,
@@ -330,7 +339,21 @@ namespace nvexec {
330339
331340 template <class BaseEnv >
332341 auto make_stream_env (BaseEnv&& base_env, stream_provider_t * stream_provider) noexcept {
333- return __join_env (__mkprop (get_stream_provider, stream_provider), (BaseEnv&&) base_env);
342+ return __join_env (
343+ __env::__env_fn{
344+ [stream_provider]<__one_of<get_stream_provider_t , get_scheduler_t , get_domain_t > Tag>(
345+ Tag) noexcept {
346+ __mfront<stream_provider_t , Tag>* str_provider = stream_provider;
347+ if constexpr (same_as<Tag, get_stream_provider_t >) {
348+ return str_provider;
349+ } else if constexpr (same_as<Tag, get_scheduler_t >) {
350+ return str_provider->context_ .make_stream_scheduler ();
351+ } else {
352+ return get_domain (str_provider->context_ .make_stream_scheduler ());
353+ }
354+ STDEXEC_UNREACHABLE ();
355+ }},
356+ (BaseEnv&&) base_env);
334357 }
335358
336359 template <class BaseEnv >
@@ -370,6 +393,10 @@ namespace nvexec {
370393 stream_sender_base,
371394 __decay_t <transform_sender_result_t <__env_domain_of_t <E>, S, E>>);
372395
396+ struct stream_scheduler ;
397+ template <class = stream_scheduler>
398+ struct stream_domain ;
399+
373400 template <class R >
374401 concept stream_receiver = //
375402 receiver<R> && //
@@ -427,8 +454,8 @@ namespace nvexec {
427454 };
428455 };
429456
430- template <class Receiver , class ... As , class Tag >
431- __launch_bounds__ (1 ) __global__ void continuation_kernel(Receiver rcvr, Tag, As... as) {
457+ template <class Receiver , class Tag , class ... As >
458+ __launch_bounds__ (1 ) __global__ void continuation_kernel(Receiver rcvr, As... as) {
432459 static_assert (trivially_copyable<Receiver, Tag, As...>);
433460 Tag ()(::cuda::std::move (rcvr), static_cast <As&&>(as)...);
434461 }
@@ -552,7 +579,7 @@ namespace nvexec {
552579 if (cudaCpuDeviceId == device_id) {
553580 ptr->~T ();
554581 } else {
555- destructor_kernel<<<1 , 1 , 0 , stream>>> (ptr);
582+ STDEXEC_STREAM_DETAIL_NS:: destructor_kernel<<<1 , 1 , 0 , stream>>> (ptr);
556583
557584 // TODO Bury all the memory associated with the stream provider and then
558585 // deallocate the memory
@@ -573,9 +600,9 @@ namespace nvexec {
573600 if constexpr (stream_receiver<outer_receiver_t >) {
574601 set_error ((outer_receiver_t &&) rcvr_, (cudaError_t&&) status);
575602 } else {
576- // pass a cudaError_t by value :
577- continuation_kernel<outer_receiver_t , Error>
578- <<<1 , 1 , 0 , get_stream()>>> ((outer_receiver_t &&) rcvr_, set_error_t (), status);
603+ STDEXEC_STREAM_DETAIL_NS: :
604+ continuation_kernel<outer_receiver_t , set_error_t , cudaError_t> // by value
605+ <<<1 , 1 , 0 , get_stream()>>> ((outer_receiver_t &&) rcvr_, status);
579606 }
580607 }
581608
@@ -584,8 +611,9 @@ namespace nvexec {
584611 if constexpr (stream_receiver<outer_receiver_t >) {
585612 Tag ()((outer_receiver_t &&) rcvr_, (As&&) as...);
586613 } else {
587- continuation_kernel<outer_receiver_t , As&&...> // by reference
588- <<<1 , 1 , 0 , get_stream()>>> ((outer_receiver_t &&) rcvr_, Tag (), (As&&) as...);
614+ STDEXEC_STREAM_DETAIL_NS::
615+ continuation_kernel<outer_receiver_t , Tag, As&&...> // by reference
616+ <<<1 , 1 , 0 , get_stream()>>> ((outer_receiver_t &&) rcvr_, (As&&) as...);
589617 }
590618 }
591619 };
0 commit comments