@@ -81,23 +81,23 @@ impl<G: GenServer> GenServerHandle<G> {
8181 }
8282}
8383
84- pub enum GenServerInMsg < A : GenServer > {
84+ pub enum GenServerInMsg < G : GenServer > {
8585 Call {
86- sender : oneshot:: Sender < Result < A :: OutMsg , GenServerError > > ,
87- message : A :: CallMsg ,
86+ sender : oneshot:: Sender < Result < G :: OutMsg , GenServerError > > ,
87+ message : G :: CallMsg ,
8888 } ,
8989 Cast {
90- message : A :: CastMsg ,
90+ message : G :: CastMsg ,
9191 } ,
9292}
9393
94- pub enum CallResponse < U > {
95- Reply ( U ) ,
96- Stop ( U ) ,
94+ pub enum CallResponse < G : GenServer > {
95+ Reply ( G :: State , G :: OutMsg ) ,
96+ Stop ( G :: OutMsg ) ,
9797}
9898
99- pub enum CastResponse {
100- NoReply ,
99+ pub enum CastResponse < G : GenServer > {
100+ NoReply ( G :: State ) ,
101101 Stop ,
102102}
103103
@@ -177,74 +177,81 @@ where
177177 & mut self ,
178178 handle : & GenServerHandle < Self > ,
179179 rx : & mut mpsc:: Receiver < GenServerInMsg < Self > > ,
180- mut state : Self :: State ,
180+ state : Self :: State ,
181181 ) -> impl std:: future:: Future < Output = Result < ( Self :: State , bool ) , GenServerError > > + Send {
182182 async move {
183183 let message = rx. recv ( ) . await ;
184184
185185 // Save current state in case of a rollback
186186 let state_clone = state. clone ( ) ;
187187
188- let ( keep_running, error ) = match message {
188+ let ( keep_running, new_state ) = match message {
189189 Some ( GenServerInMsg :: Call { sender, message } ) => {
190- let ( keep_running, error , response) =
191- match AssertUnwindSafe ( self . handle_call ( message, handle, & mut state) )
190+ let ( keep_running, new_state , response) =
191+ match AssertUnwindSafe ( self . handle_call ( message, handle, state) )
192192 . catch_unwind ( )
193193 . await
194194 {
195195 Ok ( response) => match response {
196- CallResponse :: Reply ( response) => ( true , None , Ok ( response) ) ,
197- CallResponse :: Stop ( response) => ( false , None , Ok ( response) ) ,
196+ CallResponse :: Reply ( new_state, response) => {
197+ ( true , new_state, Ok ( response) )
198+ }
199+ CallResponse :: Stop ( response) => ( false , state_clone, Ok ( response) ) ,
198200 } ,
199- Err ( error) => ( true , Some ( error) , Err ( GenServerError :: Callback ) ) ,
201+ Err ( error) => {
202+ tracing:: trace!(
203+ "Error in callback, reverting state - Error: '{error:?}'"
204+ ) ;
205+ ( true , state_clone, Err ( GenServerError :: Callback ) )
206+ }
200207 } ;
201208 // Send response back
202209 if sender. send ( response) . is_err ( ) {
203210 tracing:: trace!(
204211 "GenServer failed to send response back, client must have died"
205212 )
206213 } ;
207- ( keep_running, error )
214+ ( keep_running, new_state )
208215 }
209216 Some ( GenServerInMsg :: Cast { message } ) => {
210- match AssertUnwindSafe ( self . handle_cast ( message, handle, & mut state) )
217+ match AssertUnwindSafe ( self . handle_cast ( message, handle, state) )
211218 . catch_unwind ( )
212219 . await
213220 {
214221 Ok ( response) => match response {
215- CastResponse :: NoReply => ( true , None ) ,
216- CastResponse :: Stop => ( false , None ) ,
222+ CastResponse :: NoReply ( new_state ) => ( true , new_state ) ,
223+ CastResponse :: Stop => ( false , state_clone ) ,
217224 } ,
218- Err ( error) => ( true , Some ( error) ) ,
225+ Err ( error) => {
226+ tracing:: trace!(
227+ "Error in callback, reverting state - Error: '{error:?}'"
228+ ) ;
229+ ( true , state_clone)
230+ }
219231 }
220232 }
221233 None => {
222234 // Channel has been closed; won't receive further messages. Stop the server.
223- ( false , None )
235+ ( false , state )
224236 }
225237 } ;
226- if let Some ( error) = error {
227- tracing:: trace!( "Error in callback, reverting state - Error: '{error:?}'" ) ;
228- // Restore initial state (ie. dismiss any change)
229- state = state_clone;
230- } ;
231- Ok ( ( state, keep_running) )
238+ Ok ( ( new_state, keep_running) )
232239 }
233240 }
234241
235242 fn handle_call (
236243 & mut self ,
237244 message : Self :: CallMsg ,
238245 handle : & GenServerHandle < Self > ,
239- state : & mut Self :: State ,
240- ) -> impl std:: future:: Future < Output = CallResponse < Self :: OutMsg > > + Send ;
246+ state : Self :: State ,
247+ ) -> impl std:: future:: Future < Output = CallResponse < Self > > + Send ;
241248
242249 fn handle_cast (
243250 & mut self ,
244251 message : Self :: CastMsg ,
245252 handle : & GenServerHandle < Self > ,
246- state : & mut Self :: State ,
247- ) -> impl std:: future:: Future < Output = CastResponse > + Send ;
253+ state : Self :: State ,
254+ ) -> impl std:: future:: Future < Output = CastResponse < Self > > + Send ;
248255}
249256
250257#[ cfg( test) ]
@@ -279,17 +286,17 @@ mod tests {
279286 & mut self ,
280287 _: Self :: CallMsg ,
281288 _: & GenServerHandle < Self > ,
282- _: & mut Self :: State ,
283- ) -> CallResponse < Self :: OutMsg > {
289+ _: Self :: State ,
290+ ) -> CallResponse < Self > {
284291 CallResponse :: Stop ( ( ) )
285292 }
286293
287294 async fn handle_cast (
288295 & mut self ,
289296 _: Self :: CastMsg ,
290297 _: & GenServerHandle < Self > ,
291- _: & mut Self :: State ,
292- ) -> CastResponse {
298+ _: Self :: State ,
299+ ) -> CastResponse < Self > {
293300 rt:: sleep ( Duration :: from_millis ( 20 ) ) . await ;
294301 thread:: sleep ( Duration :: from_secs ( 2 ) ) ;
295302 CastResponse :: Stop
@@ -318,10 +325,13 @@ mod tests {
318325 & mut self ,
319326 message : Self :: CallMsg ,
320327 _: & GenServerHandle < Self > ,
321- state : & mut Self :: State ,
322- ) -> CallResponse < Self :: OutMsg > {
328+ state : Self :: State ,
329+ ) -> CallResponse < Self > {
323330 match message {
324- InMessage :: GetCount => CallResponse :: Reply ( OutMsg :: Count ( state. count ) ) ,
331+ InMessage :: GetCount => {
332+ let count = state. count ;
333+ CallResponse :: Reply ( state, OutMsg :: Count ( count) )
334+ }
325335 InMessage :: Stop => CallResponse :: Stop ( OutMsg :: Count ( state. count ) ) ,
326336 }
327337 }
@@ -330,12 +340,12 @@ mod tests {
330340 & mut self ,
331341 _: Self :: CastMsg ,
332342 handle : & GenServerHandle < Self > ,
333- state : & mut Self :: State ,
334- ) -> CastResponse {
343+ mut state : Self :: State ,
344+ ) -> CastResponse < Self > {
335345 state. count += 1 ;
336346 println ! ( "{:?}: good still alive" , thread:: current( ) . id( ) ) ;
337347 send_after ( Duration :: from_millis ( 100 ) , handle. to_owned ( ) , ( ) ) ;
338- CastResponse :: NoReply
348+ CastResponse :: NoReply ( state )
339349 }
340350 }
341351
0 commit comments