Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions gaps/cli/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,16 @@ def preprocess_collect_config(
files = collect_pattern
if files == "PIPELINE":
files = parse_previous_status(project_dir, command_name)
files = [re.sub(f"{TAG}\\d+", "*", fname) for fname in files]
files = [re.sub(f"{TAG}\\d+", f"{TAG}*", fname) for fname in files]

if isinstance(files, str):
files = [files]

if isinstance(files, abc.Sequence):
files = {pattern.replace("*", ""): pattern for pattern in files}
files = {
pattern.replace(f"{TAG}*", "").replace("*", ""): pattern
for pattern in files
}

files = [
(
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def test_cli_monitor(
assert "collect-run_collect_pattern" in collected_outputs.attrs
assert (
Path(collected_outputs.attrs["collect-run_collect_pattern"])
== tmp_cwd / file_pattern
== tmp_cwd / file_pattern.replace("*", f"{TAG}*")
)

profiles = manual_collect(data_dir / file_pattern, "cf_profile")
Expand Down Expand Up @@ -487,7 +487,7 @@ def test_cli_background(
assert "collect-run_collect_pattern" in collected_outputs.attrs
assert (
Path(collected_outputs.attrs["collect-run_collect_pattern"])
== tmp_cwd / file_pattern
== tmp_cwd / file_pattern.replace("*", f"{TAG}*")
)

profiles = manual_collect(data_dir / file_pattern, "cf_profile")
Expand Down
35 changes: 34 additions & 1 deletion tests/cli/test_cli_preprocesing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""

import json
import glob
from pathlib import Path

import pytest
Expand All @@ -16,6 +17,7 @@
preprocess_collect_config,
split_project_points_into_ranges,
)
from gaps.cli.config import TAG
from gaps.exceptions import gapsConfigError
from gaps.warn import gapsWarning

Expand Down Expand Up @@ -111,7 +113,38 @@ def test_preprocess_collect_config_pipeline_input(tmp_path):
assert len(config["_pattern"]) == 2
for out_fp, pattern in zip(config["_out_path"], config["_pattern"]):
assert any(name in out_fp for name in allowed_out_fn)
assert out_fp == pattern.replace("*", "")
assert out_fp == pattern.replace(f"{TAG}*", "")


def test_preprocess_collect_config_pipeline_input_ignores_untagged_file(
tmp_path,
):
"""Test that PIPELINE collection patterns do not match untagged files."""
config_fp = tmp_path / "pipe_config.json"
with open(config_fp, "w") as file_:
json.dump(SAMPLE_CONFIG, file_)

(tmp_path / "config.json").touch()
(tmp_path / "collect_config.json").touch()

Pipeline(config_fp)

job_file = tmp_path / "output_file_j0.h5"
job_file.touch()
(tmp_path / "output_file.h5").touch()
Status.make_single_job_file(
tmp_path,
pipeline_step="run",
job_name="test_0",
attrs={StatusField.OUT_FILE: job_file.as_posix()},
)

config = preprocess_collect_config({}, tmp_path, "collect-run")

matched_files = sorted(
Path(path) for path in glob.glob(config["_pattern"][0])
)
assert matched_files == [job_file]


def test_split_project_points_into_ranges():
Expand Down
Loading