@@ -181,6 +181,30 @@ def _get_all_first_order_variables(indict) -> Iterable[str]:
181181
182182 return variable_names
183183
184+ def symbol_appears_in_any_expr(param_name, solver_json) -> bool:
185+ if "update_expressions" in solver_json.keys():
186+ for sym, expr in solver_json["update_expressions"].items():
187+ if param_name in [str(sym) for sym in list(expr.atoms())]:
188+ return True
189+
190+ if "propagators" in solver_json.keys():
191+ for sym, expr in solver_json["propagators"].items():
192+ if param_name in [str(sym) for sym in list(expr.atoms())]:
193+ return True
194+
195+ if "conditions" in solver_json.keys():
196+ for conditional_solver_json in solver_json["conditions"].values():
197+ if "update_expressions" in conditional_solver_json.keys():
198+ for sym, expr in conditional_solver_json["update_expressions"].items():
199+ if param_name in [str(sym) for sym in list(expr.atoms())]:
200+ return True
201+
202+ if "propagators" in conditional_solver_json.keys():
203+ for sym, expr in solver_json["propagators"].items():
204+ if param_name in [str(sym) for sym in list(expr.atoms())]:
205+ return True
206+
207+ return False
184208
185209def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_solver: bool = False, disable_singularity_detection: bool = False, preserve_expressions: Union[bool, Iterable[str]] = False, log_level: Union[str, int] = logging.WARNING) -> Tuple[List[Dict], SystemOfShapes, List[Shape]]:
186210 r"""
@@ -320,20 +344,7 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
320344 solver_json["parameters"] = {}
321345 for param_name, param_expr in indict["parameters"].items():
322346 # only make parameters appear in a solver if they are actually used there
323- symbol_appears_in_any_expr = False
324- if "update_expressions" in solver_json.keys():
325- for sym, expr in solver_json["update_expressions"].items():
326- if param_name in [str(sym) for sym in list(expr.atoms())]:
327- symbol_appears_in_any_expr = True
328- break
329-
330- if "propagators" in solver_json.keys():
331- for sym, expr in solver_json["propagators"].items():
332- if param_name in [str(sym) for sym in list(expr.atoms())]:
333- symbol_appears_in_any_expr = True
334- break
335-
336- if symbol_appears_in_any_expr:
347+ if symbol_appears_in_any_expr(sym, solver_json):
337348 sympy_expr = sympy.parsing.sympy_parser.parse_expr(param_expr, global_dict=Shape._sympy_globals)
338349
339350 # validate output for numerical problems
@@ -388,6 +399,26 @@ def _analysis(indict, disable_stiffness_check: bool = False, disable_analytic_so
388399 for sym, expr in solver_json["propagators"].items():
389400 solver_json["propagators"][sym] = str(expr)
390401
402+ if "conditions" in solver_json.keys():
403+ for cond, cond_solver in solver_json["conditions"].items():
404+ if "update_expressions" in cond_solver:
405+ for sym, expr in cond_solver["update_expressions"].items():
406+ cond_solver["update_expressions"][sym] = str(expr)
407+
408+ if preserve_expressions and sym in preserve_expressions:
409+ if "analytic" in solver_json["solver"]:
410+ logging.warning("Not preserving expression for variable \"" + sym + "\" as it is solved by propagator solver")
411+ continue
412+
413+ logging.info("Preserving expression for variable \"" + sym + "\"")
414+ var_def_str = _find_variable_definition(indict, sym, order=1)
415+ assert var_def_str is not None
416+ cond_solver["update_expressions"][sym] = var_def_str.replace("'", Config().differential_order_symbol)
417+
418+ if "propagators" in cond_solver:
419+ for sym, expr in cond_solver["propagators"].items():
420+ cond_solver["propagators"][sym] = str(expr)
421+
391422 logging.info("In ode-toolbox: returning outdict = ")
392423 logging.info(json.dumps(solvers_json, indent=4, sort_keys=True))
393424
0 commit comments