@@ -70,10 +70,38 @@ pub fn system(attr: TokenStream, item: TokenStream) -> TokenStream {
70
70
}
71
71
}
72
72
73
+ impl SystemTransform {
74
+ fn visit_stmts_mut ( & mut self , stmts : & mut Vec < Stmt > ) {
75
+ for stmt in stmts {
76
+ if let Stmt :: Expr ( ref mut expr) | Stmt :: Semi ( ref mut expr, _) = stmt {
77
+ self . visit_expr_mut ( expr) ;
78
+ }
79
+ }
80
+ }
81
+ }
82
+
73
83
/// Visits the AST and modifies the system function
74
84
impl VisitMut for SystemTransform {
75
85
// Modify the return instruction to return Result<Vec<u8>>
76
86
fn visit_expr_mut ( & mut self , expr : & mut Expr ) {
87
+ match expr {
88
+ Expr :: ForLoop ( for_loop_expr) => {
89
+ self . visit_stmts_mut ( & mut for_loop_expr. body . stmts ) ;
90
+ }
91
+ Expr :: Loop ( loop_expr) => {
92
+ self . visit_stmts_mut ( & mut loop_expr. body . stmts ) ;
93
+ }
94
+ Expr :: If ( if_expr) => {
95
+ self . visit_stmts_mut ( & mut if_expr. then_branch . stmts ) ;
96
+ if let Some ( ( _, else_expr) ) = & mut if_expr. else_branch {
97
+ self . visit_expr_mut ( else_expr) ;
98
+ }
99
+ }
100
+ Expr :: Block ( block_expr) => {
101
+ self . visit_stmts_mut ( & mut block_expr. block . stmts ) ;
102
+ }
103
+ _ => ( ) ,
104
+ }
77
105
if let Some ( inner_variable) = Self :: extract_inner_ok_expression ( expr) {
78
106
let new_return_expr: Expr = match inner_variable {
79
107
Expr :: Tuple ( tuple_expr) => {
@@ -88,7 +116,11 @@ impl VisitMut for SystemTransform {
88
116
}
89
117
}
90
118
} ;
91
- * expr = new_return_expr;
119
+ if let Expr :: Return ( return_expr) = expr {
120
+ return_expr. expr = Some ( Box :: new ( new_return_expr) ) ;
121
+ } else {
122
+ * expr = new_return_expr;
123
+ }
92
124
}
93
125
}
94
126
@@ -108,11 +140,7 @@ impl VisitMut for SystemTransform {
108
140
Self :: modify_fn_return_type ( item_fn, self . return_values ) ;
109
141
// Modify the return statement inside the function body
110
142
let block = & mut item_fn. block ;
111
- for stmt in & mut block. stmts {
112
- if let Stmt :: Expr ( ref mut expr) | Stmt :: Semi ( ref mut expr, _) = stmt {
113
- self . visit_expr_mut ( expr) ;
114
- }
115
- }
143
+ self . visit_stmts_mut ( & mut block. stmts ) ;
116
144
}
117
145
}
118
146
}
0 commit comments