diff --git a/petab/v2/C.py b/petab/v2/C.py index 5bb73980..e680450e 100644 --- a/petab/v2/C.py +++ b/petab/v2/C.py @@ -285,6 +285,9 @@ #: separator for multiple parameter values (bounds, observableParameters, ...) PARAMETER_SEPARATOR = ";" +#: The time symbol for use in any PEtab-specific mathematical expressions +TIME_SYMBOL = "time" + __all__ = [ x diff --git a/petab/v2/lint.py b/petab/v2/lint.py index 0780b340..2f11cc61 100644 --- a/petab/v2/lint.py +++ b/petab/v2/lint.py @@ -42,6 +42,7 @@ "CheckObservablesDoNotShadowModelEntities", "CheckUnusedConditions", "CheckPriorDistribution", + "CheckInitialChangeSymbols", "lint_problem", "default_validation_tasks", ] @@ -713,6 +714,62 @@ def run(self, problem: Problem) -> ValidationIssue | None: return None +class CheckInitialChangeSymbols(ValidationTask): + """ + Check that changes of any first period of any experiment only refers to + allowed symbols. + + The only allowed symbols are those that are present in the parameter table. + """ + + def run(self, problem: Problem) -> ValidationIssue | None: + if not problem.experiments: + return None + + if not problem.conditions: + return None + + allowed_symbols = {p.id for p in problem.parameters} + allowed_symbols.add(TIME_SYMBOL) + # IDs of conditions that have already been checked + valid_conditions = set() + id_to_condition = {c.id: c for c in problem.conditions} + + messages = [] + for experiment in problem.experiments: + if not experiment.periods: + continue + + first_period = experiment.sorted_periods[0] + for condition_id in first_period.condition_ids: + if condition_id in valid_conditions: + continue + + # we assume that all referenced condition IDs are valid + condition = id_to_condition[condition_id] + + used_symbols = { + str(sym) + for change in condition.changes + for sym in change.target_value.free_symbols + } + invalid_symbols = used_symbols - allowed_symbols + if invalid_symbols: + messages.append( + f"Condition {condition.id} is applied at the start of " + f"experiment {experiment.id}, and thus, its " + f"target value expressions must only contain " + f"symbols from the parameter table, or `time`. " + "However, it contains additional symbols: " + f"{invalid_symbols}. " + ) + + if messages: + return ValidationError("\n".join(messages)) + + return None + + class CheckPriorDistribution(ValidationTask): """A task to validate the prior distribution of a PEtab problem.""" @@ -1058,10 +1115,7 @@ def get_placeholders( CheckValidParameterInConditionOrParameterTable(), CheckUnusedExperiments(), CheckUnusedConditions(), - # TODO: atomize checks, update to long condition table, re-enable - # TODO validate mapping table - CheckValidParameterInConditionOrParameterTable(), - CheckAllParametersPresentInParameterTable(), - CheckValidConditionTargets(), CheckPriorDistribution(), + CheckInitialChangeSymbols(), + # TODO validate mapping table ] diff --git a/tests/v2/test_lint.py b/tests/v2/test_lint.py index 12973d86..82917902 100644 --- a/tests/v2/test_lint.py +++ b/tests/v2/test_lint.py @@ -39,6 +39,31 @@ def test_check_incompatible_targets(): assert "overlapping targets {'p1'}" in error.message +def test_validate_initial_change_symbols(): + """Test validation of symbols in target value expressions for changes + applied at the start of an experiment.""" + problem = Problem() + problem.model = SbmlModel.from_antimony("p1 = 1; p2 = 2") + problem.add_experiment("e1", 0, "c1", 1, "c2") + problem.add_condition("c1", p1="p2 + time") + problem.add_condition("c2", p1="p2", p2="p1") + problem.add_parameter("p1", nominal_value=1, estimate=False) + problem.add_parameter("p2", nominal_value=2, estimate=False) + + check = CheckInitialChangeSymbols() + assert check.run(problem) is None + + # removing `p1` from the parameter table is okay, as `c2` is never + # used at the start of an experiment + problem.parameter_tables[0].parameters.remove(problem["p1"]) + assert check.run(problem) is None + + # removing `p2` is not okay, as it is used at the start of an experiment + problem.parameter_tables[0].parameters.remove(problem["p2"]) + assert (error := check.run(problem)) is not None + assert "contains additional symbols: {'p2'}" in error.message + + def test_invalid_model_id_in_measurements(): """Test that measurements with an invalid model ID are caught.""" problem = Problem()