diff --git a/glommio/src/sync/gate.rs b/glommio/src/sync/gate.rs index 704ee7a3b..e72cfdf6e 100644 --- a/glommio/src/sync/gate.rs +++ b/glommio/src/sync/gate.rs @@ -12,7 +12,7 @@ use crate::{ #[derive(Debug)] enum State { - Closing(LocalSender), + Closing(Vec>), Closed, Open, } @@ -109,7 +109,6 @@ impl Gate { } } -type PreviousWaiter = Option>; type CurrentClosure = LocalReceiver; #[derive(Debug)] @@ -143,14 +142,10 @@ impl GateInner { } async fn wait_for_closure( - waiter: Result, GlommioError<()>>, + waiter: Result, GlommioError<()>>, ) -> Result<(), GlommioError<()>> { - if let Some((waiter, previous_closer)) = waiter? { + if let Some(waiter) = waiter? { waiter.recv().await; - if let Some(previous_closer) = previous_closer { - // Previous channel may be dropped so ignore the result. - let _ = previous_closer.try_send(true); - } } Ok(()) @@ -161,26 +156,31 @@ impl GateInner { State::Open => { if self.count.get() != 0 { let (sender, receiver) = local_channel::new_bounded(1); - self.state.replace(State::Closing(sender)); - Self::wait_for_closure(Ok(Some((receiver, None)))) + self.state.replace(State::Closing(vec![sender])); + Self::wait_for_closure(Ok(Some(receiver))) } else { Self::wait_for_closure(Ok(None)) } } - State::Closing(previous_closer) => { + State::Closing(mut previous_closers) => { assert!( self.count.get() != 0, "If count is 0 then the state should have been marked as closed" ); assert!( - !previous_closer.is_full(), + !previous_closers.is_empty(), + "Should have at least one closer" + ); + assert!( + !previous_closers[0].is_full(), "Already notified that the gate is closed!" ); let (sender, receiver) = local_channel::new_bounded(1); - self.state.replace(State::Closing(sender)); + previous_closers.push(sender); + self.state.replace(State::Closing(previous_closers)); - Self::wait_for_closure(Ok(Some((receiver, Some(previous_closer))))) + Self::wait_for_closure(Ok(Some(receiver))) } State::Closed => Self::wait_for_closure(Err(GlommioError::Closed(ResourceType::Gate))), } @@ -195,8 +195,10 @@ impl GateInner { } pub fn notify_closed(&self) { - if let State::Closing(sender) = self.state.replace(State::Closed) { - sender.try_send(true).unwrap(); + if let State::Closing(senders) = self.state.replace(State::Closed) { + for sender in senders { + let _ = sender.try_send(true); + } } else { unreachable!("It should not happen!"); } @@ -208,7 +210,7 @@ mod tests { use super::*; use crate::sync::Semaphore; use crate::{enclose, timer::timeout, LocalExecutor}; - use futures::join; + use futures::{join, FutureExt}; use std::time::Duration; #[test] @@ -371,4 +373,43 @@ mod tests { ); }) } + + #[test] + fn test_closure_poll_order_irrelevant() { + LocalExecutor::default().run(async { + for poll_close1_first in [true, false] { + let gate = Gate::new(); + let pass = gate.enter().unwrap(); + let close1 = gate.close(); + let close2 = gate.close(); + std::mem::drop(pass); + let (close1, close2) = if poll_close1_first { + (close1, close2) + } else { + (close2, close1) + }; + close1 + .now_or_never() + .expect("Close future should be ready") + .expect("Close signal should have arrived successfully"); + close2 + .now_or_never() + .expect("Close future should be ready") + .expect("Close signal should have arrived successfully"); + } + }) + } + + #[test] + fn test_dropped_closure_still_unblocks() { + LocalExecutor::default().run(async { + let gate = Gate::new(); + let pass = gate.enter().unwrap(); + let close1 = gate.close(); + let close2 = gate.close(); + std::mem::drop(close2); + std::mem::drop(pass); + close1.await.expect("Closure signal should still arrive"); + }) + } }