11use core:: cell:: UnsafeCell ;
22use core:: fmt;
3- use core:: task:: Waker ;
3+ use core:: task:: { RawWaker , Waker } ;
44
55#[ cfg( not( feature = "portable-atomic" ) ) ]
66use core:: sync:: atomic:: AtomicUsize ;
@@ -10,12 +10,20 @@ use portable_atomic::AtomicUsize;
1010
1111use crate :: raw:: TaskVTable ;
1212use crate :: state:: * ;
13+ use crate :: utils:: abort;
1314use crate :: utils:: abort_on_panic;
1415
15- /// The header of a task.
16- ///
17- /// This header is stored in memory at the beginning of the heap-allocated task.
18- pub ( crate ) struct Header < M > {
16+ /// Actions to take upon calling [`Header::drop_waker`].
17+ pub ( crate ) enum DropWakerAction {
18+ /// Re-schedule the task
19+ Schedule ,
20+ /// Destroy the task.
21+ Destroy ,
22+ /// Do nothing.
23+ None ,
24+ }
25+
26+ pub ( crate ) struct Header {
1927 /// Current state of the task.
2028 ///
2129 /// Contains flags representing the current state and the reference count.
@@ -32,17 +40,12 @@ pub(crate) struct Header<M> {
3240 /// methods necessary for bookkeeping the heap-allocated task.
3341 pub ( crate ) vtable : & ' static TaskVTable ,
3442
35- /// Metadata associated with the task.
36- ///
37- /// This metadata may be provided to the user.
38- pub ( crate ) metadata : M ,
39-
4043 /// Whether or not a panic that occurs in the task should be propagated.
4144 #[ cfg( feature = "std" ) ]
4245 pub ( crate ) propagate_panic : bool ,
4346}
4447
45- impl < M > Header < M > {
48+ impl Header {
4649 /// Notifies the awaiter blocked on this task.
4750 ///
4851 /// If the awaiter is the same as the current waker, it will not be notified.
@@ -157,11 +160,69 @@ impl<M> Header<M> {
157160 abort_on_panic ( || w. wake ( ) ) ;
158161 }
159162 }
163+
164+ /// Clones a waker.
165+ pub ( crate ) unsafe fn clone_waker ( ptr : * const ( ) ) -> RawWaker {
166+ let header = ptr as * const Header ;
167+
168+ // Increment the reference count. With any kind of reference-counted data structure,
169+ // relaxed ordering is appropriate when incrementing the counter.
170+ let state = ( * header) . state . fetch_add ( REFERENCE , Ordering :: Relaxed ) ;
171+
172+ // If the reference count overflowed, abort.
173+ if state > isize:: MAX as usize {
174+ abort ( ) ;
175+ }
176+
177+ RawWaker :: new ( ptr, ( * header) . vtable . raw_waker_vtable )
178+ }
179+
180+ #[ inline( never) ]
181+ pub ( crate ) unsafe fn drop_waker ( ptr : * const ( ) ) -> DropWakerAction {
182+ let header = ptr as * const Header ;
183+
184+ // Decrement the reference count.
185+ let new = ( * header) . state . fetch_sub ( REFERENCE , Ordering :: AcqRel ) - REFERENCE ;
186+
187+ // If this was the last reference to the task and the `Task` has been dropped too,
188+ // then we need to decide how to destroy the task.
189+ if new & !( REFERENCE - 1 ) == 0 && new & TASK == 0 {
190+ if new & ( COMPLETED | CLOSED ) == 0 {
191+ // If the task was not completed nor closed, close it and schedule one more time so
192+ // that its future gets dropped by the executor.
193+ ( * header)
194+ . state
195+ . store ( SCHEDULED | CLOSED | REFERENCE , Ordering :: Release ) ;
196+ DropWakerAction :: Schedule
197+ } else {
198+ // Otherwise, destroy the task right away.
199+ DropWakerAction :: Destroy
200+ }
201+ } else {
202+ DropWakerAction :: None
203+ }
204+ }
205+ }
206+
207+ // SAFETY: repr(C) is explicitly used here so that casts between `Header` and `HeaderWithMetadata`
208+ // can be done safely without additional offsets.
209+ //
210+ /// The header of a task.
211+ ///
212+ /// This header is stored in memory at the beginning of the heap-allocated task.
213+ #[ repr( C ) ]
214+ pub ( crate ) struct HeaderWithMetadata < M > {
215+ pub ( crate ) header : Header ,
216+
217+ /// Metadata associated with the task.
218+ ///
219+ /// This metadata may be provided to the user.
220+ pub ( crate ) metadata : M ,
160221}
161222
162- impl < M : fmt:: Debug > fmt:: Debug for Header < M > {
223+ impl < M : fmt:: Debug > fmt:: Debug for HeaderWithMetadata < M > {
163224 fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
164- let state = self . state . load ( Ordering :: SeqCst ) ;
225+ let state = self . header . state . load ( Ordering :: SeqCst ) ;
165226
166227 f. debug_struct ( "Header" )
167228 . field ( "scheduled" , & ( state & SCHEDULED != 0 ) )
0 commit comments