From aa9a49f5319e4a3a934a2efa9b85f7b52a8bc9c2 Mon Sep 17 00:00:00 2001 From: dave Date: Mon, 3 Feb 2025 17:52:19 +0100 Subject: [PATCH] allow for more flexibility when resolving toml files --- dlt/common/configuration/providers/toml.py | 88 +++++++++---------- .../configuration/test_toml_provider.py | 11 +-- 2 files changed, 49 insertions(+), 50 deletions(-) diff --git a/dlt/common/configuration/providers/toml.py b/dlt/common/configuration/providers/toml.py index e586fef225..5ae10f79e8 100644 --- a/dlt/common/configuration/providers/toml.py +++ b/dlt/common/configuration/providers/toml.py @@ -2,7 +2,7 @@ import tomlkit import tomlkit.exceptions import tomlkit.items -from typing import Any, Optional +from typing import Any, Optional, List from dlt.common.utils import update_dict_nested @@ -45,8 +45,7 @@ def __init__( name: str, supports_secrets: bool, file_name: str, - settings_dir: str, - global_dir: str = None, + resolvable_dirs: List[str], ) -> None: """Creates config provider from a `toml` file @@ -64,8 +63,7 @@ def __init__( name(str): name of the provider when registering in context supports_secrets(bool): allows to store secret values in this provider file_name (str): The name of `toml` file to load - settings_dir (str, optional): The location of `file_name`. If not specified, defaults to $cwd/.dlt - global_dir (bool, optional): Looks for `file_name` in global_dir (defaults to `dlt` home directory which in most cases is $HOME/.dlt) + resolvable_dirs (List[str], optional): A list of directories to resolve the file from, files will be merged into each other in the order the directories are specified. If not specified, defaults to [$cwd/.dlt, $HOME/.dlt] Raises: TomlProviderReadException: File could not be read, most probably `toml` parsing error @@ -73,22 +71,27 @@ def __init__( # set supports_secrets early, we need this flag to read config self._supports_secrets = supports_secrets # read toml file from local or from various environments - - self._toml_path = os.path.join(settings_dir, file_name) - self._global_dir = os.path.join(global_dir, file_name) if global_dir else None - self._config_toml = self._read_toml_files( - name, file_name, self._toml_path, self._global_dir + self._toml_paths = self._resolve_toml_paths( + file_name, [d for d in resolvable_dirs if d is not None] ) + self._config_toml = self._read_toml_files(name, file_name, self._toml_paths) + super().__init__( name, self._config_toml.unwrap, supports_secrets, ) + def _resolve_toml_paths(self, file_name: str, resolvable_dirs: List[str]) -> List[str]: + return [os.path.join(dir, file_name) for dir in resolvable_dirs] + def write_toml(self) -> None: - assert not self._global_dir, "Will not write configs when `global_dir` was set" - with open(self._toml_path, "w", encoding="utf-8") as f: + assert len(self._toml_paths) == 1, ( + f"Will not write configs when more than one toml path was resolved. Found paths: " + + str(self._toml_paths) + ) + with open(self._toml_paths[0], "w", encoding="utf-8") as f: tomlkit.dump(self._config_toml, f) def set_value(self, key: str, value: Any, pipeline_name: Optional[str], *sections: str) -> None: @@ -179,36 +182,37 @@ def _read_toml_file(self, toml_path: str) -> tomlkit.TOMLDocument: return None def _read_toml_files( - self, name: str, file_name: str, toml_path: str, global_path: str + self, name: str, file_name: str, toml_paths: List[str] ) -> tomlkit.TOMLDocument: + """Merge all toml files into one""" + try: - if (project_toml := self._read_toml_file(toml_path)) is not None: - pass - elif (project_toml := self._read_google_colab_secrets(name, file_name)) is not None: - pass - elif (project_toml := self._read_streamlit_secrets(name, file_name)) is not None: - pass - else: - # empty doc - project_toml = tomlkit.document() - if global_path: - global_toml = self._read_toml_file(global_path) - if global_toml is not None: - project_toml = update_dict_nested(global_toml, project_toml) - return project_toml + # merge all toml files into one + result_toml: Optional[tomlkit.TOMLDocument] = None + for path in toml_paths: + if (loaded_toml := self._read_toml_file(path)) is not None: + if result_toml is None: + result_toml = loaded_toml + else: + result_toml = update_dict_nested(loaded_toml, result_toml) + + # if nothing was found, try to load from google colab or streamlit + if not result_toml: + if (result_toml := self._read_google_colab_secrets(name, file_name)) is not None: + pass + elif (result_toml := self._read_streamlit_secrets(name, file_name)) is not None: + pass + else: + result_toml = tomlkit.document() + + return result_toml except Exception as ex: - raise TomlProviderReadException(name, file_name, toml_path, str(ex)) + raise TomlProviderReadException(name, file_name, toml_paths, str(ex)) class ConfigTomlProvider(SettingsTomlProvider): def __init__(self, settings_dir: str, global_dir: str = None) -> None: - super().__init__( - CONFIG_TOML, - False, - CONFIG_TOML, - settings_dir=settings_dir, - global_dir=global_dir, - ) + super().__init__(CONFIG_TOML, False, CONFIG_TOML, [settings_dir, global_dir]) @property def is_writable(self) -> bool: @@ -217,13 +221,7 @@ def is_writable(self) -> bool: class SecretsTomlProvider(SettingsTomlProvider): def __init__(self, settings_dir: str, global_dir: str = None) -> None: - super().__init__( - SECRETS_TOML, - True, - SECRETS_TOML, - settings_dir=settings_dir, - global_dir=global_dir, - ) + super().__init__(SECRETS_TOML, True, SECRETS_TOML, [settings_dir, global_dir]) @property def is_writable(self) -> bool: @@ -232,10 +230,10 @@ def is_writable(self) -> bool: class TomlProviderReadException(ConfigProviderException): def __init__( - self, provider_name: str, file_name: str, full_path: str, toml_exception: str + self, provider_name: str, file_name: str, full_paths: List[str], toml_exception: str ) -> None: self.file_name = file_name - self.full_path = full_path - msg = f"A problem encountered when loading {provider_name} from {full_path}:\n" + self.full_paths = full_paths + msg = f"A problem encountered when loading {provider_name} from paths {full_paths}:\n" msg += toml_exception super().__init__(provider_name, msg) diff --git a/tests/common/configuration/test_toml_provider.py b/tests/common/configuration/test_toml_provider.py index 9538849976..c261a37aa7 100644 --- a/tests/common/configuration/test_toml_provider.py +++ b/tests/common/configuration/test_toml_provider.py @@ -261,16 +261,17 @@ def test_toml_global_config() -> None: providers = Container()[PluggableRunContext].providers secrets = providers[SECRETS_TOML] config = providers[CONFIG_TOML] - # in pytest should be false - assert secrets._global_dir is None # type: ignore[attr-defined] - assert config._global_dir is None # type: ignore[attr-defined] + + # in pytest should be false, no global dir appended to resolved paths + assert len(secrets._toml_paths) == 1 # type: ignore[attr-defined] + assert len(config._toml_paths) == 1 # type: ignore[attr-defined] # set dlt data and settings dir global_dir = "./tests/common/cases/configuration/dlt_home" settings_dir = "./tests/common/cases/configuration/.dlt" # create instance with global toml enabled config = ConfigTomlProvider(settings_dir=settings_dir, global_dir=global_dir) - assert config._global_dir == os.path.join(global_dir, CONFIG_TOML) + assert config._toml_paths[1] == os.path.join(global_dir, CONFIG_TOML) assert isinstance(config._config_doc, dict) assert len(config._config_doc) > 0 # kept from global @@ -287,7 +288,7 @@ def test_toml_global_config() -> None: assert v == "a" secrets = SecretsTomlProvider(settings_dir=settings_dir, global_dir=global_dir) - assert secrets._global_dir == os.path.join(global_dir, SECRETS_TOML) + assert secrets._toml_paths[1] == os.path.join(global_dir, SECRETS_TOML) # check if values from project exist secrets_project = SecretsTomlProvider(settings_dir=settings_dir) assert secrets._config_doc == secrets_project._config_doc