Skip to content

Commit

Permalink
fix: Allow less restrictive values for parameters in Pipeline configu…
Browse files Browse the repository at this point in the history
…rations (#3345)

* fix: Allow arbitrary values for parameters in Pipeline configurations

* Add test

* Adapt expected error message in tests

* Fix bug

* Fix bug on checking JSON

* Remove test cases that previously tested if error was thrown

* Change encoding in test

* Restrict possible values

* Re-add tests

* Re-add tests

* Add value flag to list elements
  • Loading branch information
bogdankostic authored and masci committed Oct 10, 2022
1 parent ce36be8 commit 6012da4
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
13 changes: 7 additions & 6 deletions haystack/pipelines/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
logger = logging.getLogger(__name__)


VALID_INPUT_REGEX = re.compile(r"^[-a-zA-Z0-9_/\\.:*]+$")
VALID_KEY_REGEX = re.compile(r"^[-\w/\\.:*]+$")
VALID_VALUE_REGEX = re.compile(r"^[-\w/\\.:* \[\]]+$")
VALID_ROOT_NODES = ["Query", "File"]


Expand Down Expand Up @@ -100,15 +101,14 @@ def read_pipeline_config_from_yaml(path: Path) -> Dict[str, Any]:
JSON_FIELDS = ["custom_query"] # ElasticsearchDocumentStore.custom_query


def validate_config_strings(pipeline_config: Any):
def validate_config_strings(pipeline_config: Any, is_value: bool = False):
"""
Ensures that strings used in the pipelines configuration
contain only alphanumeric characters and basic punctuation.
"""
try:
if isinstance(pipeline_config, dict):
for key, value in pipeline_config.items():

# FIXME find a better solution
# Some nodes take parameters that expect JSON input,
# like `ElasticsearchDocumentStore.custom_query`
Expand All @@ -125,14 +125,15 @@ def validate_config_strings(pipeline_config: Any):
raise PipelineConfigError(f"'{pipeline_config}' does not contain valid JSON.")
else:
validate_config_strings(key)
validate_config_strings(value)
validate_config_strings(value, is_value=True)

elif isinstance(pipeline_config, list):
for value in pipeline_config:
validate_config_strings(value)
validate_config_strings(value, is_value=True)

else:
if not VALID_INPUT_REGEX.match(str(pipeline_config)):
valid_regex = VALID_VALUE_REGEX if is_value else VALID_KEY_REGEX
if not valid_regex.match(str(pipeline_config)):
raise PipelineConfigError(
f"'{pipeline_config}' is not a valid variable name or value. "
"Use alphanumeric characters or dash, underscore and colon only."
Expand Down
39 changes: 39 additions & 0 deletions test/pipelines/test_pipeline_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,45 @@ def test_load_yaml_disconnected_component(tmp_path):
assert not pipeline.get_node("retriever")


def test_load_yaml_unusual_chars_in_values(tmp_path):
class DummyNode(BaseComponent):
outgoing_edges = 1

def __init__(self, space_param, non_alphanumeric_param):
super().__init__()
self.space_param = space_param
self.non_alphanumeric_param = non_alphanumeric_param

def run(self):
raise NotImplementedError

def run_batch(self):
raise NotImplementedError

with open(tmp_path / "tmp_config.yml", "w", encoding="utf-8") as tmp_file:
tmp_file.write(
f"""
version: '1.9.0'
components:
- name: DummyNode
type: DummyNode
params:
space_param: with space
non_alphanumeric_param: \[ümlaut\]
pipelines:
- name: indexing
nodes:
- name: DummyNode
inputs: [File]
"""
)
pipeline = Pipeline.load_from_yaml(path=tmp_path / "tmp_config.yml")
assert pipeline.components["DummyNode"].space_param == "with space"
assert pipeline.components["DummyNode"].non_alphanumeric_param == "\\[ümlaut\\]"


def test_save_yaml(tmp_path):
pipeline = Pipeline()
pipeline.add_node(MockRetriever(), name="retriever", inputs=["Query"])
Expand Down

0 comments on commit 6012da4

Please sign in to comment.