Skip to content

Commit 28d2db5

Browse files
committed
Moved state ownership into and
1 parent 3210230 commit 28d2db5

File tree

11 files changed

+233
-170
lines changed

11 files changed

+233
-170
lines changed

concurrency/src/tasks/gen_server.rs

Lines changed: 52 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -81,23 +81,23 @@ impl<G: GenServer> GenServerHandle<G> {
8181
}
8282
}
8383

84-
pub enum GenServerInMsg<A: GenServer> {
84+
pub enum GenServerInMsg<G: GenServer> {
8585
Call {
86-
sender: oneshot::Sender<Result<A::OutMsg, GenServerError>>,
87-
message: A::CallMsg,
86+
sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
87+
message: G::CallMsg,
8888
},
8989
Cast {
90-
message: A::CastMsg,
90+
message: G::CastMsg,
9191
},
9292
}
9393

94-
pub enum CallResponse<U> {
95-
Reply(U),
96-
Stop(U),
94+
pub enum CallResponse<G: GenServer> {
95+
Reply(G::State, G::OutMsg),
96+
Stop(G::OutMsg),
9797
}
9898

99-
pub enum CastResponse {
100-
NoReply,
99+
pub enum CastResponse<G: GenServer> {
100+
NoReply(G::State),
101101
Stop,
102102
}
103103

@@ -177,74 +177,81 @@ where
177177
&mut self,
178178
handle: &GenServerHandle<Self>,
179179
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
180-
mut state: Self::State,
180+
state: Self::State,
181181
) -> impl std::future::Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
182182
async move {
183183
let message = rx.recv().await;
184184

185185
// Save current state in case of a rollback
186186
let state_clone = state.clone();
187187

188-
let (keep_running, error) = match message {
188+
let (keep_running, new_state) = match message {
189189
Some(GenServerInMsg::Call { sender, message }) => {
190-
let (keep_running, error, response) =
191-
match AssertUnwindSafe(self.handle_call(message, handle, &mut state))
190+
let (keep_running, new_state, response) =
191+
match AssertUnwindSafe(self.handle_call(message, handle, state))
192192
.catch_unwind()
193193
.await
194194
{
195195
Ok(response) => match response {
196-
CallResponse::Reply(response) => (true, None, Ok(response)),
197-
CallResponse::Stop(response) => (false, None, Ok(response)),
196+
CallResponse::Reply(new_state, response) => {
197+
(true, new_state, Ok(response))
198+
}
199+
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
198200
},
199-
Err(error) => (true, Some(error), Err(GenServerError::Callback)),
201+
Err(error) => {
202+
tracing::trace!(
203+
"Error in callback, reverting state - Error: '{error:?}'"
204+
);
205+
(true, state_clone, Err(GenServerError::Callback))
206+
}
200207
};
201208
// Send response back
202209
if sender.send(response).is_err() {
203210
tracing::trace!(
204211
"GenServer failed to send response back, client must have died"
205212
)
206213
};
207-
(keep_running, error)
214+
(keep_running, new_state)
208215
}
209216
Some(GenServerInMsg::Cast { message }) => {
210-
match AssertUnwindSafe(self.handle_cast(message, handle, &mut state))
217+
match AssertUnwindSafe(self.handle_cast(message, handle, state))
211218
.catch_unwind()
212219
.await
213220
{
214221
Ok(response) => match response {
215-
CastResponse::NoReply => (true, None),
216-
CastResponse::Stop => (false, None),
222+
CastResponse::NoReply(new_state) => (true, new_state),
223+
CastResponse::Stop => (false, state_clone),
217224
},
218-
Err(error) => (true, Some(error)),
225+
Err(error) => {
226+
tracing::trace!(
227+
"Error in callback, reverting state - Error: '{error:?}'"
228+
);
229+
(true, state_clone)
230+
}
219231
}
220232
}
221233
None => {
222234
// Channel has been closed; won't receive further messages. Stop the server.
223-
(false, None)
235+
(false, state)
224236
}
225237
};
226-
if let Some(error) = error {
227-
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
228-
// Restore initial state (ie. dismiss any change)
229-
state = state_clone;
230-
};
231-
Ok((state, keep_running))
238+
Ok((new_state, keep_running))
232239
}
233240
}
234241

235242
fn handle_call(
236243
&mut self,
237244
message: Self::CallMsg,
238245
handle: &GenServerHandle<Self>,
239-
state: &mut Self::State,
240-
) -> impl std::future::Future<Output = CallResponse<Self::OutMsg>> + Send;
246+
state: Self::State,
247+
) -> impl std::future::Future<Output = CallResponse<Self>> + Send;
241248

242249
fn handle_cast(
243250
&mut self,
244251
message: Self::CastMsg,
245252
handle: &GenServerHandle<Self>,
246-
state: &mut Self::State,
247-
) -> impl std::future::Future<Output = CastResponse> + Send;
253+
state: Self::State,
254+
) -> impl std::future::Future<Output = CastResponse<Self>> + Send;
248255
}
249256

250257
#[cfg(test)]
@@ -279,17 +286,17 @@ mod tests {
279286
&mut self,
280287
_: Self::CallMsg,
281288
_: &GenServerHandle<Self>,
282-
_: &mut Self::State,
283-
) -> CallResponse<Self::OutMsg> {
289+
_: Self::State,
290+
) -> CallResponse<Self> {
284291
CallResponse::Stop(())
285292
}
286293

287294
async fn handle_cast(
288295
&mut self,
289296
_: Self::CastMsg,
290297
_: &GenServerHandle<Self>,
291-
_: &mut Self::State,
292-
) -> CastResponse {
298+
_: Self::State,
299+
) -> CastResponse<Self> {
293300
rt::sleep(Duration::from_millis(20)).await;
294301
thread::sleep(Duration::from_secs(2));
295302
CastResponse::Stop
@@ -318,10 +325,13 @@ mod tests {
318325
&mut self,
319326
message: Self::CallMsg,
320327
_: &GenServerHandle<Self>,
321-
state: &mut Self::State,
322-
) -> CallResponse<Self::OutMsg> {
328+
state: Self::State,
329+
) -> CallResponse<Self> {
323330
match message {
324-
InMessage::GetCount => CallResponse::Reply(OutMsg::Count(state.count)),
331+
InMessage::GetCount => {
332+
let count = state.count;
333+
CallResponse::Reply(state, OutMsg::Count(count))
334+
}
325335
InMessage::Stop => CallResponse::Stop(OutMsg::Count(state.count)),
326336
}
327337
}
@@ -330,12 +340,12 @@ mod tests {
330340
&mut self,
331341
_: Self::CastMsg,
332342
handle: &GenServerHandle<Self>,
333-
state: &mut Self::State,
334-
) -> CastResponse {
343+
mut state: Self::State,
344+
) -> CastResponse<Self> {
335345
state.count += 1;
336346
println!("{:?}: good still alive", thread::current().id());
337347
send_after(Duration::from_millis(100), handle.to_owned(), ());
338-
CastResponse::NoReply
348+
CastResponse::NoReply(state)
339349
}
340350
}
341351

concurrency/src/threads/gen_server.rs

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,23 +59,23 @@ impl<G: GenServer> GenServerHandle<G> {
5959
}
6060
}
6161

62-
pub enum GenServerInMsg<A: GenServer> {
62+
pub enum GenServerInMsg<G: GenServer> {
6363
Call {
64-
sender: oneshot::Sender<Result<A::OutMsg, GenServerError>>,
65-
message: A::CallMsg,
64+
sender: oneshot::Sender<Result<G::OutMsg, GenServerError>>,
65+
message: G::CallMsg,
6666
},
6767
Cast {
68-
message: A::CastMsg,
68+
message: G::CastMsg,
6969
},
7070
}
7171

72-
pub enum CallResponse<U> {
73-
Reply(U),
74-
Stop(U),
72+
pub enum CallResponse<G: GenServer> {
73+
Reply(G::State, G::OutMsg),
74+
Stop(G::OutMsg),
7575
}
7676

77-
pub enum CastResponse {
78-
NoReply,
77+
pub enum CastResponse<G: GenServer> {
78+
NoReply(G::State),
7979
Stop,
8080
}
8181

@@ -148,65 +148,71 @@ where
148148
&mut self,
149149
handle: &GenServerHandle<Self>,
150150
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
151-
mut state: Self::State,
151+
state: Self::State,
152152
) -> Result<(Self::State, bool), GenServerError> {
153153
let message = rx.recv().ok();
154154

155155
// Save current state in case of a rollback
156156
let state_clone = state.clone();
157157

158-
let (keep_running, error) = match message {
158+
let (keep_running, new_state) = match message {
159159
Some(GenServerInMsg::Call { sender, message }) => {
160-
let (keep_running, error, response) = match catch_unwind(AssertUnwindSafe(|| {
161-
self.handle_call(message, handle, &mut state)
162-
})) {
163-
Ok(response) => match response {
164-
CallResponse::Reply(response) => (true, None, Ok(response)),
165-
CallResponse::Stop(response) => (false, None, Ok(response)),
166-
},
167-
Err(error) => (true, Some(error), Err(GenServerError::Callback)),
168-
};
160+
let (keep_running, new_state, response) =
161+
match catch_unwind(AssertUnwindSafe(|| {
162+
self.handle_call(message, handle, state)
163+
})) {
164+
Ok(response) => match response {
165+
CallResponse::Reply(new_state, response) => {
166+
(true, new_state, Ok(response))
167+
}
168+
CallResponse::Stop(response) => (false, state_clone, Ok(response)),
169+
},
170+
Err(error) => {
171+
tracing::trace!(
172+
"Error in callback, reverting state - Error: '{error:?}'"
173+
);
174+
(true, state_clone, Err(GenServerError::Callback))
175+
}
176+
};
169177
// Send response back
170178
if sender.send(response).is_err() {
171179
tracing::trace!("GenServer failed to send response back, client must have died")
172180
};
173-
(keep_running, error)
181+
(keep_running, new_state)
174182
}
175183
Some(GenServerInMsg::Cast { message }) => {
176184
match catch_unwind(AssertUnwindSafe(|| {
177-
self.handle_cast(message, handle, &mut state)
185+
self.handle_cast(message, handle, state)
178186
})) {
179187
Ok(response) => match response {
180-
CastResponse::NoReply => (true, None),
181-
CastResponse::Stop => (false, None),
188+
CastResponse::NoReply(new_state) => (true, new_state),
189+
CastResponse::Stop => (false, state_clone),
182190
},
183-
Err(error) => (true, Some(error)),
191+
Err(error) => {
192+
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
193+
(true, state_clone)
194+
}
184195
}
185196
}
186197
None => {
187198
// Channel has been closed; won't receive further messages. Stop the server.
188-
(false, None)
199+
(false, state)
189200
}
190201
};
191-
if let Some(error) = error {
192-
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
193-
// Restore initial state (ie. dismiss any change)
194-
state = state_clone;
195-
};
196-
Ok((state, keep_running))
202+
Ok((new_state, keep_running))
197203
}
198204

199205
fn handle_call(
200206
&mut self,
201207
message: Self::CallMsg,
202208
handle: &GenServerHandle<Self>,
203-
state: &mut Self::State,
204-
) -> CallResponse<Self::OutMsg>;
209+
state: Self::State,
210+
) -> CallResponse<Self>;
205211

206212
fn handle_cast(
207213
&mut self,
208214
message: Self::CastMsg,
209215
handle: &GenServerHandle<Self>,
210-
state: &mut Self::State,
211-
) -> CastResponse;
216+
state: Self::State,
217+
) -> CastResponse<Self>;
212218
}

examples/bank/src/main.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,21 @@ fn main() {
7878
assert_eq!(
7979
result,
8080
Err(BankError::InsufficientBalance {
81-
who: joe,
81+
who: joe.clone(),
8282
amount: 25
8383
})
8484
);
8585

86+
let result = Bank::withdraw(&mut name_server, "Joe".to_string(), 25).await;
87+
tracing::info!("Withdraw result {result:?}");
88+
assert_eq!(
89+
result,
90+
Ok(BankOutMessage::WidrawOk {
91+
who: joe,
92+
amount: 0
93+
})
94+
);
95+
8696
let result = Bank::stop(&mut name_server).await;
8797
tracing::info!("Stop result {result:?}");
8898
assert_eq!(result, Ok(BankOutMessage::Stopped));

0 commit comments

Comments
 (0)