Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
atarashansky committed Dec 12, 2024
1 parent f4d042d commit 9b3fc04
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 67 deletions.
138 changes: 73 additions & 65 deletions server/common/config/config_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from typing import Dict, List, Optional, Union
from urllib.parse import quote_plus

from pydantic import BaseModel, Extra, Field, model_validator, validator
from pydantic.v1 import BaseModel, Extra, Field, root_validator, validator

from server.common.utils.data_locator import discover_s3_region_name
from server.common.utils.utils import custom_format_warning, find_available_port, is_port_available
Expand All @@ -24,13 +24,11 @@ class CspDirectives(BaseModel):
script_src: Union[str, List[str]] = Field(default_factory=list, alias="script-src")
connect_src: Union[str, List[str]] = Field(default_factory=list, alias="connect-src")

@model_validator(mode="before")
@classmethod
@root_validator(skip_on_failure=True)
def string_to_list(cls, values):
if isinstance(values, dict):
for key, value in values.items():
if isinstance(value, str):
values[key] = [value]
for key, value in values.items():
if isinstance(value, str):
values[key] = [value]
return values


Expand All @@ -44,48 +42,52 @@ class ServerApp(BaseModel):
flask_secret_key: Optional[str]
generate_cache_control_headers: bool
server_timing_headers: bool
csp_directives: Optional[CspDirectives] = Field(
default_factory=lambda: {"img-src": [], "script-src": [], "connect-src": []}
)
csp_directives: Optional[CspDirectives]
api_base_url: Optional[str]
web_base_url: Optional[str]

@model_validator(mode="after")
def check_port(self) -> "ServerApp":
if self.port:
if not is_port_available(self.host, self.port):
raise ValueError(f"The port selected {self.port} is in use or invalid, please configure an open port.")
@root_validator(skip_on_failure=True)
def check_port(cls, values):
host = values["host"]
port = values.get("port")
if port:
if not is_port_available(host, port):
raise ValueError(f"The port selected {port} is in use or invalid, please configure an open port.")
else:
self.port = find_available_port(self.host)
return self

@model_validator(mode="after")
def check_debug(self) -> "ServerApp":
if self.debug:
self.verbose = True
self.open_browser = False
values["port"] = find_available_port(host)
return values

@root_validator(skip_on_failure=True)
def check_debug(cls, values):
if values["debug"]:
values["verbose"] = True
values["open_brower"] = False
else:
warnings.formatwarning = custom_format_warning
return self

@model_validator(mode="after")
def check_api_base_url(self) -> "ServerApp":
if self.api_base_url == "local":
self.api_base_url = f"http://{self.host}:{self.port}"
if self.api_base_url and self.api_base_url.endswith("/"):
self.api_base_url = self.api_base_url[:-1]
return self

@model_validator(mode="after")
def check_web_base_url(self) -> "ServerApp":
if self.web_base_url is None:
self.web_base_url = self.api_base_url
if self.web_base_url:
if self.web_base_url == "local":
self.web_base_url = f"http://{self.host}:{self.port}"
elif self.web_base_url.endswith("/"):
self.web_base_url = self.web_base_url[:-1]
return self
return values

@root_validator(skip_on_failure=True)
def check_api_base_url(cls, values):
api_base_url = values.get("api_base_url")
if api_base_url == "local":
api_base_url = f"http://{values['host']}:{values['port']}"
if api_base_url and api_base_url.endswith("/"):
api_base_url = api_base_url[:-1]
values["api_base_url"] = api_base_url
return values

@root_validator(skip_on_failure=True)
def check_web_base_url(cls, values):
web_base_url = values["web_base_url"]
if web_base_url is None:
web_base_url = values["api_base_url"]
if web_base_url:
if web_base_url == "local":
web_base_url = f"http://{values['host']}:{values['port']}"
elif web_base_url.endswith("/"):
web_base_url = web_base_url[:-1]
values["web_base_url"] = web_base_url
return values

@validator("verbose")
def check_verbose(cls, value):
Expand All @@ -95,6 +97,10 @@ def check_verbose(cls, value):
sys.tracebacklimit = 1000
return value

@validator("csp_directives")
def check_csp_directives(cls, value):
return value if value else {}


class DatarootValue(BaseModel):
base_url: str
Expand Down Expand Up @@ -122,21 +128,21 @@ class MultiDataset(BaseModel):
dataroots: Optional[Dict[str, DatarootValue]] = {}
index: Union[bool, str] = Field(default=False)

@model_validator(mode="after")
def check_dataroot(self) -> "MultiDataset":
if all([self.dataroot, self.dataroots]):
@root_validator(skip_on_failure=True)
def check_dataroot(cls, values):
if all([values["dataroot"], values["dataroots"]]):
raise ValueError("Must set dataroot or dataroots.")
elif self.dataroot:
default = dict(base_url="d", dataroot=self.dataroot)
self.dataroots["d"] = DatarootValue(**default)
self.dataroot = None
elif values["dataroot"]:
default = dict(base_url="d", dataroot=values["dataroot"])
values["dataroots"]["d"] = DatarootValue(**default)
values["dataroot"] = None

# verify all the base_urls are unique
base_urls = [d.base_url for d in self.dataroots.values()]
base_urls = [d.base_url for d in values["dataroots"].values()]
if len(base_urls) > len(set(base_urls)):
raise ValueError("error in multi_dataset__dataroot: base_urls must be unique")
# TODO check that at least one dataroot is set. Then we can remove AppConfig.handle_data_source.
return self
return values


class DataLocator(BaseModel):
Expand Down Expand Up @@ -175,10 +181,10 @@ class Server(BaseModel):
adaptor: Adaptor
limits: Limits

@model_validator(mode="after")
def check_data_locator(self) -> "Server":
if self.data_locator.s3_region_name is True:
path = self.multi_dataset.dataroots or self.multi_dataset.dataroot
@root_validator(skip_on_failure=True)
def check_data_locator(cls, values):
if values["data_locator"].s3_region_name is True:
path = values["multi_dataset"].dataroots or values["multi_dataset"].dataroot
# except KeyError as ex:
# return values
if isinstance(path, dict):
Expand All @@ -193,16 +199,18 @@ def check_data_locator(self) -> "Server":
region_name = discover_s3_region_name(path)
if region_name is None:
raise ValueError(f"Unable to discover s3 region name from {path}")
self.data_locator.s3_region_name = region_name
values["data_locator"].s3_region_name = region_name
else:
self.data_locator.s3_region_name = None
return self

@model_validator(mode="after")
def check_cxg_adaptor(self) -> "Server":
if not self.adaptor.cxg_adaptor.tiledb_ctx.vfs_s3_region and isinstance(self.data_locator.s3_region_name, str):
self.adaptor.cxg_adaptor.tiledb_ctx.vfs_s3_region = self.data_locator.s3_region_name
return self
values["data_locator"].s3_region_name = None
return values

@root_validator(skip_on_failure=True)
def check_cxg_adaptor(cls, values):
if not values["adaptor"].cxg_adaptor.tiledb_ctx.vfs_s3_region and isinstance(
values["data_locator"].s3_region_name, str
):
values["adaptor"].cxg_adaptor.tiledb_ctx.vfs_s3_region = values["data_locator"].s3_region_name
return values


class ScriptsItem(BaseModel):
Expand Down
1 change: 0 additions & 1 deletion server/tests/unit/common/config/test_app_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def test_configfile_no_server_section(self):
changes = self.compare_configs(app_config, default_config)
self.assertCountEqual(changes, [("default_dataset__app__about_legal_tos", "expected_value", None)])

@unittest.skip("Configuration needs to be updated to satisfy v2 pydantic validation criteria")
def test_csp_directives(self):
default_config = AppConfig()
with tempfile.TemporaryDirectory() as tempdir:
Expand Down
1 change: 0 additions & 1 deletion server/tests/unit/common/config/test_base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def test_mapping_creation_returns_map_of_server_and_dataset_config(self):
self.assertIsNotNone(mapping["server__app__verbose"])
self.assertIsNotNone(mapping["default_dataset__presentation__max_categories"])

@unittest.skip("Configuration needs to be updated to satisfy v2 pydantic validation criteria")
def test_changes_from_default_returns_list_of_nondefault_config_values(self):
config = self.get_config(verbose="true", lfc_cutoff=0.05)
changes = config.changes_from_default()
Expand Down

0 comments on commit 9b3fc04

Please sign in to comment.