Skip to content

Commit

Permalink
Simplify load connection in LocalFilesystemBackend (apache#10638)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Sep 6, 2020
1 parent ebb0a97 commit ddee0aa
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 78 deletions.
38 changes: 26 additions & 12 deletions airflow/secrets/local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json
import logging
import os
import warnings
from collections import defaultdict
from inspect import signature
from json import JSONDecodeError
Expand Down Expand Up @@ -235,33 +236,44 @@ def load_variables(file_path: str) -> Dict[str, str]:
return variables


def load_connections(file_path: str):
def load_connections(file_path) -> Dict[str, List[Any]]:
"""
This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.",
"""
warnings.warn(
"This function is deprecated. Please use `airflow.secrets.local_filesystem.load_connections_dict`.",
DeprecationWarning, stacklevel=2
)
return {k: [v] for k, v in load_connections_dict(file_path).values()}


def load_connections_dict(file_path: str) -> Dict[str, Any]:
"""
Load connection from text file.
Both ``JSON`` and ``.env`` files are supported.
:return: A dictionary where the key contains a connection ID and the value contains a list of connections.
:rtype: Dict[str, List[airflow.models.connection.Connection]]
:rtype: Dict[str, airflow.models.connection.Connection]
"""
log.debug("Loading connection")

secrets: Dict[str, Any] = _parse_secret_file(file_path)
connections_by_conn_id = defaultdict(list)
connection_by_conn_id = {}
for key, secret_values in list(secrets.items()):
if isinstance(secret_values, list):
if len(secret_values) > 1:
raise ConnectionNotUnique(f"Found multiple values for {key} in {file_path}.")

for secret_value in secret_values:
connections_by_conn_id[key].append(_create_connection(key, secret_value))
connection_by_conn_id[key] = _create_connection(key, secret_value)
else:
connections_by_conn_id[key].append(_create_connection(key, secret_values))

if len(connections_by_conn_id[key]) > 1:
raise ConnectionNotUnique(f"Found multiple values for {key} in {file_path}")
connection_by_conn_id[key] = _create_connection(key, secret_values)

num_conn = sum(map(len, connections_by_conn_id.values()))
num_conn = len(connection_by_conn_id)
log.debug("Loaded %d connections", num_conn)

return connections_by_conn_id
return connection_by_conn_id


class LocalFilesystemBackend(BaseSecretsBackend, LoggingMixin):
Expand Down Expand Up @@ -298,10 +310,12 @@ def _local_connections(self) -> Dict[str, List[Any]]:
self.log.debug("The file for connection is not specified. Skipping")
# The user may not specify any file.
return {}
return load_connections(self.connections_file)
return load_connections_dict(self.connections_file)

def get_connections(self, conn_id: str) -> List[Any]:
return self._local_connections.get(conn_id) or []
if conn_id in self._local_connections:
return [self._local_connections[conn_id]]
return []

def get_variable(self, key: str) -> Optional[str]:
return self._local_variables.get(key)
147 changes: 81 additions & 66 deletions tests/secrets/test_local_filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,27 +122,27 @@ def test_yaml_file_should_load_variables(self, file_content, expected_variables)
class TestLoadConnection(unittest.TestCase):
@parameterized.expand(
(
("CONN_ID=mysql://host_1/", {"CONN_ID": ["mysql://host_1"]}),
("CONN_ID=mysql://host_1/", {"CONN_ID": "mysql://host_1"}),
(
"CONN_ID1=mysql://host_1/\nCONN_ID2=mysql://host_2/",
{"CONN_ID1": ["mysql://host_1"], "CONN_ID2": ["mysql://host_2"]},
{"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
(
"CONN_ID1=mysql://host_1/\n # AAAA\nCONN_ID2=mysql://host_2/",
{"CONN_ID1": ["mysql://host_1"], "CONN_ID2": ["mysql://host_2"]},
{"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
(
"\n\n\n\nCONN_ID1=mysql://host_1/\n\n\n\n\nCONN_ID2=mysql://host_2/\n\n\n",
{"CONN_ID1": ["mysql://host_1"], "CONN_ID2": ["mysql://host_2"]},
{"CONN_ID1": "mysql://host_1", "CONN_ID2": "mysql://host_2"},
),
)
)
def test_env_file_should_load_connection(self, file_content, expected_connection_uris):
with mock_local_file(file_content):
connections_by_conn_id = local_filesystem.load_connections("a.env")
connection_by_conn_id = local_filesystem.load_connections_dict("a.env")
connection_uris_by_conn_id = {
conn_id: [connection.get_uri() for connection in connections]
for conn_id, connections in connections_by_conn_id.items()
conn_id: connection.get_uri()
for conn_id, connection in connection_by_conn_id.items()
}

self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
Expand All @@ -156,22 +156,22 @@ def test_env_file_should_load_connection(self, file_content, expected_connection
def test_env_file_invalid_format(self, content, expected_message):
with mock_local_file(content):
with self.assertRaisesRegex(AirflowFileParseException, re.escape(expected_message)):
local_filesystem.load_connections("a.env")
local_filesystem.load_connections_dict("a.env")

@parameterized.expand(
(
({"CONN_ID": "mysql://host_1"}, {"CONN_ID": ["mysql://host_1"]}),
({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": ["mysql://host_1"]}),
({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID": ["mysql://host_1"]}),
({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID": ["mysql://host_1"]}),
({"CONN_ID": "mysql://host_1"}, {"CONN_ID": "mysql://host_1"}),
({"CONN_ID": ["mysql://host_1"]}, {"CONN_ID": "mysql://host_1"}),
({"CONN_ID": {"uri": "mysql://host_1"}}, {"CONN_ID": "mysql://host_1"}),
({"CONN_ID": [{"uri": "mysql://host_1"}]}, {"CONN_ID": "mysql://host_1"}),
)
)
def test_json_file_should_load_connection(self, file_content, expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
connections_by_conn_id = local_filesystem.load_connections("a.json")
connections_by_conn_id = local_filesystem.load_connections_dict("a.json")
connection_uris_by_conn_id = {
conn_id: [connection.get_uri() for connection in connections]
for conn_id, connections in connections_by_conn_id.items()
conn_id: connection.get_uri()
for conn_id, connection in connections_by_conn_id.items()
}

self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)
Expand All @@ -181,27 +181,28 @@ def test_json_file_should_load_connection(self, file_content, expected_connectio
({"CONN_ID": None}, "Unexpected value type: <class 'NoneType'>."),
({"CONN_ID": 1}, "Unexpected value type: <class 'int'>."),
({"CONN_ID": [2]}, "Unexpected value type: <class 'int'>."),
({"CONN_ID": ["mysql://host_1", None]}, "Unexpected value type: <class 'NoneType'>."),
({"CONN_ID": [None]}, "Unexpected value type: <class 'NoneType'>."),
({"CONN_ID": {"AAA": "mysql://host_1"}}, "The object have illegal keys: AAA."),
({"CONN_ID": {"conn_id": "BBBB"}}, "Mismatch conn_id."),
({"CONN_ID": ["mysql://", "mysql://"]}, "Found multiple values for CONN_ID in a.json."),
)
)
def test_env_file_invalid_input(self, file_content, expected_connection_uris):
with mock_local_file(json.dumps(file_content)):
with self.assertRaisesRegex(AirflowException, re.escape(expected_connection_uris)):
local_filesystem.load_connections("a.json")
local_filesystem.load_connections_dict("a.json")

@mock.patch("airflow.secrets.local_filesystem.os.path.exists", return_value=False)
def test_missing_file(self, mock_exists):
with self.assertRaisesRegex(
AirflowException,
re.escape("File a.json was not found. Check the configuration of your Secrets backend."),
):
local_filesystem.load_connections("a.json")
local_filesystem.load_connections_dict("a.json")

@parameterized.expand(
(
("""CONN_A: 'mysql://host_a'""", {"CONN_A": ["mysql://host_a"]}),
("""CONN_A: 'mysql://host_a'""", {"CONN_A": "mysql://host_a"}),
("""
conn_a: mysql://hosta
conn_b:
Expand All @@ -215,66 +216,80 @@ def test_missing_file(self, mock_exists):
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__keyfile_path: asaa""",
{"conn_a": ["mysql://hosta"],
"conn_b": [''.join("""scheme://Login:None@host:1234/lschema?
{"conn_a": "mysql://hosta",
"conn_b": ''.join("""scheme://Login:None@host:1234/lschema?
extra__google_cloud_platform__keyfile_dict=%7B%27a%27%3A+%27b%27%7D
&extra__google_cloud_platform__keyfile_path=asaa""".split())]}),
&extra__google_cloud_platform__keyfile_path=asaa""".split())}),
)
)
def test_yaml_file_should_load_connection(self, file_content, expected_connection_uris):
with mock_local_file(file_content):
connections_by_conn_id = local_filesystem.load_connections("a.yaml")
connections_by_conn_id = local_filesystem.load_connections_dict("a.yaml")
connection_uris_by_conn_id = {
conn_id: [connection.get_uri() for connection in connections]
for conn_id, connections in connections_by_conn_id.items()
conn_id: connection.get_uri()
for conn_id, connection in connections_by_conn_id.items()
}

self.assertEqual(expected_connection_uris, connection_uris_by_conn_id)

@parameterized.expand(
(
("""conn_c:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra_dejson:
aws_conn_id: bbb
region_name: ccc
""", {"conn_c": [{"aws_conn_id": "bbb", "region_name": "ccc"}]}),
("""conn_d:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra_dejson:
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__key_path: xxx
""", {"conn_d": [{"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
"extra__google_cloud_platform__key_path": "xxx"}]}),
("""conn_d:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra: '{\"extra__google_cloud_platform__keyfile_dict\": {\"a\": \"b\"}}'""", {"conn_d": [
{"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}]})
(
"""
conn_c:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra_dejson:
aws_conn_id: bbb
region_name: ccc
""",
{"conn_c": {"aws_conn_id": "bbb", "region_name": "ccc"}},
),
(
"""
conn_d:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra_dejson:
extra__google_cloud_platform__keyfile_dict:
a: b
extra__google_cloud_platform__key_path: xxx
""",
{
"conn_d": {
"extra__google_cloud_platform__keyfile_dict": {"a": "b"},
"extra__google_cloud_platform__key_path": "xxx",
}
},
),
(
"""
conn_d:
conn_type: scheme
host: host
schema: lschema
login: Login
password: None
port: 1234
extra: '{\"extra__google_cloud_platform__keyfile_dict\": {\"a\": \"b\"}}'
""",
{"conn_d": {"extra__google_cloud_platform__keyfile_dict": {"a": "b"}}},
),
)
)
def test_yaml_file_should_load_connection_extras(self, file_content, expected_extras):
with mock_local_file(file_content):
connections_by_conn_id = local_filesystem.load_connections("a.yaml")
connections_by_conn_id = local_filesystem.load_connections_dict("a.yaml")
connection_uris_by_conn_id = {
conn_id: [connection.extra_dejson for connection in connections]
for conn_id, connections in connections_by_conn_id.items()
conn_id: connection.extra_dejson for conn_id, connection in connections_by_conn_id.items()
}
self.assertEqual(expected_extras, connection_uris_by_conn_id)

Expand All @@ -298,7 +313,7 @@ def test_yaml_file_should_load_connection_extras(self, file_content, expected_ex
def test_yaml_invalid_extra(self, file_content, expected_message):
with mock_local_file(file_content):
with self.assertRaisesRegex(AirflowException, re.escape(expected_message)):
local_filesystem.load_connections("a.yaml")
local_filesystem.load_connections_dict("a.yaml")

@parameterized.expand(
(
Expand All @@ -308,7 +323,7 @@ def test_yaml_invalid_extra(self, file_content, expected_message):
def test_ensure_unique_connection_env(self, file_content):
with mock_local_file(file_content):
with self.assertRaises(ConnectionNotUnique):
local_filesystem.load_connections("a.env")
local_filesystem.load_connections_dict("a.env")

@parameterized.expand(
(
Expand All @@ -323,7 +338,7 @@ def test_ensure_unique_connection_env(self, file_content):
def test_ensure_unique_connection_json(self, file_content):
with mock_local_file(json.dumps(file_content)):
with self.assertRaises(ConnectionNotUnique):
local_filesystem.load_connections("a.json")
local_filesystem.load_connections_dict("a.json")

@parameterized.expand(
(
Expand All @@ -336,7 +351,7 @@ def test_ensure_unique_connection_json(self, file_content):
def test_ensure_unique_connection_yaml(self, file_content):
with mock_local_file(file_content):
with self.assertRaises(ConnectionNotUnique):
local_filesystem.load_connections("a.yaml")
local_filesystem.load_connections_dict("a.yaml")


class TestLocalFileBackend(unittest.TestCase):
Expand Down

0 comments on commit ddee0aa

Please sign in to comment.