diff --git a/src/sirocco/core/workflow.py b/src/sirocco/core/workflow.py index 06ee46ee..ffc6e012 100644 --- a/src/sirocco/core/workflow.py +++ b/src/sirocco/core/workflow.py @@ -1,12 +1,11 @@ from __future__ import annotations from itertools import product -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Self -from sirocco.core import _tasks # noqa [F401] from sirocco.core.graph_items import Cycle, Data, Store, Task from sirocco.parsing._yaml_data_models import ( - ConfigWorkflow, + CanonicalWorkflow, load_workflow_config, ) @@ -20,7 +19,7 @@ class Workflow: """Internal representation of a workflow""" - def __init__(self, workflow_config: ConfigWorkflow) -> None: + def __init__(self, workflow_config: CanonicalWorkflow) -> None: self.name = workflow_config.name self.tasks = Store() self.data = Store() @@ -68,7 +67,11 @@ def iter_coordinates(param_refs: list, date: datetime | None = None) -> Iterator self.tasks.add(task) cycle_tasks.append(task) self.cycles.add( - Cycle(name=cycle_name, tasks=cycle_tasks, coordinates={} if date is None else {"date": date}) + Cycle( + name=cycle_name, + tasks=cycle_tasks, + coordinates={} if date is None else {"date": date}, + ) ) # 4 - Link wait on tasks @@ -83,5 +86,5 @@ def cycle_dates(cycle_config: ConfigCycle) -> Iterator[datetime]: yield date @classmethod - def from_yaml(cls, config_path: str): + def from_yaml(cls: type[Self], config_path: str) -> Self: return cls(load_workflow_config(config_path)) diff --git a/src/sirocco/parsing/_yaml_data_models.py b/src/sirocco/parsing/_yaml_data_models.py index 70642a20..9315757b 100644 --- a/src/sirocco/parsing/_yaml_data_models.py +++ b/src/sirocco/parsing/_yaml_data_models.py @@ -1,6 +1,9 @@ from __future__ import annotations +import functools +import itertools import time +import typing from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -9,6 +12,7 @@ from isoduration import parse_duration from isoduration.types import Duration # pydantic needs type # noqa: TCH002 from pydantic import ( + AfterValidator, BaseModel, ConfigDict, Discriminator, @@ -375,7 +379,11 @@ class ConfigData(BaseModel): generated: list[ConfigGeneratedData] = [] -def get_plugin_from_named_base_model(data: dict) -> str: +def get_plugin_from_named_base_model( + data: dict | ConfigRootTask | ConfigShellTask | ConfigIconTask, +) -> str: + if isinstance(data, (ConfigRootTask, ConfigShellTask, ConfigIconTask)): + return data.plugin name_and_specs = _NamedBaseModel.merge_name_and_specs(data) if name_and_specs.get("name", None) == "ROOT": return ConfigRootTask.plugin @@ -434,8 +442,6 @@ class ConfigWorkflow(BaseModel): tasks: list[ConfigTask] data: ConfigData parameters: dict[str, list] = {} - data_dict: dict = {} - task_dict: dict = {} @field_validator("parameters", mode="before") @classmethod @@ -450,19 +456,9 @@ def check_parameters_lists(cls, data) -> dict[str, list]: raise TypeError(msg) return data - @model_validator(mode="after") - def build_internal_dicts(self) -> ConfigWorkflow: - self.data_dict = {data.name: data for data in self.data.available} | { - data.name: data for data in self.data.generated - } - self.task_dict = {task.name: task for task in self.tasks} - return self - @model_validator(mode="after") def check_parameters(self) -> ConfigWorkflow: - task_data_list = self.tasks + self.data.generated - if self.data.available: - task_data_list.extend(self.data.available) + task_data_list = itertools.chain(self.tasks, self.data.generated, self.data.available) for item in task_data_list: for param_name in item.parameters: if param_name not in self.parameters: @@ -471,7 +467,50 @@ def check_parameters(self) -> ConfigWorkflow: return self -def load_workflow_config(workflow_config: str) -> ConfigWorkflow: +ITEM_T = typing.TypeVar("ITEM_T") + + +def list_not_empty(value: list[ITEM_T]) -> list[ITEM_T]: + if len(value) < 1: + msg = "At least one element is required." + raise ValueError(msg) + return value + + +class CanonicalWorkflow(BaseModel): + name: str + rootdir: Path + cycles: Annotated[list[ConfigCycle], AfterValidator(list_not_empty)] + tasks: Annotated[list[ConfigTask], AfterValidator(list_not_empty)] + data: ConfigData + parameters: dict[str, list[Any]] + data_dict: dict[str, ConfigAvailableData | ConfigGeneratedData] + task_dict: dict[str, ConfigTask] + + +@functools.singledispatch +def canonicalize(value: Any) -> Any: # noqa: ARG001 # value not accessed, as this is just a placeholder + raise NotImplementedError + + +@canonicalize.register +def canonicalize_workflow(value: ConfigWorkflow) -> CanonicalWorkflow: + if not value.name or not value.rootdir: + msg = "Workflow name and root dir required for canonicalization." + raise ValueError(msg) + return CanonicalWorkflow( + name=value.name, + rootdir=value.rootdir, + cycles=value.cycles, + tasks=value.tasks, + data=value.data, + parameters=value.parameters, + data_dict={data.name: data for data in value.data.available + value.data.generated}, + task_dict={task.name: task for task in value.tasks}, + ) + + +def load_workflow_config(workflow_config: str) -> CanonicalWorkflow: """ Loads a python representation of a workflow config file. @@ -491,4 +530,5 @@ def load_workflow_config(workflow_config: str) -> ConfigWorkflow: parsed_workflow.rootdir = config_path.resolve().parent - return parsed_workflow + return canonicalize_workflow(parsed_workflow) + # return parsed_workflow diff --git a/tests/unit_tests/parsing/test_yaml_data_models.py b/tests/unit_tests/parsing/test_yaml_data_models.py index 27464f60..84dc5c0e 100644 --- a/tests/unit_tests/parsing/test_yaml_data_models.py +++ b/tests/unit_tests/parsing/test_yaml_data_models.py @@ -1,17 +1,22 @@ +import pathlib import textwrap from sirocco.parsing import _yaml_data_models as models -def test_workflow_test_internal_dicts(): - testee = models.ConfigWorkflow( - cycles=[], +def test_workflow_canonicalization(): + config = models.ConfigWorkflow( + name="testee", + rootdir=pathlib.Path("foo"), + cycles=[models.ConfigCycle(minimal={"tasks": [models.ConfigCycleTask(a={})]})], tasks=[{"some_task": {"plugin": "shell"}}], data=models.ConfigData( available=[models.ConfigAvailableData(foo={})], generated=[models.ConfigGeneratedData(bar={})], ), ) + + testee = models.canonicalize(config) assert testee.data_dict["foo"].name == "foo" assert testee.data_dict["bar"].name == "bar" assert testee.task_dict["some_task"].name == "some_task"