diff --git a/petab/v2/converters.py b/petab/v2/converters.py index f0736087..67c7efda 100644 --- a/petab/v2/converters.py +++ b/petab/v2/converters.py @@ -8,7 +8,13 @@ import libsbml from sbmlmath import sbml_math_to_sympy, set_math -from .core import Change, Condition, Experiment, ExperimentPeriod +from .core import ( + Change, + Condition, + ConditionTable, + Experiment, + ExperimentPeriod, +) from .models._sbml_utils import add_sbml_parameter, check from .models.sbml_model import SbmlModel from .problem import Problem @@ -176,7 +182,7 @@ def convert(self) -> Problem: self._add_preequilibration_indicator() - for experiment in self._new_problem.experiment_table.experiments: + for experiment in self._new_problem.experiments: self._convert_experiment(experiment) self._add_indicators_to_conditions() @@ -226,7 +232,7 @@ def _convert_experiment(self, experiment: Experiment) -> None: self._create_event_assignments_for_period( ev, [ - self._new_problem.condition_table[condition_id] + self._new_problem[condition_id] for condition_id in period.condition_ids ], ) @@ -365,24 +371,18 @@ def _add_indicators_to_conditions(self) -> None: problem = self._new_problem # create conditions for indicator parameters - problem.condition_table.conditions.append( - Condition( - id=self.CONDITION_ID_PREEQ_ON, - changes=[ - Change(target_id=self._preeq_indicator, target_value=1) - ], - ) + problem += Condition( + id=self.CONDITION_ID_PREEQ_ON, + changes=[Change(target_id=self._preeq_indicator, target_value=1)], ) - problem.condition_table.conditions.append( - Condition( - id=self.CONDITION_ID_PREEQ_OFF, - changes=[ - Change(target_id=self._preeq_indicator, target_value=0) - ], - ) + + problem += Condition( + id=self.CONDITION_ID_PREEQ_OFF, + changes=[Change(target_id=self._preeq_indicator, target_value=0)], ) + # add conditions for the experiment indicators - for experiment in problem.experiment_table.experiments: + for experiment in problem.experiments: cond_id = self._get_experiment_indicator_condition_id( experiment.id ) @@ -392,17 +392,19 @@ def _add_indicators_to_conditions(self) -> None: target_value=1, ) ] - problem.condition_table.conditions.append( - Condition( - id=cond_id, - changes=changes, - ) + problem += Condition( + id=cond_id, + changes=changes, ) # All changes have been encoded in event assignments and can be # removed. Only keep the conditions setting our indicators. - problem.condition_table.conditions = [ - condition - for condition in problem.condition_table.conditions - if condition.id.startswith("_petab") + problem.condition_tables = [ + ConditionTable( + conditions=[ + condition + for condition in problem.conditions + if condition.id.startswith("_petab") + ] + ) ] diff --git a/petab/v2/lint.py b/petab/v2/lint.py index 2810841a..a8ea848e 100644 --- a/petab/v2/lint.py +++ b/petab/v2/lint.py @@ -252,12 +252,8 @@ class CheckMeasuredObservablesDefined(ValidationTask): are defined.""" def run(self, problem: Problem) -> ValidationIssue | None: - used_observables = { - m.observable_id for m in problem.measurement_table.measurements - } - defined_observables = { - o.id for o in problem.observable_table.observables - } + used_observables = {m.observable_id for m in problem.measurements} + defined_observables = {o.id for o in problem.observables} if undefined_observables := (used_observables - defined_observables): return ValidationError( f"Observable(s) {undefined_observables} are used in the " @@ -275,15 +271,14 @@ class CheckOverridesMatchPlaceholders(ValidationTask): def run(self, problem: Problem) -> ValidationIssue | None: observable_parameters_count = { - o.id: len(o.observable_placeholders) - for o in problem.observable_table.observables + o.id: len(o.observable_placeholders) for o in problem.observables } noise_parameters_count = { - o.id: len(o.noise_placeholders) - for o in problem.observable_table.observables + o.id: len(o.noise_placeholders) for o in problem.observables } messages = [] - for m in problem.measurement_table.measurements: + observables = {o.id: o for o in problem.observables} + for m in problem.measurements: # check observable parameters try: expected = observable_parameters_count[m.observable_id] @@ -297,7 +292,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: actual = len(m.observable_parameters) if actual != expected: - formula = problem.observable_table[m.observable_id].formula + formula = observables[m.observable_id].formula messages.append( f"Mismatch of observable parameter overrides for " f"{m.observable_id} ({formula})" @@ -323,9 +318,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: "noiseParameters column." ) else: - formula = problem.observable_table[ - m.observable_id - ].noise_formula + formula = observables[m.observable_id].noise_formula messages.append( f"Mismatch of noise parameter overrides for " f"{m.observable_id} ({formula})" @@ -348,11 +341,11 @@ def run(self, problem: Problem) -> ValidationIssue | None: log_observables = { o.id - for o in problem.observable_table.observables + for o in problem.observables if o.noise_distribution in [ND.LOG_NORMAL, ND.LOG_LAPLACE] } if log_observables: - for m in problem.measurement_table.measurements: + for m in problem.measurements: if m.measurement <= 0 and m.observable_id in log_observables: return ValidationError( "Measurements with observable " @@ -374,14 +367,12 @@ def run(self, problem: Problem) -> ValidationIssue | None: # to conditions, otherwise it should maximally be a warning used_experiments = { m.experiment_id - for m in problem.measurement_table.measurements + for m in problem.measurements if m.experiment_id is not None } # check that measured experiments exist - available_experiments = { - e.id for e in problem.experiment_table.experiments - } + available_experiments = {e.id for e in problem.experiments} if missing_experiments := (used_experiments - available_experiments): return ValidationError( "Measurement table references experiments that " @@ -403,14 +394,12 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) allowed_targets |= set(get_output_parameters(problem)) allowed_targets |= { - m.petab_id - for m in problem.mapping_table.mappings - if m.model_id is not None + m.petab_id for m in problem.mappings if m.model_id is not None } used_targets = { change.target_id - for cond in problem.condition_table.conditions + for cond in problem.conditions for change in cond.changes } @@ -421,7 +410,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: # Check that changes of simultaneously applied conditions don't # intersect - for experiment in problem.experiment_table.experiments: + for experiment in problem.experiments: for period in experiment.periods: if not period.condition_ids: continue @@ -429,7 +418,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: for condition_id in period.condition_ids: condition_targets = { change.target_id - for cond in problem.condition_table.conditions + for cond in problem.conditions if cond.id == condition_id for change in cond.changes } @@ -451,7 +440,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: # -- replaces CheckObservablesDoNotShadowModelEntities # check for uniqueness of all primary keys - counter = Counter(c.id for c in problem.condition_table.conditions) + counter = Counter(c.id for c in problem.conditions) duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: @@ -459,7 +448,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"Condition table contains duplicate IDs: {duplicates}" ) - counter = Counter(o.id for o in problem.observable_table.observables) + counter = Counter(o.id for o in problem.observables) duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: @@ -467,7 +456,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"Observable table contains duplicate IDs: {duplicates}" ) - counter = Counter(e.id for e in problem.experiment_table.experiments) + counter = Counter(e.id for e in problem.experiments) duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: @@ -475,7 +464,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: f"Experiment table contains duplicate IDs: {duplicates}" ) - counter = Counter(p.id for p in problem.parameter_table.parameters) + counter = Counter(p.id for p in problem.parameters) duplicates = {id_ for id_, count in counter.items() if count > 1} if duplicates: @@ -491,12 +480,12 @@ class CheckObservablesDoNotShadowModelEntities(ValidationTask): # TODO: all PEtab entity IDs must be disjoint from the model entity IDs def run(self, problem: Problem) -> ValidationIssue | None: - if not problem.observable_table.observables or problem.model is None: + if not problem.observables or problem.model is None: return None shadowed_entities = [ o.id - for o in problem.observable_table.observables + for o in problem.observables if problem.model.has_entity_with_id(o.id) ] if shadowed_entities: @@ -512,7 +501,7 @@ class CheckExperimentTable(ValidationTask): def run(self, problem: Problem) -> ValidationIssue | None: messages = [] - for experiment in problem.experiment_table.experiments: + for experiment in problem.experiments: # Check that there are no duplicate timepoints counter = Counter(period.time for period in experiment.periods) duplicates = {time for time, count in counter.items() if count > 1} @@ -534,10 +523,8 @@ class CheckExperimentConditionsExist(ValidationTask): def run(self, problem: Problem) -> ValidationIssue | None: messages = [] - available_conditions = { - c.id for c in problem.condition_table.conditions - } - for experiment in problem.experiment_table.experiments: + available_conditions = {c.id for c in problem.conditions} + for experiment in problem.experiments: missing_conditions = ( set( chain.from_iterable( @@ -569,7 +556,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: required = get_required_parameters_for_parameter_table(problem) allowed = get_valid_parameters_for_parameter_table(problem) - actual = {p.id for p in problem.parameter_table.parameters} + actual = {p.id for p in problem.parameters} missing = required - actual extraneous = actual - allowed @@ -577,7 +564,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: # the mapping table if missing: model_to_petab_mapping = {} - for m in problem.mapping_table.mappings: + for m in problem.mappings: if m.model_id in model_to_petab_mapping: model_to_petab_mapping[m.model_id].append(m.petab_id) else: @@ -620,7 +607,7 @@ def run(self, problem: Problem) -> ValidationIssue | None: ) allowed_in_condition_cols |= { m.petab_id - for m in problem.mapping_table.mappings + for m in problem.mappings if not pd.isna(m.model_id) and ( # mapping table entities mapping to already allowed parameters @@ -636,12 +623,10 @@ def run(self, problem: Problem) -> ValidationIssue | None: entities_in_condition_table = { change.target_id - for cond in problem.condition_table.conditions + for cond in problem.conditions for change in cond.changes } - entities_in_parameter_table = { - p.id for p in problem.parameter_table.parameters - } + entities_in_parameter_table = {p.id for p in problem.parameters} disallowed_in_condition = { x @@ -689,12 +674,10 @@ class CheckUnusedExperiments(ValidationTask): def run(self, problem: Problem) -> ValidationIssue | None: used_experiments = { m.experiment_id - for m in problem.measurement_table.measurements + for m in problem.measurements if m.experiment_id is not None } - available_experiments = { - e.id for e in problem.experiment_table.experiments - } + available_experiments = {e.id for e in problem.experiments} unused_experiments = available_experiments - used_experiments if unused_experiments: @@ -713,14 +696,10 @@ class CheckUnusedConditions(ValidationTask): def run(self, problem: Problem) -> ValidationIssue | None: used_conditions = set( chain.from_iterable( - p.condition_ids - for e in problem.experiment_table.experiments - for p in e.periods + p.condition_ids for e in problem.experiments for p in e.periods ) ) - available_conditions = { - c.id for c in problem.condition_table.conditions - } + available_conditions = {c.id for c in problem.conditions} unused_conditions = available_conditions - used_conditions if unused_conditions: @@ -770,7 +749,7 @@ class CheckPriorDistribution(ValidationTask): def run(self, problem: Problem) -> ValidationIssue | None: messages = [] - for parameter in problem.parameter_table.parameters: + for parameter in problem.parameters: if parameter.prior_distribution is None: continue @@ -837,7 +816,7 @@ def get_valid_parameters_for_parameter_table( # condition table targets invalid |= { change.target_id - for cond in problem.condition_table.conditions + for cond in problem.conditions for change in cond.changes } @@ -849,7 +828,7 @@ def get_valid_parameters_for_parameter_table( if p not in invalid ) - for mapping in problem.mapping_table.mappings: + for mapping in problem.mappings: if mapping.model_id and mapping.model_id in parameter_ids.keys(): parameter_ids[mapping.petab_id] = None @@ -866,14 +845,15 @@ def append_overrides(overrides): if isinstance(p, sp.Symbol) and (str_p := str(p)) not in invalid: parameter_ids[str_p] = None - for measurement in problem.measurement_table.measurements: + for measurement in problem.measurements: # we trust that the number of overrides matches append_overrides(measurement.observable_parameters) append_overrides(measurement.noise_parameters) # Append parameter overrides from condition table - for p in problem.condition_table.free_symbols: - parameter_ids[str(p)] = None + for ct in problem.condition_tables: + for p in ct.free_symbols: + parameter_ids[str(p)] = None return set(parameter_ids.keys()) @@ -895,7 +875,7 @@ def get_required_parameters_for_parameter_table( parameter_ids = set() condition_targets = { change.target_id - for cond in problem.condition_table.conditions + for cond in problem.conditions for change in cond.changes } @@ -908,7 +888,7 @@ def append_overrides(overrides): and (str_p := str(p)) not in condition_targets ) - for m in problem.measurement_table.measurements: + for m in problem.measurements: # we trust that the number of overrides matches append_overrides(m.observable_parameters) append_overrides(m.noise_parameters) @@ -916,7 +896,7 @@ def append_overrides(overrides): # TODO remove `observable_ids` when # `get_output_parameters` is updated for PEtab v2/v1.1, where # observable IDs are allowed in observable formulae - observable_ids = {o.id for o in problem.observable_table.observables} + observable_ids = {o.id for o in problem.observables} # Add output parameters except for placeholders for formula_type, placeholder_sources in ( @@ -951,7 +931,8 @@ def append_overrides(overrides): # model parameter_ids.update( str(p) - for p in problem.condition_table.free_symbols + for ct in problem.condition_tables + for p in ct.free_symbols if not problem.model.has_entity_with_id(str(p)) ) @@ -981,13 +962,9 @@ def get_output_parameters( """ formulas = [] if observables: - formulas.extend( - o.formula for o in problem.observable_table.observables - ) + formulas.extend(o.formula for o in problem.observables) if noise: - formulas.extend( - o.noise_formula for o in problem.observable_table.observables - ) + formulas.extend(o.noise_formula for o in problem.observables) output_parameters = OrderedDict() for formula in formulas: @@ -1001,17 +978,15 @@ def get_output_parameters( continue # does it map to a model entity? - - if ( - (mapped := problem.mapping_table.get(sym)) is not None - and mapped.model_id is not None - and problem.model.symbol_allowed_in_observable_formula( - mapped.model_id - ) - ): - continue - - output_parameters[sym] = None + for mapping in problem.mappings: + if mapping.petab_id == sym and mapping.model_id is not None: + if problem.model.symbol_allowed_in_observable_formula( + mapping.model_id + ): + break + else: + # no mapping to a model entity, so it is an output parameter + output_parameters[sym] = None return list(output_parameters.keys()) @@ -1036,7 +1011,7 @@ def get_placeholders( # collect placeholder parameters overwritten by # {observable,noise}Parameters placeholders = [] - for o in problem.observable_table.observables: + for o in problem.observables: if observables: placeholders.extend(map(str, o.observable_placeholders)) if noise: diff --git a/petab/v2/problem.py b/petab/v2/problem.py index 97684241..a41a53b5 100644 --- a/petab/v2/problem.py +++ b/petab/v2/problem.py @@ -7,6 +7,7 @@ import tempfile import traceback from collections.abc import Sequence +from itertools import chain from math import nan from numbers import Number from pathlib import Path @@ -18,9 +19,6 @@ from pydantic import AnyUrl, BaseModel, Field, field_validator from ..v1 import ( - mapping, - measurements, - observables, parameter_mapping, parameters, validate_yaml_syntax, @@ -32,7 +30,7 @@ from ..v1.yaml import get_path_prefix from ..v2.C import * # noqa: F403 from ..versions import parse_version -from . import conditions, core, experiments +from . import core if TYPE_CHECKING: from ..v2.lint import ValidationResultList, ValidationTask @@ -63,12 +61,13 @@ class Problem: def __init__( self, model: Model = None, - condition_table: core.ConditionTable = None, - experiment_table: core.ExperimentTable = None, - observable_table: core.ObservableTable = None, - measurement_table: core.MeasurementTable = None, - parameter_table: core.ParameterTable = None, - mapping_table: core.MappingTable = None, + condition_tables: list[core.ConditionTable] = None, + experiment_tables: list[core.ExperimentTable] = None, + observable_tables: list[core.ObservableTable] = None, + measurement_tables: list[core.MeasurementTable] = None, + parameter_tables: list[core.ParameterTable] = None, + mapping_tables: list[core.MappingTable] = None, + # TODO: remove visualization_df: pd.DataFrame = None, config: ProblemConfig = None, ): @@ -80,41 +79,43 @@ def __init__( default_validation_tasks.copy() ) - self.observable_table = observable_table or core.ObservableTable( - observables=[] - ) - self.condition_table = condition_table or core.ConditionTable( - conditions=[] - ) - self.experiment_table = experiment_table or core.ExperimentTable( - experiments=[] - ) - self.measurement_table = measurement_table or core.MeasurementTable( - measurements=[] - ) - self.mapping_table = mapping_table or core.MappingTable(mappings=[]) - self.parameter_table = parameter_table or core.ParameterTable( - parameters=[] - ) + self.observable_tables = observable_tables or [ + core.ObservableTable(observables=[]) + ] + self.condition_tables = condition_tables or [ + core.ConditionTable(conditions=[]) + ] + self.experiment_tables = experiment_tables or [ + core.ExperimentTable(experiments=[]) + ] + self.measurement_tables = measurement_tables or [ + core.MeasurementTable(measurements=[]) + ] + self.mapping_tables = mapping_tables or [ + core.MappingTable(mappings=[]) + ] + self.parameter_tables = parameter_tables or [ + core.ParameterTable(parameters=[]) + ] self.visualization_df = visualization_df def __str__(self): model = f"with model ({self.model})" if self.model else "without model" - ne = len(self.experiment_table.experiments) + ne = len(self.experiments) experiments = f"{ne} experiments" - nc = len(self.condition_table.conditions) + nc = len(self.conditions) conditions = f"{nc} conditions" - no = len(self.observable_table.observables) + no = len(self.observables) observables = f"{no} observables" - nm = len(self.measurement_table.measurements) + nm = len(self.measurements) measurements = f"{nm} measurements" - nest = self.parameter_table.n_estimated + nest = sum(pt.n_estimated for pt in self.parameter_tables) parameters = f"{nest} estimated parameters" return ( @@ -130,15 +131,15 @@ def __getitem__(self, key): Accessing model entities is not currently not supported. """ - for table in ( - self.condition_table, - self.experiment_table, - self.observable_table, - self.measurement_table, - self.parameter_table, - self.mapping_table, + for table_list in ( + self.condition_tables, + self.experiment_tables, + self.observable_tables, + self.measurement_tables, + self.parameter_tables, + self.mapping_tables, ): - if table is not None: + for table in table_list: try: return table[key] except KeyError: @@ -215,9 +216,10 @@ def get_path(filename): config = ProblemConfig( **yaml_config, base_path=base_path, filepath=yaml_file ) - parameter_df = parameters.get_parameter_df( - [get_path(f) for f in config.parameter_files] - ) + parameter_tables = [ + core.ParameterTable.from_tsv(get_path(f)) + for f in config.parameter_files + ] if len(config.model_files or []) > 1: # TODO https://github.com/PEtab-dev/libpetab-python/issues/6 @@ -233,27 +235,30 @@ def get_path(filename): model_id=model_id, ) - measurement_files = [get_path(f) for f in config.measurement_files] - # If there are multiple tables, we will merge them - measurement_df = ( - concat_tables(measurement_files, measurements.get_measurement_df) - if measurement_files + measurement_tables = ( + [ + core.MeasurementTable.from_tsv(get_path(f)) + for f in config.measurement_files + ] + if config.measurement_files else None ) - condition_files = [get_path(f) for f in config.condition_files] - # If there are multiple tables, we will merge them - condition_df = ( - concat_tables(condition_files, conditions.get_condition_df) - if condition_files + condition_tables = ( + [ + core.ConditionTable.from_tsv(get_path(f)) + for f in config.condition_files + ] + if config.condition_files else None ) - experiment_files = [get_path(f) for f in config.experiment_files] - # If there are multiple tables, we will merge them - experiment_df = ( - concat_tables(experiment_files, experiments.get_experiment_df) - if experiment_files + experiment_tables = ( + [ + core.ExperimentTable.from_tsv(get_path(f)) + for f in config.experiment_files + ] + if config.experiment_files else None ) @@ -266,32 +271,34 @@ def get_path(filename): else None ) - observable_files = [get_path(f) for f in config.observable_files] - # If there are multiple tables, we will merge them - observable_df = ( - concat_tables(observable_files, observables.get_observable_df) - if observable_files + observable_tables = ( + [ + core.ObservableTable.from_tsv(get_path(f)) + for f in config.observable_files + ] + if config.observable_files else None ) - mapping_files = [get_path(f) for f in config.mapping_files] - # If there are multiple tables, we will merge them - mapping_df = ( - concat_tables(mapping_files, mapping.get_mapping_df) - if mapping_files + mapping_tables = ( + [ + core.MappingTable.from_tsv(get_path(f)) + for f in config.mapping_files + ] + if config.mapping_files else None ) - return Problem.from_dfs( - condition_df=condition_df, - experiment_df=experiment_df, - measurement_df=measurement_df, - parameter_df=parameter_df, - observable_df=observable_df, + return Problem( + config=config, model=model, + condition_tables=condition_tables, + experiment_tables=experiment_tables, + observable_tables=observable_tables, + measurement_tables=measurement_tables, + parameter_tables=parameter_tables, + mapping_tables=mapping_tables, visualization_df=visualization_df, - mapping_df=mapping_df, - config=config, ) @staticmethod @@ -330,12 +337,12 @@ def from_dfs( return Problem( model=model, - condition_table=condition_table, - experiment_table=experiment_table, - observable_table=observable_table, - measurement_table=measurement_table, - parameter_table=parameter_table, - mapping_table=mapping_table, + condition_tables=[condition_table], + experiment_tables=[experiment_table], + observable_tables=[observable_table], + measurement_tables=[measurement_table], + parameter_tables=[parameter_table], + mapping_tables=[mapping_table], visualization_df=visualization_df, config=config, ) @@ -398,73 +405,142 @@ def get_problem(problem: str | Path | Problem) -> Problem: @property def condition_df(self) -> pd.DataFrame | None: - """Condition table as DataFrame.""" - # TODO: return empty df? - return self.condition_table.to_df() if self.condition_table else None + """Combined condition tables as DataFrame.""" + conditions = self.conditions + return ( + core.ConditionTable(conditions=conditions).to_df() + if conditions + else None + ) @condition_df.setter def condition_df(self, value: pd.DataFrame): - self.condition_table = core.ConditionTable.from_df(value) + self.condition_tables = [core.ConditionTable.from_df(value)] @property def experiment_df(self) -> pd.DataFrame | None: """Experiment table as DataFrame.""" - return self.experiment_table.to_df() if self.experiment_table else None + return ( + core.ExperimentTable(experiments=experiments).to_df() + if (experiments := self.experiments) + else None + ) @experiment_df.setter def experiment_df(self, value: pd.DataFrame): - self.experiment_table = core.ExperimentTable.from_df(value) + self.experiment_tables = [core.ExperimentTable.from_df(value)] @property def measurement_df(self) -> pd.DataFrame | None: - """Measurement table as DataFrame.""" + """Combined measurement tables as DataFrame.""" + measurements = self.measurements return ( - self.measurement_table.to_df() if self.measurement_table else None + core.MeasurementTable(measurements=measurements).to_df() + if measurements + else None ) @measurement_df.setter def measurement_df(self, value: pd.DataFrame): - self.measurement_table = core.MeasurementTable.from_df(value) + self.measurement_tables = [core.MeasurementTable.from_df(value)] @property def parameter_df(self) -> pd.DataFrame | None: - """Parameter table as DataFrame.""" - return self.parameter_table.to_df() if self.parameter_table else None + """Combined parameter tables as DataFrame.""" + parameters = self.parameters + return ( + core.ParameterTable(parameters=parameters).to_df() + if parameters + else None + ) @parameter_df.setter def parameter_df(self, value: pd.DataFrame): - self.parameter_table = core.ParameterTable.from_df(value) + self.parameter_tables = [core.ParameterTable.from_df(value)] @property def observable_df(self) -> pd.DataFrame | None: - """Observable table as DataFrame.""" - return self.observable_table.to_df() if self.observable_table else None + """Combined observable tables as DataFrame.""" + observables = self.observables + return ( + core.ObservableTable(observables=observables).to_df() + if observables + else None + ) @observable_df.setter def observable_df(self, value: pd.DataFrame): - self.observable_table = core.ObservableTable.from_df(value) + self.observable_tables = [core.ObservableTable.from_df(value)] @property def mapping_df(self) -> pd.DataFrame | None: - """Mapping table as DataFrame.""" - return self.mapping_table.to_df() if self.mapping_table else None + """Combined mapping tables as DataFrame.""" + mappings = self.mappings + return ( + core.MappingTable(mappings=mappings).to_df() if mappings else None + ) @mapping_df.setter def mapping_df(self, value: pd.DataFrame): - self.mapping_table = core.MappingTable.from_df(value) + self.mapping_tables = [core.MappingTable.from_df(value)] + + @property + def conditions(self) -> list[core.Condition]: + """List of conditions in the condition table(s).""" + return list( + chain.from_iterable(ct.conditions for ct in self.condition_tables) + ) + + @property + def experiments(self) -> list[core.Experiment]: + """List of experiments in the experiment table(s).""" + return list( + chain.from_iterable( + et.experiments for et in self.experiment_tables + ) + ) + + @property + def observables(self) -> list[core.Observable]: + """List of observables in the observable table(s).""" + return list( + chain.from_iterable( + ot.observables for ot in self.observable_tables + ) + ) + + @property + def measurements(self) -> list[core.Measurement]: + """List of measurements in the measurement table(s).""" + return list( + chain.from_iterable( + mt.measurements for mt in self.measurement_tables + ) + ) + + @property + def parameters(self) -> list[core.Parameter]: + """List of parameters in the parameter table(s).""" + return list( + chain.from_iterable(pt.parameters for pt in self.parameter_tables) + ) + + @property + def mappings(self) -> list[core.Mapping]: + """List of mappings in the mapping table(s).""" + return list( + chain.from_iterable(mt.mappings for mt in self.mapping_tables) + ) def get_optimization_parameters(self) -> list[str]: """ Get the list of optimization parameter IDs from parameter table. - Arguments: - parameter_df: PEtab parameter DataFrame - Returns: A list of IDs of parameters selected for optimization (i.e., those with estimate = True). """ - return [p.id for p in self.parameter_table.parameters if p.estimate] + return [p.id for p in self.parameters if p.estimate] def get_optimization_parameter_scales(self) -> dict[str, str]: """ @@ -479,7 +555,7 @@ def get_observable_ids(self) -> list[str]: """ Returns dictionary of observable ids. """ - return [o.id for o in self.observable_table.observables] + return [o.id for o in self.observables] def _apply_mask(self, v: list, free: bool = True, fixed: bool = True): """Apply mask of only free or only fixed values. @@ -521,7 +597,7 @@ def get_x_ids(self, free: bool = True, fixed: bool = True): ------- The parameter IDs. """ - v = [p.id for p in self.parameter_table.parameters] + v = [p.id for p in self.parameters] return self._apply_mask(v, free=free, fixed=fixed) @property @@ -561,7 +637,7 @@ def get_x_nominal( """ v = [ p.nominal_value if p.nominal_value is not None else nan - for p in self.parameter_table.parameters + for p in self.parameters ] if scaled: @@ -624,10 +700,7 @@ def get_lb( ------- The lower parameter bounds. """ - v = [ - p.lb if p.lb is not None else nan - for p in self.parameter_table.parameters - ] + v = [p.lb if p.lb is not None else nan for p in self.parameters] if scaled: v = list( parameters.map_scale(v, self.parameter_df[PARAMETER_SCALE]) @@ -664,10 +737,7 @@ def get_ub( ------- The upper parameter bounds. """ - v = [ - p.ub if p.ub is not None else nan - for p in self.parameter_table.parameters - ] + v = [p.ub if p.ub is not None else nan for p in self.parameters] if scaled: v = list( parameters.map_scale(v, self.parameter_df[PARAMETER_SCALE]) @@ -687,20 +757,12 @@ def ub_scaled(self) -> list: @property def x_free_indices(self) -> list[int]: """Parameter table estimated parameter indices.""" - return [ - i - for i, p in enumerate(self.parameter_table.parameters) - if p.estimate - ] + return [i for i, p in enumerate(self.parameters) if p.estimate] @property def x_fixed_indices(self) -> list[int]: """Parameter table non-estimated parameter indices.""" - return [ - i - for i, p in enumerate(self.parameter_table.parameters) - if not p.estimate - ] + return [i for i, p in enumerate(self.parameters) if not p.estimate] # TODO remove in v2? def get_optimization_to_simulation_parameter_mapping(self, **kwargs): @@ -725,11 +787,7 @@ def get_priors(self) -> dict[str, Distribution]: :returns: The prior distributions for the estimated parameters. """ - return { - p.id: p.prior_dist - for p in self.parameter_table.parameters - if p.estimate - } + return {p.id: p.prior_dist for p in self.parameters if p.estimate} def sample_parameter_startpoints(self, n_starts: int = 100, **kwargs): """Create 2D array with starting points for optimization""" @@ -810,15 +868,12 @@ def n_estimated(self) -> int: @property def n_measurements(self) -> int: """Number of measurements.""" - return len(self.measurement_table.measurements) + return sum(len(mt.measurements) for mt in self.measurement_tables) @property def n_priors(self) -> int: """Number of priors.""" - return sum( - p.prior_distribution is not None - for p in self.parameter_table.parameters - ) + return sum(p.prior_distribution is not None for p in self.parameters) def validate( self, validation_tasks: list[ValidationTask] = None @@ -872,9 +927,14 @@ def add_condition( ): """Add a simulation condition to the problem. + If there are more than one condition tables, the condition + is added to the last one. + Arguments: id_: The condition id - name: The condition name + name: The condition name. If given, this will be added to the + last mapping table. If no mapping table exists, + a new mapping table will be created. kwargs: Entities to be added to the condition table in the form `target_id=target_value`. """ @@ -885,16 +945,13 @@ def add_condition( core.Change(target_id=target_id, target_value=target_value) for target_id, target_value in kwargs.items() ] - self.condition_table.conditions.append( + if not self.condition_tables: + self.condition_tables.append(core.ConditionTable(conditions=[])) + self.condition_tables[-1].conditions.append( core.Condition(id=id_, changes=changes) ) if name is not None: - self.mapping_table.mappings.append( - core.Mapping( - petab_id=id_, - name=name, - ) - ) + self.add_mapping(petab_id=id_, name=name) def add_observable( self, @@ -909,6 +966,9 @@ def add_observable( ): """Add an observable to the problem. + If there are more than one observable tables, the observable + is added to the last one. + Arguments: id_: The observable id formula: The observable formula @@ -936,7 +996,10 @@ def add_observable( record[NOISE_PLACEHOLDERS] = noise_placeholders record.update(kwargs) - self.observable_table += core.Observable(**record) + if not self.observable_tables: + self.observable_tables.append(core.ObservableTable(observables=[])) + + self.observable_tables[-1] += core.Observable(**record) def add_parameter( self, @@ -952,6 +1015,9 @@ def add_parameter( ): """Add a parameter to the problem. + If there are more than one parameter tables, the parameter + is added to the last one. + Arguments: id_: The parameter id estimate: Whether the parameter is estimated @@ -986,7 +1052,10 @@ def add_parameter( record[PRIOR_PARAMETERS] = prior_pars record.update(kwargs) - self.parameter_table += core.Parameter(**record) + if not self.parameter_tables: + self.parameter_tables.append(core.ParameterTable(parameters=[])) + + self.parameter_tables[-1] += core.Parameter(**record) def add_measurement( self, @@ -999,6 +1068,9 @@ def add_measurement( ): """Add a measurement to the problem. + If there are more than one measurement tables, the measurement + is added to the last one. + Arguments: obs_id: The observable ID experiment_id: The experiment ID @@ -1016,7 +1088,12 @@ def add_measurement( ): noise_parameters = [noise_parameters] - self.measurement_table.measurements.append( + if not self.measurement_tables: + self.measurement_tables.append( + core.MeasurementTable(measurements=[]) + ) + + self.measurement_tables[-1].measurements.append( core.Measurement( observable_id=obs_id, experiment_id=experiment_id, @@ -1027,20 +1104,31 @@ def add_measurement( ) ) - def add_mapping(self, petab_id: str, model_id: str, name: str = None): + def add_mapping( + self, petab_id: str, model_id: str = None, name: str = None + ): """Add a mapping table entry to the problem. + If there are more than one mapping tables, the mapping + is added to the last one. + Arguments: petab_id: The new PEtab-compatible ID mapping to `model_id` model_id: The ID of some entity in the model + name: A name (any string) for the entity referenced by `petab_id`. """ - self.mapping_table.mappings.append( + if not self.mapping_tables: + self.mapping_tables.append(core.MappingTable(mappings=[])) + self.mapping_tables[-1].mappings.append( core.Mapping(petab_id=petab_id, model_id=model_id, name=name) ) def add_experiment(self, id_: str, *args): """Add an experiment to the problem. + If there are more than one experiment tables, the experiment + is added to the last one. + :param id_: The experiment ID. :param args: Timepoints and associated conditions: ``time_1, condition_id_1, time_2, condition_id_2, ...``. @@ -1060,7 +1148,9 @@ def add_experiment(self, id_: str, *args): for i in range(0, len(args), 2) ] - self.experiment_table.experiments.append( + if not self.experiment_tables: + self.experiment_tables.append(core.ExperimentTable(experiments=[])) + self.experiment_tables[-1].experiments.append( core.Experiment(id=id_, periods=periods) ) @@ -1075,15 +1165,35 @@ def __iadd__(self, other): ) if isinstance(other, Observable): - self.observable_table += other + if not self.observable_tables: + self.observable_tables.append( + core.ObservableTable(observables=[]) + ) + self.observable_tables[-1] += other elif isinstance(other, Parameter): - self.parameter_table += other + if not self.parameter_tables: + self.parameter_tables.append( + core.ParameterTable(parameters=[]) + ) + self.parameter_tables[-1] += other elif isinstance(other, Measurement): - self.measurement_table += other + if not self.measurement_tables: + self.measurement_tables.append( + core.MeasurementTable(measurements=[]) + ) + self.measurement_tables[-1] += other elif isinstance(other, Condition): - self.condition_table += other + if not self.condition_tables: + self.condition_tables.append( + core.ConditionTable(conditions=[]) + ) + self.condition_tables[-1] += other elif isinstance(other, Experiment): - self.experiment_table += other + if not self.experiment_tables: + self.experiment_tables.append( + core.ExperimentTable(experiments=[]) + ) + self.experiment_tables[-1] += other else: raise ValueError( f"Cannot add object of type {type(other)} to Problem." @@ -1136,13 +1246,19 @@ def model_dump(self, **kwargs) -> dict[str, Any]: **kwargs, by_alias=True ), } - res |= self.mapping_table.model_dump(**kwargs) - res |= self.condition_table.model_dump(**kwargs) - res |= self.experiment_table.model_dump(**kwargs) - res |= self.observable_table.model_dump(**kwargs) - res |= self.measurement_table.model_dump(**kwargs) - res |= self.parameter_table.model_dump(**kwargs) - + for field, table_list in ( + ("conditions", self.condition_tables), + ("experiments", self.experiment_tables), + ("observables", self.observable_tables), + ("measurements", self.measurement_tables), + ("parameters", self.parameter_tables), + ("mappings", self.mapping_tables), + ): + res[field] = ( + [table.model_dump(**kwargs) for table in table_list] + if table_list + else [] + ) return res diff --git a/tests/v2/test_conversion.py b/tests/v2/test_conversion.py index 43e14662..6bcbb22c 100644 --- a/tests/v2/test_conversion.py +++ b/tests/v2/test_conversion.py @@ -15,7 +15,7 @@ def test_petab1to2_remote(): problem = petab1to2(yaml_url) assert isinstance(problem, Problem) - assert len(problem.measurement_table.measurements) + assert len(problem.measurements) try: @@ -45,4 +45,4 @@ def test_benchmark_collection(problem_id): except NotImplementedError as e: pytest.skip(str(e)) assert isinstance(problem, Problem) - assert len(problem.measurement_table.measurements) + assert len(problem.measurements) diff --git a/tests/v2/test_converters.py b/tests/v2/test_converters.py index 76ba6a86..8cdbaddf 100644 --- a/tests/v2/test_converters.py +++ b/tests/v2/test_converters.py @@ -25,7 +25,7 @@ def test_experiments_to_events_converter(): sbml_model = converted.model.sbml_model assert sbml_model.getNumEvents() == 2 - assert converted.condition_table.conditions == [ + assert converted.conditions == [ Condition( id="_petab_preequilibration_on", changes=[ @@ -53,7 +53,7 @@ def test_experiments_to_events_converter(): ], ), ] - assert converted.experiment_table.experiments == [ + assert converted.experiments == [ Experiment( id="e1", periods=[