Skip to content

Commit 3210230

Browse files
committed
Improve state ownership along GenServer code
1 parent 675c8b7 commit 3210230

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

concurrency/src/tasks/gen_server.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,8 @@ where
134134
) -> impl Future<Output = Result<(), GenServerError>> + Send {
135135
async {
136136
match self.init(handle, state).await {
137-
Ok(mut new_state) => {
138-
self.main_loop(handle, rx, &mut new_state).await?;
137+
Ok(new_state) => {
138+
self.main_loop(handle, rx, new_state).await?;
139139
Ok(())
140140
}
141141
Err(err) => {
@@ -158,13 +158,15 @@ where
158158
&mut self,
159159
handle: &GenServerHandle<Self>,
160160
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
161-
state: &mut Self::State,
161+
mut state: Self::State,
162162
) -> impl Future<Output = Result<(), GenServerError>> + Send {
163163
async {
164164
loop {
165-
if !self.receive(handle, rx, state).await? {
165+
let (new_state, cont) = self.receive(handle, rx, state).await?;
166+
if !cont {
166167
break;
167168
}
169+
state = new_state;
168170
}
169171
tracing::trace!("Stopping GenServer");
170172
Ok(())
@@ -175,9 +177,9 @@ where
175177
&mut self,
176178
handle: &GenServerHandle<Self>,
177179
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
178-
state: &mut Self::State,
179-
) -> impl std::future::Future<Output = Result<bool, GenServerError>> + Send {
180-
async {
180+
mut state: Self::State,
181+
) -> impl std::future::Future<Output = Result<(Self::State, bool), GenServerError>> + Send {
182+
async move {
181183
let message = rx.recv().await;
182184

183185
// Save current state in case of a rollback
@@ -186,7 +188,7 @@ where
186188
let (keep_running, error) = match message {
187189
Some(GenServerInMsg::Call { sender, message }) => {
188190
let (keep_running, error, response) =
189-
match AssertUnwindSafe(self.handle_call(message, handle, state))
191+
match AssertUnwindSafe(self.handle_call(message, handle, &mut state))
190192
.catch_unwind()
191193
.await
192194
{
@@ -205,7 +207,7 @@ where
205207
(keep_running, error)
206208
}
207209
Some(GenServerInMsg::Cast { message }) => {
208-
match AssertUnwindSafe(self.handle_cast(message, handle, state))
210+
match AssertUnwindSafe(self.handle_cast(message, handle, &mut state))
209211
.catch_unwind()
210212
.await
211213
{
@@ -224,9 +226,9 @@ where
224226
if let Some(error) = error {
225227
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
226228
// Restore initial state (ie. dismiss any change)
227-
*state = state_clone;
229+
state = state_clone;
228230
};
229-
Ok(keep_running)
231+
Ok((state, keep_running))
230232
}
231233
}
232234

concurrency/src/threads/gen_server.rs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ where
108108
state: Self::State,
109109
) -> Result<(), GenServerError> {
110110
match self.init(handle, state) {
111-
Ok(mut new_state) => {
112-
self.main_loop(handle, rx, &mut new_state)?;
111+
Ok(new_state) => {
112+
self.main_loop(handle, rx, new_state)?;
113113
Ok(())
114114
}
115115
Err(err) => {
@@ -131,12 +131,14 @@ where
131131
&mut self,
132132
handle: &GenServerHandle<Self>,
133133
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
134-
state: &mut Self::State,
134+
mut state: Self::State,
135135
) -> Result<(), GenServerError> {
136136
loop {
137-
if !self.receive(handle, rx, state)? {
137+
let (new_state, cont) = self.receive(handle, rx, state)?;
138+
if !cont {
138139
break;
139140
}
141+
state = new_state;
140142
}
141143
tracing::trace!("Stopping GenServer");
142144
Ok(())
@@ -146,8 +148,8 @@ where
146148
&mut self,
147149
handle: &GenServerHandle<Self>,
148150
rx: &mut mpsc::Receiver<GenServerInMsg<Self>>,
149-
state: &mut Self::State,
150-
) -> Result<bool, GenServerError> {
151+
mut state: Self::State,
152+
) -> Result<(Self::State, bool), GenServerError> {
151153
let message = rx.recv().ok();
152154

153155
// Save current state in case of a rollback
@@ -156,7 +158,7 @@ where
156158
let (keep_running, error) = match message {
157159
Some(GenServerInMsg::Call { sender, message }) => {
158160
let (keep_running, error, response) = match catch_unwind(AssertUnwindSafe(|| {
159-
self.handle_call(message, handle, state)
161+
self.handle_call(message, handle, &mut state)
160162
})) {
161163
Ok(response) => match response {
162164
CallResponse::Reply(response) => (true, None, Ok(response)),
@@ -172,7 +174,7 @@ where
172174
}
173175
Some(GenServerInMsg::Cast { message }) => {
174176
match catch_unwind(AssertUnwindSafe(|| {
175-
self.handle_cast(message, handle, state)
177+
self.handle_cast(message, handle, &mut state)
176178
})) {
177179
Ok(response) => match response {
178180
CastResponse::NoReply => (true, None),
@@ -189,9 +191,9 @@ where
189191
if let Some(error) = error {
190192
tracing::trace!("Error in callback, reverting state - Error: '{error:?}'");
191193
// Restore initial state (ie. dismiss any change)
192-
*state = state_clone;
194+
state = state_clone;
193195
};
194-
Ok(keep_running)
196+
Ok((state, keep_running))
195197
}
196198

197199
fn handle_call(

0 commit comments

Comments
 (0)