diff --git a/promptlens/models/config.py b/promptlens/models/config.py index 7c85798..2201dcf 100644 --- a/promptlens/models/config.py +++ b/promptlens/models/config.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator class ProviderConfig(BaseModel): @@ -83,6 +83,9 @@ class ExecutionConfig(BaseModel): timeout_seconds: int = 60 +SUPPORTED_OUTPUT_FORMATS = {"html", "json", "csv", "md"} + + class OutputConfig(BaseModel): """Configuration for output settings. @@ -96,6 +99,26 @@ class OutputConfig(BaseModel): formats: List[str] = Field(default_factory=lambda: ["html", "json"]) run_name: Optional[str] = None + @field_validator("formats") + @classmethod + def validate_formats(cls, formats: List[str]) -> List[str]: + """Ensure output formats are supported and normalized.""" + normalized_formats = [format_name.lower().strip() for format_name in formats] + invalid_formats = [ + format_name for format_name in normalized_formats + if format_name not in SUPPORTED_OUTPUT_FORMATS + ] + + if invalid_formats: + supported = ", ".join(sorted(SUPPORTED_OUTPUT_FORMATS)) + invalid = ", ".join(sorted(set(invalid_formats))) + raise ValueError( + f"Unsupported output formats: {invalid}. " + f"Supported formats: {supported}" + ) + + return normalized_formats + class RunConfig(BaseModel): """Complete run configuration. diff --git a/tests/test_output_format_validation.py b/tests/test_output_format_validation.py new file mode 100644 index 0000000..e343248 --- /dev/null +++ b/tests/test_output_format_validation.py @@ -0,0 +1,17 @@ +"""Tests for output format validation hardening.""" + +import pytest +from pydantic import ValidationError + +from promptlens.models.config import OutputConfig + + +def test_output_formats_are_normalized() -> None: + config = OutputConfig(formats=[" JSON ", "Html", "md"]) + + assert config.formats == ["json", "html", "md"] + + +def test_output_formats_reject_unsupported_values() -> None: + with pytest.raises(ValidationError, match="Unsupported output formats"): + OutputConfig(formats=["json", "xml"])