Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework command line specification #125

Merged
merged 22 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/sirocco/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from ._tasks import IconTask, ShellTask
from .graph_items import Cycle, Data, GraphItem, Task
from .graph_items import AvailableData, Cycle, Data, GeneratedData, GraphItem, Task
from .workflow import Workflow

__all__ = ["Workflow", "GraphItem", "Data", "Task", "Cycle", "ShellTask", "IconTask"]
__all__ = ["Workflow", "GraphItem", "Data", "AvailableData", "GeneratedData", "Task", "Cycle", "ShellTask", "IconTask"]
8 changes: 2 additions & 6 deletions src/sirocco/core/_tasks/icon_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,9 @@ def update_core_namelists_from_workflow(self):
"experimentStopDate": self.cycle_point.stop_date.isoformat() + "Z",
}
)
self.core_namelists["icon_master.namelist"]["master_nml"]["lrestart"] = any(
# NOTE: in_data[0] contains the actual data node and in_data[1] the port name
in_data[1] == "restart"
for in_data in self.inputs
)
self.core_namelists["icon_master.namelist"]["master_nml"]["lrestart"] = bool(self.inputs["restart"])

def dump_core_namelists(self, folder=None):
def dump_core_namelists(self, folder: str | Path | None = None):
if folder is not None:
folder = Path(folder)
folder.mkdir(parents=True, exist_ok=True)
Expand Down
33 changes: 19 additions & 14 deletions src/sirocco/core/graph_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass, field
from itertools import chain, product
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypeVar, cast
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeVar, cast

from sirocco.parsing.target_cycle import DateList, LagList, NoTargetCycle
from sirocco.parsing.yaml_data_models import (
Expand Down Expand Up @@ -46,21 +46,23 @@ class Data(ConfigBaseDataSpecs, GraphItem):

color: ClassVar[Color] = field(default="light_blue", repr=False)

available: bool

@classmethod
def from_config(cls, config: ConfigBaseData, coordinates: dict) -> Self:
return cls(
def from_config(cls, config: ConfigBaseData, coordinates: dict) -> AvailableData | GeneratedData:
data_class = AvailableData if isinstance(config, ConfigAvailableData) else GeneratedData
return data_class(
name=config.name,
type=config.type,
src=config.src,
available=isinstance(config, ConfigAvailableData),
coordinates=coordinates,
)


# contains the input data and its potential associated port
BoundData: TypeAlias = tuple[Data, str | None]
class AvailableData(Data):
pass


class GeneratedData(Data):
pass


@dataclass(kw_only=True)
Expand All @@ -70,7 +72,7 @@ class Task(ConfigBaseTaskSpecs, GraphItem):
plugin_classes: ClassVar[dict[str, type[Self]]] = field(default={}, repr=False)
color: ClassVar[Color] = field(default="light_red", repr=False)

inputs: list[BoundData] = field(default_factory=list)
inputs: dict[str, list[Data]] = field(default_factory=dict)
outputs: list[Data] = field(default_factory=list)
wait_on: list[Task] = field(default_factory=list)
config_rootdir: Path
Expand All @@ -85,6 +87,9 @@ def __init_subclass__(cls, **kwargs):
raise ValueError(msg)
Task.plugin_classes[cls.plugin] = cls

def input_data_nodes(self) -> Iterator[Data]:
yield from chain(*self.inputs.values())

@classmethod
def from_config(
cls: type[Self],
Expand All @@ -95,11 +100,11 @@ def from_config(
datastore: Store,
graph_spec: ConfigCycleTask,
) -> Task:
inputs = [
(data_node, input_spec.port)
for input_spec in graph_spec.inputs
for data_node in datastore.iter_from_cycle_spec(input_spec, coordinates)
]
inputs: dict[str, list[Data]] = {}
for input_spec in graph_spec.inputs:
if input_spec.port not in inputs:
inputs[input_spec.port] = []
inputs[input_spec.port].extend(datastore.iter_from_cycle_spec(input_spec, coordinates))
outputs = [datastore[output_spec.name, coordinates] for output_spec in graph_spec.outputs]
if (plugin_cls := Task.plugin_classes.get(type(config).plugin, None)) is None:
msg = f"Plugin {type(config).plugin!r} is not supported."
Expand Down
29 changes: 15 additions & 14 deletions src/sirocco/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ def __init__(
self,
name: str,
config_rootdir: Path,
cycles: list[ConfigCycle],
tasks: list[ConfigTask],
data: ConfigData,
config_cycles: list[ConfigCycle],
config_tasks: list[ConfigTask],
config_data: ConfigData,
parameters: dict[str, list],
) -> None:
self.name: str = name
Expand All @@ -41,8 +41,10 @@ def __init__(
self.data: Store[Data] = Store()
self.cycles: Store[Cycle] = Store()

data_dict: dict[str, ConfigBaseData] = {data.name: data for data in chain(data.available, data.generated)}
task_dict: dict[str, ConfigTask] = {task.name: task for task in tasks}
config_data_dict: dict[str, ConfigBaseData] = {
data.name: data for data in chain(config_data.available, config_data.generated)
}
config_task_dict: dict[str, ConfigTask] = {task.name: task for task in config_tasks}

# Function to iterate over date and parameter combinations
def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator[dict]:
Expand All @@ -52,28 +54,27 @@ def iter_coordinates(cycle_point: CyclePoint, param_refs: list[str]) -> Iterator
yield from (dict(zip(axes.keys(), x, strict=False)) for x in product(*axes.values()))

# 1 - create availalbe data nodes
for available_data_config in data.available:
for available_data_config in config_data.available:
for coordinates in iter_coordinates(OneOffPoint(), available_data_config.parameters):
self.data.add(Data.from_config(config=available_data_config, coordinates=coordinates))

# 2 - create output data nodes
for cycle_config in cycles:
for cycle_config in config_cycles:
for cycle_point in cycle_config.cycling.iter_cycle_points():
for task_ref in cycle_config.tasks:
for data_ref in task_ref.outputs:
data_name = data_ref.name
data_config = data_dict[data_name]
data_config = config_data_dict[data_ref.name]
for coordinates in iter_coordinates(cycle_point, data_config.parameters):
self.data.add(Data.from_config(config=data_config, coordinates=coordinates))

# 3 - create cycles and tasks
for cycle_config in cycles:
for cycle_config in config_cycles:
cycle_name = cycle_config.name
for cycle_point in cycle_config.cycling.iter_cycle_points():
cycle_tasks = []
for task_graph_spec in cycle_config.tasks:
task_name = task_graph_spec.name
task_config = task_dict[task_name]
task_config = config_task_dict[task_name]
for coordinates in iter_coordinates(cycle_point, task_config.parameters):
task = Task.from_config(
config=task_config,
Expand Down Expand Up @@ -113,8 +114,8 @@ def from_config_workflow(cls: type[Self], config_workflow: ConfigWorkflow) -> Se
return cls(
name=config_workflow.name,
config_rootdir=config_workflow.rootdir,
cycles=config_workflow.cycles,
tasks=config_workflow.tasks,
data=config_workflow.data,
config_cycles=config_workflow.cycles,
config_tasks=config_workflow.tasks,
config_data=config_workflow.data,
parameters=config_workflow.parameters,
)
161 changes: 62 additions & 99 deletions src/sirocco/parsing/yaml_data_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import enum
import itertools
import re
import time
import typing
from dataclasses import dataclass, field
Expand Down Expand Up @@ -160,7 +161,7 @@ class TargetNodesBaseModel(_NamedBaseModel):


class ConfigCycleTaskInput(TargetNodesBaseModel):
port: str | None = None
port: str


class ConfigCycleTaskWaitOn(TargetNodesBaseModel):
Expand Down Expand Up @@ -273,58 +274,66 @@ class ConfigRootTask(ConfigBaseTask):
plugin: ClassVar[Literal["_root"]] = "_root"


# By using a frozen class we only need to validate on initialization
@dataclass(frozen=True)
class ShellCliArgument:
"""A holder for a CLI argument to simplify access.

Stores CLI arguments of the form "file", "--init", "{file}" or "{--init file}". These examples translate into
ShellCliArguments ShellCliArgument(name="file", references_data_item=False, cli_option_of_data_item=None),
ShellCliArgument(name="--init", references_data_item=False, cli_option_of_data_item=None),
ShellCliArgument(name="file", references_data_item=True, cli_option_of_data_item=None),
ShellCliArgument(name="file", references_data_item=True, cli_option_of_data_item="--init")
@dataclass(kw_only=True)
class ConfigShellTaskSpecs:
plugin: ClassVar[Literal["shell"]] = "shell"
port_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"{PORT(\[sep=.+\])?::(.+?)}"), repr=False)
sep_pattern: ClassVar[re.Pattern] = field(default=re.compile(r"\[sep=(.+)\]"), repr=False)
src: str | None = None
command: str
env_source_files: list[str] = field(default_factory=list)

Attributes:
name: Name of the argument. For the examples it is "file", "--init", "file" and "file"
references_data_item: Specifies if the argument references a data item signified by enclosing it by curly
brackets.
cli_option_of_data_item: The CLI option associated to the data item.
"""
def resolve_ports(self, input_labels: dict[str, list[str]]) -> str:
"""Replace port placeholders in command string with provided input labels.

name: str
references_data_item: bool
cli_option_of_data_item: str | None = None
Returns a string corresponding to self.command with "{PORT::port_name}"
placeholders replaced by the content provided in the input_labels dict.
When multiple input nodes are linked to a single port (e.g. with
parameterized data or if the `when` keyword specifies a list of lags or
dates), the provided input labels are inserted with a separator
defaulting to a " ". Specifying an alternative separator, e.g. a comma,
is done via "{PORT[sep=,]::port_name}"

def __post_init__(self):
if self.cli_option_of_data_item is not None and not self.references_data_item:
msg = "data_item_option cannot be not None if cli_option_of_data_item is False"
raise ValueError(msg)
Examples:

@classmethod
def from_cli_argument(cls, arg: str) -> ShellCliArgument:
len_arg_with_option = 2
len_arg_no_option = 1
references_data_item = arg.startswith("{") and arg.endswith("}")
# remove curly brackets "{--init file}" -> "--init file"
arg_unwrapped = arg[1:-1] if arg.startswith("{") and arg.endswith("}") else arg

# "--init file" -> ["--init", "file"]
input_arg = arg_unwrapped.split()
if len(input_arg) != len_arg_with_option and len(input_arg) != len_arg_no_option:
msg = f"Expected argument of format {{data}} or {{option data}} but found {arg}"
raise ValueError(msg)
name = input_arg[0] if len(input_arg) == len_arg_no_option else input_arg[1]
cli_option_of_data_item = input_arg[0] if len(input_arg) == len_arg_with_option else None
return cls(name, references_data_item, cli_option_of_data_item)
>>> task_specs = ConfigShellTaskSpecs(
... command="./my_script {PORT::positionals} -l -c --verbose 2 --arg {PORT::my_arg}"
... )
>>> task_specs.resolve_ports(
... {"positionals": ["input_1", "input_2"], "my_arg": ["input_3"]}
... )
'./my_script input_1 input_2 -l -c --verbose 2 --arg input_3'

>>> task_specs = ConfigShellTaskSpecs(
... command="./my_script {PORT::positionals} --multi_arg {PORT[sep=,]::multi_arg}"
... )
>>> task_specs.resolve_ports(
... {"positionals": ["input_1", "input_2"], "multi_arg": ["input_3", "input_4"]}
... )
'./my_script input_1 input_2 --multi_arg input_3,input_4'

@dataclass(kw_only=True)
class ConfigShellTaskSpecs:
plugin: ClassVar[Literal["shell"]] = "shell"
command: str = ""
cli_arguments: list[ShellCliArgument] = field(default_factory=list)
env_source_files: list[str] = field(default_factory=list)
src: str | None = None
>>> task_specs = ConfigShellTaskSpecs(
... command="./my_script --input {PORT[sep= --input ]::repeat_input}"
... )
>>> task_specs.resolve_ports({"repeat_input": ["input_1", "input_2", "input_3"]})
'./my_script --input input_1 --input input_2 --input input_3'
"""
cmd = self.command
for port_match in self.port_pattern.finditer(cmd):
if (port_name := port_match.group(2)) is None:
msg = f"Wrong port specification: {port_match.group(0)}"
raise ValueError(msg)
if (sep := port_match.group(1)) is None:
arg_sep = " "
else:
if (sep_match := self.sep_pattern.match(sep)) is None:
msg = "Wrong separator specification: sep"
raise ValueError(msg)
if (arg_sep := sep_match.group(1)) is None:
msg = "Wrong separator specification: sep"
raise ValueError(msg)
cmd = cmd.replace(port_match.group(0), arg_sep.join(input_labels[port_name]))
return cmd


class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
Expand All @@ -340,75 +349,26 @@ class ConfigShellTask(ConfigBaseTask, ConfigShellTaskSpecs):
... '''
... my_task:
... plugin: shell
... command: my_script.sh
... src: post_run_scripts
... cli_arguments: "-n 1024 {current_sim_output}"
... command: "my_script.sh -n 1024 {PORT::current_sim_output}"
... src: post_run_scripts/my_script.sh
... env_source_files: "env.sh"
... walltime: 00:01:00
... '''
... ),
... )
>>> my_task.cli_arguments[0]
ShellCliArgument(name='-n', references_data_item=False, cli_option_of_data_item=None)
>>> my_task.cli_arguments[1]
ShellCliArgument(name='1024', references_data_item=False, cli_option_of_data_item=None)
>>> my_task.cli_arguments[2]
ShellCliArgument(name='current_sim_output', references_data_item=True, cli_option_of_data_item=None)
>>> my_task.env_source_files
['env.sh']
>>> my_task.walltime.tm_min
1
"""

command: str = ""
cli_arguments: list[ShellCliArgument] = Field(default_factory=list)
env_source_files: list[str] = Field(default_factory=list)

@field_validator("cli_arguments", mode="before")
@classmethod
def validate_cli_arguments(cls, value: str) -> list[ShellCliArgument]:
return cls.parse_cli_arguments(value)

@field_validator("env_source_files", mode="before")
@classmethod
def validate_env_source_files(cls, value: str | list[str]) -> list[str]:
return [value] if isinstance(value, str) else value

@staticmethod
def split_cli_arguments(cli_arguments: str) -> list[str]:
"""Splits the CLI arguments into a list of separate entities.

Splits the CLI arguments by whitespaces except if the whitespace is contained within curly brackets. For example
the string
"-D --CMAKE_CXX_COMPILER=${CXX_COMPILER} {--init file}"
will be splitted into the list
["-D", "--CMAKE_CXX_COMPILER=${CXX_COMPILER}", "{--init file}"]
"""

nb_open_curly_brackets = 0
last_split_idx = 0
splits = []
for i, char in enumerate(cli_arguments):
if char == " " and not nb_open_curly_brackets:
# we ommit the space in the splitting therefore we only store up to i but move the last_split_idx to i+1
splits.append(cli_arguments[last_split_idx:i])
last_split_idx = i + 1
elif char == "{":
nb_open_curly_brackets += 1
elif char == "}":
if nb_open_curly_brackets == 0:
msg = f"Invalid input for cli_arguments. Found a closing curly bracket before an opening in {cli_arguments!r}"
raise ValueError(msg)
nb_open_curly_brackets -= 1

if last_split_idx != len(cli_arguments):
splits.append(cli_arguments[last_split_idx : len(cli_arguments)])
return splits

@staticmethod
def parse_cli_arguments(cli_arguments: str) -> list[ShellCliArgument]:
return [ShellCliArgument.from_cli_argument(arg) for arg in ConfigShellTask.split_cli_arguments(cli_arguments)]


@dataclass(kw_only=True)
class NamelistSpec:
Expand Down Expand Up @@ -662,6 +622,7 @@ class ConfigWorkflow(BaseModel):
... tasks:
... - task_a:
... plugin: shell
... command: "some_command"
... data:
... available:
... - foo:
Expand All @@ -681,7 +642,9 @@ class ConfigWorkflow(BaseModel):
... name="minimal",
... rootdir=Path("/location/of/config/file"),
... cycles=[ConfigCycle(minimal_cycle={"tasks": [ConfigCycleTask(task_a={})]})],
... tasks=[ConfigShellTask(task_a={"plugin": "shell"})],
... tasks=[
... ConfigShellTask(task_a={"plugin": "shell", "command": "some_command"})
... ],
... data=ConfigData(
... available=[
... ConfigAvailableData(name="foo", type=DataType.FILE, src="foo.txt")
Expand Down
Loading