Skip to content

Commit

Permalink
allow for more flexibility when resolving toml files
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Feb 3, 2025
1 parent d4282d6 commit 5e70118
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 50 deletions.
89 changes: 44 additions & 45 deletions dlt/common/configuration/providers/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import tomlkit
import tomlkit.exceptions
import tomlkit.items
from typing import Any, Optional
from typing import Any, Optional, List

from dlt.common.utils import update_dict_nested

Expand Down Expand Up @@ -45,8 +45,7 @@ def __init__(
name: str,
supports_secrets: bool,
file_name: str,
settings_dir: str,
global_dir: str = None,
resolvable_dirs: List[str],
) -> None:
"""Creates config provider from a `toml` file
Expand All @@ -64,31 +63,36 @@ def __init__(
name(str): name of the provider when registering in context
supports_secrets(bool): allows to store secret values in this provider
file_name (str): The name of `toml` file to load
settings_dir (str, optional): The location of `file_name`. If not specified, defaults to $cwd/.dlt
global_dir (bool, optional): Looks for `file_name` in global_dir (defaults to `dlt` home directory which in most cases is $HOME/.dlt)
resolvable_dirs (List[str], optional): A list of directories to resolve the file from, files will be merged into each other in the order the directories are specified. If not specified, defaults to [$cwd/.dlt, $HOME/.dlt]
Raises:
TomlProviderReadException: File could not be read, most probably `toml` parsing error
"""
# set supports_secrets early, we need this flag to read config
self._supports_secrets = supports_secrets
# read toml file from local or from various environments

self._toml_path = os.path.join(settings_dir, file_name)
self._global_dir = os.path.join(global_dir, file_name) if global_dir else None
self._config_toml = self._read_toml_files(
name, file_name, self._toml_path, self._global_dir
self._toml_paths = self._resolve_toml_paths(
file_name, [d for d in resolvable_dirs if d is not None]
)

self._config_toml = self._read_toml_files(name, file_name, self._toml_paths)

super().__init__(
name,
self._config_toml.unwrap,
supports_secrets,
)

def _resolve_toml_paths(self, file_name: str, resolvable_dirs: List[str]) -> List[str]:
return [os.path.join(d, file_name) for d in resolvable_dirs]

def write_toml(self) -> None:
assert not self._global_dir, "Will not write configs when `global_dir` was set"
with open(self._toml_path, "w", encoding="utf-8") as f:
assert (
len(self._toml_paths) == 1
), "Will not write configs when more than one toml path was resolved. Found paths: " + str(
self._toml_paths
)
with open(self._toml_paths[0], "w", encoding="utf-8") as f:
tomlkit.dump(self._config_toml, f)

def set_value(self, key: str, value: Any, pipeline_name: Optional[str], *sections: str) -> None:
Expand Down Expand Up @@ -179,36 +183,37 @@ def _read_toml_file(self, toml_path: str) -> tomlkit.TOMLDocument:
return None

def _read_toml_files(
self, name: str, file_name: str, toml_path: str, global_path: str
self, name: str, file_name: str, toml_paths: List[str]
) -> tomlkit.TOMLDocument:
"""Merge all toml files into one"""

try:
if (project_toml := self._read_toml_file(toml_path)) is not None:
pass
elif (project_toml := self._read_google_colab_secrets(name, file_name)) is not None:
pass
elif (project_toml := self._read_streamlit_secrets(name, file_name)) is not None:
pass
else:
# empty doc
project_toml = tomlkit.document()
if global_path:
global_toml = self._read_toml_file(global_path)
if global_toml is not None:
project_toml = update_dict_nested(global_toml, project_toml)
return project_toml
# merge all toml files into one
result_toml: Optional[tomlkit.TOMLDocument] = None
for path in toml_paths:
if (loaded_toml := self._read_toml_file(path)) is not None:
if result_toml is None:
result_toml = loaded_toml
else:
result_toml = update_dict_nested(loaded_toml, result_toml)

# if nothing was found, try to load from google colab or streamlit
if not result_toml:
if (result_toml := self._read_google_colab_secrets(name, file_name)) is not None:
pass
elif (result_toml := self._read_streamlit_secrets(name, file_name)) is not None:
pass
else:
result_toml = tomlkit.document()

return result_toml
except Exception as ex:
raise TomlProviderReadException(name, file_name, toml_path, str(ex))
raise TomlProviderReadException(name, file_name, toml_paths, str(ex))


class ConfigTomlProvider(SettingsTomlProvider):
def __init__(self, settings_dir: str, global_dir: str = None) -> None:
super().__init__(
CONFIG_TOML,
False,
CONFIG_TOML,
settings_dir=settings_dir,
global_dir=global_dir,
)
super().__init__(CONFIG_TOML, False, CONFIG_TOML, [settings_dir, global_dir])

@property
def is_writable(self) -> bool:
Expand All @@ -217,13 +222,7 @@ def is_writable(self) -> bool:

class SecretsTomlProvider(SettingsTomlProvider):
def __init__(self, settings_dir: str, global_dir: str = None) -> None:
super().__init__(
SECRETS_TOML,
True,
SECRETS_TOML,
settings_dir=settings_dir,
global_dir=global_dir,
)
super().__init__(SECRETS_TOML, True, SECRETS_TOML, [settings_dir, global_dir])

@property
def is_writable(self) -> bool:
Expand All @@ -232,10 +231,10 @@ def is_writable(self) -> bool:

class TomlProviderReadException(ConfigProviderException):
def __init__(
self, provider_name: str, file_name: str, full_path: str, toml_exception: str
self, provider_name: str, file_name: str, full_paths: List[str], toml_exception: str
) -> None:
self.file_name = file_name
self.full_path = full_path
msg = f"A problem encountered when loading {provider_name} from {full_path}:\n"
self.full_paths = full_paths
msg = f"A problem encountered when loading {provider_name} from paths {full_paths}:\n"
msg += toml_exception
super().__init__(provider_name, msg)
11 changes: 6 additions & 5 deletions tests/common/configuration/test_toml_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,16 +261,17 @@ def test_toml_global_config() -> None:
providers = Container()[PluggableRunContext].providers
secrets = providers[SECRETS_TOML]
config = providers[CONFIG_TOML]
# in pytest should be false
assert secrets._global_dir is None # type: ignore[attr-defined]
assert config._global_dir is None # type: ignore[attr-defined]

# in pytest should be false, no global dir appended to resolved paths
assert len(secrets._toml_paths) == 1 # type: ignore[attr-defined]
assert len(config._toml_paths) == 1 # type: ignore[attr-defined]

# set dlt data and settings dir
global_dir = "./tests/common/cases/configuration/dlt_home"
settings_dir = "./tests/common/cases/configuration/.dlt"
# create instance with global toml enabled
config = ConfigTomlProvider(settings_dir=settings_dir, global_dir=global_dir)
assert config._global_dir == os.path.join(global_dir, CONFIG_TOML)
assert config._toml_paths[1] == os.path.join(global_dir, CONFIG_TOML)
assert isinstance(config._config_doc, dict)
assert len(config._config_doc) > 0
# kept from global
Expand All @@ -287,7 +288,7 @@ def test_toml_global_config() -> None:
assert v == "a"

secrets = SecretsTomlProvider(settings_dir=settings_dir, global_dir=global_dir)
assert secrets._global_dir == os.path.join(global_dir, SECRETS_TOML)
assert secrets._toml_paths[1] == os.path.join(global_dir, SECRETS_TOML)
# check if values from project exist
secrets_project = SecretsTomlProvider(settings_dir=settings_dir)
assert secrets._config_doc == secrets_project._config_doc
Expand Down

0 comments on commit 5e70118

Please sign in to comment.