diff --git a/server/common/config/config_model.py b/server/common/config/config_model.py index 5cb7f6284..810e75b40 100644 --- a/server/common/config/config_model.py +++ b/server/common/config/config_model.py @@ -13,7 +13,7 @@ from typing import Dict, List, Optional, Union from urllib.parse import quote_plus -from pydantic import BaseModel, Extra, Field, model_validator, validator +from pydantic.v1 import BaseModel, Extra, Field, root_validator, validator from server.common.utils.data_locator import discover_s3_region_name from server.common.utils.utils import custom_format_warning, find_available_port, is_port_available @@ -24,13 +24,11 @@ class CspDirectives(BaseModel): script_src: Union[str, List[str]] = Field(default_factory=list, alias="script-src") connect_src: Union[str, List[str]] = Field(default_factory=list, alias="connect-src") - @model_validator(mode="before") - @classmethod + @root_validator(skip_on_failure=True) def string_to_list(cls, values): - if isinstance(values, dict): - for key, value in values.items(): - if isinstance(value, str): - values[key] = [value] + for key, value in values.items(): + if isinstance(value, str): + values[key] = [value] return values @@ -44,48 +42,52 @@ class ServerApp(BaseModel): flask_secret_key: Optional[str] generate_cache_control_headers: bool server_timing_headers: bool - csp_directives: Optional[CspDirectives] = Field( - default_factory=lambda: {"img-src": [], "script-src": [], "connect-src": []} - ) + csp_directives: Optional[CspDirectives] api_base_url: Optional[str] web_base_url: Optional[str] - @model_validator(mode="after") - def check_port(self) -> "ServerApp": - if self.port: - if not is_port_available(self.host, self.port): - raise ValueError(f"The port selected {self.port} is in use or invalid, please configure an open port.") + @root_validator(skip_on_failure=True) + def check_port(cls, values): + host = values["host"] + port = values.get("port") + if port: + if not is_port_available(host, port): + raise ValueError(f"The port selected {port} is in use or invalid, please configure an open port.") else: - self.port = find_available_port(self.host) - return self - - @model_validator(mode="after") - def check_debug(self) -> "ServerApp": - if self.debug: - self.verbose = True - self.open_browser = False + values["port"] = find_available_port(host) + return values + + @root_validator(skip_on_failure=True) + def check_debug(cls, values): + if values["debug"]: + values["verbose"] = True + values["open_brower"] = False else: warnings.formatwarning = custom_format_warning - return self - - @model_validator(mode="after") - def check_api_base_url(self) -> "ServerApp": - if self.api_base_url == "local": - self.api_base_url = f"http://{self.host}:{self.port}" - if self.api_base_url and self.api_base_url.endswith("/"): - self.api_base_url = self.api_base_url[:-1] - return self - - @model_validator(mode="after") - def check_web_base_url(self) -> "ServerApp": - if self.web_base_url is None: - self.web_base_url = self.api_base_url - if self.web_base_url: - if self.web_base_url == "local": - self.web_base_url = f"http://{self.host}:{self.port}" - elif self.web_base_url.endswith("/"): - self.web_base_url = self.web_base_url[:-1] - return self + return values + + @root_validator(skip_on_failure=True) + def check_api_base_url(cls, values): + api_base_url = values.get("api_base_url") + if api_base_url == "local": + api_base_url = f"http://{values['host']}:{values['port']}" + if api_base_url and api_base_url.endswith("/"): + api_base_url = api_base_url[:-1] + values["api_base_url"] = api_base_url + return values + + @root_validator(skip_on_failure=True) + def check_web_base_url(cls, values): + web_base_url = values["web_base_url"] + if web_base_url is None: + web_base_url = values["api_base_url"] + if web_base_url: + if web_base_url == "local": + web_base_url = f"http://{values['host']}:{values['port']}" + elif web_base_url.endswith("/"): + web_base_url = web_base_url[:-1] + values["web_base_url"] = web_base_url + return values @validator("verbose") def check_verbose(cls, value): @@ -95,6 +97,10 @@ def check_verbose(cls, value): sys.tracebacklimit = 1000 return value + @validator("csp_directives") + def check_csp_directives(cls, value): + return value if value else {} + class DatarootValue(BaseModel): base_url: str @@ -122,21 +128,21 @@ class MultiDataset(BaseModel): dataroots: Optional[Dict[str, DatarootValue]] = {} index: Union[bool, str] = Field(default=False) - @model_validator(mode="after") - def check_dataroot(self) -> "MultiDataset": - if all([self.dataroot, self.dataroots]): + @root_validator(skip_on_failure=True) + def check_dataroot(cls, values): + if all([values["dataroot"], values["dataroots"]]): raise ValueError("Must set dataroot or dataroots.") - elif self.dataroot: - default = dict(base_url="d", dataroot=self.dataroot) - self.dataroots["d"] = DatarootValue(**default) - self.dataroot = None + elif values["dataroot"]: + default = dict(base_url="d", dataroot=values["dataroot"]) + values["dataroots"]["d"] = DatarootValue(**default) + values["dataroot"] = None # verify all the base_urls are unique - base_urls = [d.base_url for d in self.dataroots.values()] + base_urls = [d.base_url for d in values["dataroots"].values()] if len(base_urls) > len(set(base_urls)): raise ValueError("error in multi_dataset__dataroot: base_urls must be unique") # TODO check that at least one dataroot is set. Then we can remove AppConfig.handle_data_source. - return self + return values class DataLocator(BaseModel): @@ -175,10 +181,10 @@ class Server(BaseModel): adaptor: Adaptor limits: Limits - @model_validator(mode="after") - def check_data_locator(self) -> "Server": - if self.data_locator.s3_region_name is True: - path = self.multi_dataset.dataroots or self.multi_dataset.dataroot + @root_validator(skip_on_failure=True) + def check_data_locator(cls, values): + if values["data_locator"].s3_region_name is True: + path = values["multi_dataset"].dataroots or values["multi_dataset"].dataroot # except KeyError as ex: # return values if isinstance(path, dict): @@ -193,16 +199,18 @@ def check_data_locator(self) -> "Server": region_name = discover_s3_region_name(path) if region_name is None: raise ValueError(f"Unable to discover s3 region name from {path}") - self.data_locator.s3_region_name = region_name + values["data_locator"].s3_region_name = region_name else: - self.data_locator.s3_region_name = None - return self - - @model_validator(mode="after") - def check_cxg_adaptor(self) -> "Server": - if not self.adaptor.cxg_adaptor.tiledb_ctx.vfs_s3_region and isinstance(self.data_locator.s3_region_name, str): - self.adaptor.cxg_adaptor.tiledb_ctx.vfs_s3_region = self.data_locator.s3_region_name - return self + values["data_locator"].s3_region_name = None + return values + + @root_validator(skip_on_failure=True) + def check_cxg_adaptor(cls, values): + if not values["adaptor"].cxg_adaptor.tiledb_ctx.vfs_s3_region and isinstance( + values["data_locator"].s3_region_name, str + ): + values["adaptor"].cxg_adaptor.tiledb_ctx.vfs_s3_region = values["data_locator"].s3_region_name + return values class ScriptsItem(BaseModel): diff --git a/server/tests/unit/common/config/test_app_config.py b/server/tests/unit/common/config/test_app_config.py index 2ec638796..08fd57076 100644 --- a/server/tests/unit/common/config/test_app_config.py +++ b/server/tests/unit/common/config/test_app_config.py @@ -113,7 +113,6 @@ def test_configfile_no_server_section(self): changes = self.compare_configs(app_config, default_config) self.assertCountEqual(changes, [("default_dataset__app__about_legal_tos", "expected_value", None)]) - @unittest.skip("Configuration needs to be updated to satisfy v2 pydantic validation criteria") def test_csp_directives(self): default_config = AppConfig() with tempfile.TemporaryDirectory() as tempdir: diff --git a/server/tests/unit/common/config/test_base_config.py b/server/tests/unit/common/config/test_base_config.py index 806b88b3b..70960c492 100644 --- a/server/tests/unit/common/config/test_base_config.py +++ b/server/tests/unit/common/config/test_base_config.py @@ -31,7 +31,6 @@ def test_mapping_creation_returns_map_of_server_and_dataset_config(self): self.assertIsNotNone(mapping["server__app__verbose"]) self.assertIsNotNone(mapping["default_dataset__presentation__max_categories"]) - @unittest.skip("Configuration needs to be updated to satisfy v2 pydantic validation criteria") def test_changes_from_default_returns_list_of_nondefault_config_values(self): config = self.get_config(verbose="true", lfc_cutoff=0.05) changes = config.changes_from_default()