Skip to content

Commit 68d7b07

Browse files
committed
feat(pipeline): add built-in list of wait keys
1 parent 9f47cf2 commit 68d7b07

File tree

1 file changed

+46
-10
lines changed

1 file changed

+46
-10
lines changed

caput/pipeline.py

+46-10
Original file line numberDiff line numberDiff line change
@@ -953,8 +953,9 @@ def _validate_task_inputs(self):
953953
for task_spec in self.task_specs:
954954
in_ = task_spec.get("in", None)
955955
requires = task_spec.get("requires", None)
956+
wait = task_spec.get("wait", None)
956957

957-
for key, value in (["in", in_], ["requires", requires]):
958+
for key, value in (["in", in_], ["requires", requires], ["wait", wait]):
958959
if value is None:
959960
continue
960961
if not isinstance(value, list):
@@ -970,7 +971,7 @@ def _get_task_from_spec(self, task_spec: dict):
970971
"""Set up a pipeline task from the spec given in the tasks list."""
971972
# Check that only the expected keys are in the task spec.
972973
for key in task_spec.keys():
973-
if key not in ["type", "params", "requires", "in", "out"]:
974+
if key not in ["type", "params", "requires", "in", "out", "wait"]:
974975
raise config.CaputConfigError(
975976
f"Task got an unexpected key '{key}' in 'tasks' list."
976977
)
@@ -1073,9 +1074,10 @@ def _check_duplicate(key0: str, key1: str, d0: dict, d1: dict):
10731074
requires = _check_duplicate("requires", "requires", task_spec, kwargs)
10741075
in_ = _check_duplicate("in", "in_", task_spec, kwargs)
10751076
out = _check_duplicate("out", "out", task_spec, kwargs)
1077+
wait = _check_duplicate("wait", "wait", task_spec, kwargs)
10761078

10771079
try:
1078-
task._setup_keys(in_, out, requires)
1080+
task._setup_keys(in_, out, requires, wait)
10791081
# Want to blindly catch errors
10801082
except Exception as e:
10811083
raise config.CaputConfigError(
@@ -1126,12 +1128,17 @@ class TaskBase(config.Reader):
11261128
If true, signals to the pipeline runner to make a call to `breakpoint` each
11271129
time this task is run. This will drop the interpreter into pdb, allowing for
11281130
interactive debugging of the current pipeline and task state. Default is False.
1131+
single_wait : bool
1132+
If true, keys in the wait queue only have to be received once, even if `next`
1133+
iterates multiple times. Otherwise, `wait` keys must be received prior to
1134+
each iteration of `next`. Default is False.
11291135
"""
11301136

11311137
broadcast_inputs = config.Property(proptype=bool, default=False)
11321138
limit_outputs = config.Property(proptype=int, default=None)
11331139
base_priority = config.Property(proptype=int, default=0)
11341140
breakpoint = config.Property(proptype=bool, default=False)
1141+
single_wait = config.Property(proptype=bool, default=False)
11351142

11361143
# Overridable Attributes
11371144
# -----------------------
@@ -1231,6 +1238,13 @@ def _pipeline_is_available(self):
12311238
# This task hasn't been initialized
12321239
return False
12331240

1241+
if self._wait is not None and not bool(
1242+
min((q.qsize() for q in self._wait), default=1)
1243+
):
1244+
# If wait flags are required and have not been received,
1245+
# this task can't be run
1246+
return False
1247+
12341248
if self._pipeline_state == "setup":
12351249
# True if all `requires` items have been provided
12361250
# This also returns True is `self._requires` is empty
@@ -1311,12 +1325,13 @@ def _from_config(cls, config):
13111325

13121326
return self
13131327

1314-
def _setup_keys(self, in_=None, out=None, requires=None):
1315-
"""Setup the 'requires', 'in' and 'out' keys for this task."""
1328+
def _setup_keys(self, in_=None, out=None, requires=None, wait=None):
1329+
"""Setup the 'in', 'out', 'requires', and 'wait' keys for this task."""
13161330
# Parse the task spec.
13171331
requires = _format_product_keys(requires)
13181332
in_ = _format_product_keys(in_)
13191333
out = _format_product_keys(out)
1334+
wait = _format_product_keys(wait)
13201335

13211336
# Inspect the `setup` method to see how many arguments it takes.
13221337
setup_argspec = inspect.getfullargspec(self.setup)
@@ -1380,6 +1395,11 @@ def _setup_keys(self, in_=None, out=None, requires=None):
13801395
# produce multiple values, queue up items which may be used in the
13811396
# future
13821397
self._in = [queue.Queue() for _ in range(n_in)]
1398+
# Store wait keys
1399+
self._wait_keys = wait
1400+
# Make a list with a queue for each wait key. Use queue because this can
1401+
# be buffered similarly to the inputs
1402+
self._wait = [queue.Queue() for _ in range(len(wait))]
13831403
# Store output keys
13841404
self._out_keys = out
13851405
# Keep track of the number of times this task has produced output
@@ -1434,6 +1454,7 @@ def _pipeline_advance_state(self):
14341454
)
14351455

14361456
self._in = None
1457+
self._wait = None
14371458
self._pipeline_state = "finish"
14381459

14391460
elif self._pipeline_state == "finish":
@@ -1476,6 +1497,10 @@ def _pipeline_next(self):
14761497
else: # noqa RET506
14771498
# Get the next set of data to be run.
14781499
args = tuple(in_.get() for in_ in self._in)
1500+
# If `wait` flags are not pinned, remove them
1501+
# from the queue
1502+
if not self.single_wait:
1503+
_ = [w.get() for w in self._wait]
14791504

14801505
# Call the next iteration of `next`. If it is done running,
14811506
# advance the task state and continue
@@ -1548,13 +1573,24 @@ def _pipeline_queue_product(self, key, product):
15481573
logger.debug(f"{self!s} stowing data product with key {key} for `in`.")
15491574
if self._in is None:
15501575
raise PipelineRuntimeError(
1551-
"Tried to queue 'in' data product, but `next()` already run."
1576+
f"Tried to queue 'in' data product, but task state is `{self._pipeline_state}`."
15521577
)
15531578

15541579
self._in[ii].put(product)
15551580

15561581
result = True
15571582

1583+
if key in self._wait_keys:
1584+
ii = self._wait_keys.index(key)
1585+
logger.debug(f"{self!s} setting wait flag with key {key}.")
1586+
if self._wait is None:
1587+
raise PipelineRuntimeError(
1588+
f"Tried to queue `wait` flag, but task state is `{self._pipeline_state}`."
1589+
)
1590+
# This data product isn't needed here - just have to record
1591+
# that it was received
1592+
self._wait[ii].put(True)
1593+
15581594
return result
15591595

15601596

@@ -2089,10 +2125,10 @@ def next(self, in_):
20892125
def _format_product_keys(keys):
20902126
"""Formats the pipeline task product keys.
20912127
2092-
In the pipeline config task list, the values of 'requires', 'in' and 'out'
2093-
are keys representing data products. This function gets that key from the
2094-
task's entry of the task list, defaults to zero, and ensures it's formated
2095-
as a sequence of strings.
2128+
In the pipeline config task list, the values of 'requires', 'in', 'out' and
2129+
'wait' are keys representing data products. This function gets that key
2130+
from the task's entry of the task list, defaults to zero, and ensures it's
2131+
formated as a sequence of strings.
20962132
"""
20972133
if keys is None:
20982134
return []

0 commit comments

Comments
 (0)