Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 58 additions & 17 deletions glommio/src/sync/gate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{

#[derive(Debug)]
enum State {
Closing(LocalSender<bool>),
Closing(Vec<LocalSender<bool>>),
Closed,
Open,
}
Expand Down Expand Up @@ -109,7 +109,6 @@ impl Gate {
}
}

type PreviousWaiter = Option<LocalSender<bool>>;
type CurrentClosure = LocalReceiver<bool>;

#[derive(Debug)]
Expand Down Expand Up @@ -143,14 +142,10 @@ impl GateInner {
}

async fn wait_for_closure(
waiter: Result<Option<(CurrentClosure, PreviousWaiter)>, GlommioError<()>>,
waiter: Result<Option<CurrentClosure>, GlommioError<()>>,
) -> Result<(), GlommioError<()>> {
if let Some((waiter, previous_closer)) = waiter? {
if let Some(waiter) = waiter? {
waiter.recv().await;
if let Some(previous_closer) = previous_closer {
// Previous channel may be dropped so ignore the result.
let _ = previous_closer.try_send(true);
}
}

Ok(())
Expand All @@ -161,26 +156,31 @@ impl GateInner {
State::Open => {
if self.count.get() != 0 {
let (sender, receiver) = local_channel::new_bounded(1);
self.state.replace(State::Closing(sender));
Self::wait_for_closure(Ok(Some((receiver, None))))
self.state.replace(State::Closing(vec![sender]));
Self::wait_for_closure(Ok(Some(receiver)))
} else {
Self::wait_for_closure(Ok(None))
}
}
State::Closing(previous_closer) => {
State::Closing(mut previous_closers) => {
assert!(
self.count.get() != 0,
"If count is 0 then the state should have been marked as closed"
);
assert!(
!previous_closer.is_full(),
!previous_closers.is_empty(),
"Should have at least one closer"
);
assert!(
!previous_closers[0].is_full(),
"Already notified that the gate is closed!"
);

let (sender, receiver) = local_channel::new_bounded(1);
self.state.replace(State::Closing(sender));
previous_closers.push(sender);
self.state.replace(State::Closing(previous_closers));

Self::wait_for_closure(Ok(Some((receiver, Some(previous_closer)))))
Self::wait_for_closure(Ok(Some(receiver)))
}
State::Closed => Self::wait_for_closure(Err(GlommioError::Closed(ResourceType::Gate))),
}
Expand All @@ -195,8 +195,10 @@ impl GateInner {
}

pub fn notify_closed(&self) {
if let State::Closing(sender) = self.state.replace(State::Closed) {
sender.try_send(true).unwrap();
if let State::Closing(senders) = self.state.replace(State::Closed) {
for sender in senders {
let _ = sender.try_send(true);
}
} else {
unreachable!("It should not happen!");
}
Expand All @@ -208,7 +210,7 @@ mod tests {
use super::*;
use crate::sync::Semaphore;
use crate::{enclose, timer::timeout, LocalExecutor};
use futures::join;
use futures::{join, FutureExt};
use std::time::Duration;

#[test]
Expand Down Expand Up @@ -371,4 +373,43 @@ mod tests {
);
})
}

#[test]
fn test_closure_poll_order_irrelevant() {
LocalExecutor::default().run(async {
for poll_close1_first in [true, false] {
let gate = Gate::new();
let pass = gate.enter().unwrap();
let close1 = gate.close();
let close2 = gate.close();
std::mem::drop(pass);
let (close1, close2) = if poll_close1_first {
(close1, close2)
} else {
(close2, close1)
};
close1
.now_or_never()
.expect("Close future should be ready")
.expect("Close signal should have arrived successfully");
close2
.now_or_never()
.expect("Close future should be ready")
.expect("Close signal should have arrived successfully");
}
})
}

#[test]
fn test_dropped_closure_still_unblocks() {
LocalExecutor::default().run(async {
let gate = Gate::new();
let pass = gate.enter().unwrap();
let close1 = gate.close();
let close2 = gate.close();
std::mem::drop(close2);
std::mem::drop(pass);
close1.await.expect("Closure signal should still arrive");
})
}
}
Loading