Skip to content

Commit

Permalink
Improve behavior around adding nested TaskDependencies (#23)
Browse files Browse the repository at this point in the history
* cicd: rm black, use only ruff

* feat: modify task_group.tasks to dict[str, task], like OrbiterDAG

* feat: add task dependencies for nested tasks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: add `translate` test, typo fixes

* docs: add tests and docs

* fix: have includes create parent folder structure if needed

* fix: allow python callable to be a string callable reference or a callable (and it gets inlined)

* fix: render `_callable` args as py variable, handle recursion into task_groups

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
fritz-astronomer and pre-commit-ci[bot] authored Nov 8, 2024
1 parent b3ee237 commit 964661e
Show file tree
Hide file tree
Showing 27 changed files with 485 additions and 534 deletions.
5 changes: 1 addition & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ repos:
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]

- repo: https://github.com/psf/black
rev: 24.10.0
hooks: [ { id: black, args: [ --config=pyproject.toml ] } ]
- id: ruff-format

- repo: https://github.com/PyCQA/bandit/
rev: 1.7.10
Expand Down
16 changes: 5 additions & 11 deletions orbiter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import Any, Tuple

__version__ = "1.2.3"
__version__ = "1.3.0"

version = __version__

Expand Down Expand Up @@ -44,16 +44,14 @@ def import_from_qualname(qualname) -> Tuple[str, Any]:
"""Import a function or module from a qualified name
:param qualname: The qualified name of the function or module to import (e.g. a.b.d.MyOperator or json)
:return Tuple[str, Any]: The name of the function or module, and the function or module itself
>>> import_from_qualname('json.loads')
>>> import_from_qualname("json.loads")
('loads', <function loads at ...>)
>>> import_from_qualname('json')
>>> import_from_qualname("json")
('json', <module 'json' from '...'>)
"""
from importlib import import_module

[module, name] = (
qualname.rsplit(".", 1) if "." in qualname else [qualname, qualname]
)
[module, name] = qualname.rsplit(".", 1) if "." in qualname else [qualname, qualname]
imported_module = import_module(module)
return (
name,
Expand All @@ -64,8 +62,4 @@ def import_from_qualname(qualname) -> Tuple[str, Any]:
if __name__ == "__main__":
import doctest

doctest.testmod(
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL
)
doctest.testmod(optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL)
68 changes: 19 additions & 49 deletions orbiter/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def formatter(r):
return (
"<lvl>"
+ ( # add [time] WARN, etc. if it's not INFO
"[{time:HH:mm:ss}|{level}] "
if r["level"].no != logging.INFO
else "[{time:HH:mm:ss}] "
"[{time:HH:mm:ss}|{level}] " if r["level"].no != logging.INFO else "[{time:HH:mm:ss}] "
)
+ "{message}</>\n{exception}" # add exception, if there is one
)
Expand Down Expand Up @@ -86,9 +84,7 @@ def import_ruleset(ruleset: str) -> TranslationRuleset:
logger.debug(f"Importing ruleset: {ruleset}")
(_, translation_ruleset) = import_from_qualname(ruleset)
if not isinstance(translation_ruleset, TranslationRuleset):
raise RuntimeError(
f"translation_ruleset={translation_ruleset} is not a TranslationRuleset"
)
raise RuntimeError(f"translation_ruleset={translation_ruleset} is not a TranslationRuleset")
return translation_ruleset


Expand All @@ -115,20 +111,14 @@ def run_ruff_formatter(output_dir: Path):
changed_files = " ".join(
(
file
for file in git.Repo(output_dir)
.git.diff(output_dir, name_only=True)
.split("\n")
for file in git.Repo(output_dir).git.diff(output_dir, name_only=True).split("\n")
if file.endswith(".py")
)
)
except ImportError:
logger.debug(
"Unable to acquire list of changed files in output directory, reformatting output directory..."
)
logger.debug("Unable to acquire list of changed files in output directory, reformatting output directory...")
except Exception:
logger.debug(
"Unable to acquire list of changed files in output directory, reformatting output directory..."
)
logger.debug("Unable to acquire list of changed files in output directory, reformatting output directory...")

output = run(
f"ruff check --select E,F,UP,B,SIM,I --ignore E501 --fix {changed_files}",
Expand Down Expand Up @@ -205,9 +195,9 @@ def translate(

translation_ruleset = import_ruleset(ruleset)
try:
translation_ruleset.translate_fn(
translation_ruleset=translation_ruleset, input_dir=input_dir
).render(output_dir)
translation_ruleset.translate_fn(translation_ruleset=translation_ruleset, input_dir=input_dir).render(
output_dir
)
except RuntimeError as e:
logger.error(f"Error encountered during translation: {e}")
raise click.Abort()
Expand Down Expand Up @@ -252,9 +242,9 @@ def analyze(
output_file = output_file.open("w", newline="")
translation_ruleset = import_ruleset(ruleset)
try:
translation_ruleset.translate_fn(
translation_ruleset=translation_ruleset, input_dir=input_dir
).analyze(_format, output_file)
translation_ruleset.translate_fn(translation_ruleset=translation_ruleset, input_dir=input_dir).analyze(
_format, output_file
)
except RuntimeError as e:
logger.exception(f"Error encountered during translation: {e}")
raise click.Abort()
Expand All @@ -266,9 +256,7 @@ def _pip_install(repo: str, key: str):
_exec += f"=={TRANSLATION_VERSION}" if TRANSLATION_VERSION != "latest" else ""
if repo == "astronomer-orbiter-translations":
if not key:
raise ValueError(
"License key is required for 'astronomer-orbiter-translations'!"
)
raise ValueError("License key is required for 'astronomer-orbiter-translations'!")
extra = f' --index-url "https://license:{key}@api.keygen.sh/v1/accounts/{KG_ACCOUNT_ID}/engines/pypi/simple"'
_exec = f"{_exec}{extra}"
logger.debug(_exec.replace(key or "<nothing>", "****"))
Expand All @@ -287,8 +275,7 @@ def _get_keygen_pyz(key):
latest_orbiter_translations_pyz_id = next(
artifact["id"]
for artifact in r.json().get("data", [])
if artifact.get("attributes", {}).get("filename")
== "orbiter_translations.pyz"
if artifact.get("attributes", {}).get("filename") == "orbiter_translations.pyz"
)
except StopIteration:
raise ValueError("No Artifact found with filename='orbiter_translations.pyz'")
Expand Down Expand Up @@ -321,9 +308,7 @@ def _add_pyz():
logger.debug(f"Adding current directory {os.getcwd()} to sys.path")
sys.path.insert(0, os.getcwd())

local_pyz = [
str(_path.resolve()) for _path in Path(".").iterdir() if _path.suffix == ".pyz"
]
local_pyz = [str(_path.resolve()) for _path in Path(".").iterdir() if _path.suffix == ".pyz"]
logger.debug(f"Adding local .pyz files {local_pyz} to sys.path")
sys.path += local_pyz

Expand All @@ -332,9 +317,7 @@ def _bin_install(repo: str, key: str):
"""If we are running via a PyInstaller binary, we need to download a .pyz"""
if "astronomer-orbiter-translations" in repo:
if not key:
raise ValueError(
"License key is required for 'astronomer-orbiter-translations'!"
)
raise ValueError("License key is required for 'astronomer-orbiter-translations'!")
_get_keygen_pyz(key)
else:
_get_gh_pyz()
Expand All @@ -348,9 +331,7 @@ def _bin_install(repo: str, key: str):
@click.option(
"-r",
"--repo",
type=click.Choice(
["astronomer-orbiter-translations", "orbiter-community-translations"]
),
type=click.Choice(["astronomer-orbiter-translations", "orbiter-community-translations"]),
required=False,
allow_from_autoenv=True,
show_envvar=True,
Expand All @@ -367,10 +348,7 @@ def _bin_install(repo: str, key: str):
show_envvar=True,
)
def install(
repo: (
Literal["astronomer-orbiter-translations", "orbiter-community-translations"]
| None
),
repo: (Literal["astronomer-orbiter-translations", "orbiter-community-translations"] | None),
key: str | None,
):
"""Install a new Translation Ruleset from a repository"""
Expand All @@ -392,9 +370,7 @@ def install(
) == "Other":
repo = None
while not repo:
repo = Prompt.ask(
"Package Name or Repository URL (e.g. git+https://github.com/my/repo.git )"
)
repo = Prompt.ask("Package Name or Repository URL (e.g. git+https://github.com/my/repo.git )")

if RUNNING_AS_BINARY:
_bin_install(repo, key)
Expand All @@ -408,13 +384,7 @@ def list_rulesets():
console = Console()

table = tabulate(
list(
DictReader(
pkgutil.get_data("orbiter.assets", "supported_origins.csv")
.decode()
.splitlines()
)
),
list(DictReader(pkgutil.get_data("orbiter.assets", "supported_origins.csv").decode().splitlines())),
headers="keys",
tablefmt="pipe",
# https://github.com/Textualize/rich/issues/3027
Expand Down
60 changes: 20 additions & 40 deletions orbiter/ast_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
from typing import List, Callable


def py_bitshift(
left: str | List[str], right: str | List[str], is_downstream: bool = True
):
def py_bitshift(left: str | List[str], right: str | List[str], is_downstream: bool = True):
"""
>>> render_ast(py_bitshift("foo", "bar", is_downstream=False))
'foo << bar'
Expand All @@ -18,26 +16,14 @@ def py_bitshift(
>>> render_ast(py_bitshift(["foo", "bar"], "baz"))
'[foo, bar] >> baz'
"""
left = (
ast.Name(id=left)
if isinstance(left, str)
else ast.List(elts=[ast.Name(id=elt) for elt in left])
)
right = (
ast.Name(id=right)
if isinstance(right, str)
else ast.List(elts=[ast.Name(id=elt) for elt in right])
)
return ast.Expr(
value=ast.BinOp(
left=left, op=ast.RShift() if is_downstream else ast.LShift(), right=right
)
)
left = ast.Name(id=left) if isinstance(left, str) else ast.List(elts=[ast.Name(id=elt) for elt in left])
right = ast.Name(id=right) if isinstance(right, str) else ast.List(elts=[ast.Name(id=elt) for elt in right])
return ast.Expr(value=ast.BinOp(left=left, op=ast.RShift() if is_downstream else ast.LShift(), right=right))


def py_assigned_object(ast_name: str, obj: str, **kwargs) -> ast.Assign:
"""
>>> render_ast(py_assigned_object("foo","Bar",baz="bop"))
>>> render_ast(py_assigned_object("foo", "Bar", baz="bop"))
"foo = Bar(baz='bop')"
"""
return ast.Assign(
Expand All @@ -49,11 +35,7 @@ def py_assigned_object(ast_name: str, obj: str, **kwargs) -> ast.Assign:
keywords=[
ast.keyword(
arg=arg,
value=(
ast.Constant(value=value)
if not isinstance(value, ast.AST)
else value
),
value=(ast.Constant(value=value) if not isinstance(value, ast.AST) else value),
)
for arg, value in kwargs.items()
],
Expand All @@ -70,10 +52,7 @@ def py_object(name: str, **kwargs) -> ast.Expr:
value=ast.Call(
func=ast.Name(id=name),
args=[],
keywords=[
ast.keyword(arg=arg, value=ast.Constant(value=value))
for arg, value in kwargs.items()
],
keywords=[ast.keyword(arg=arg, value=ast.Constant(value=value)) for arg, value in kwargs.items()],
)
)

Expand All @@ -86,9 +65,7 @@ def py_root(*args) -> ast.Module:
return ast.Module(body=args, type_ignores=[])


def py_import(
names: List[str], module: str = None
) -> ast.ImportFrom | ast.Import | list:
def py_import(names: List[str], module: str = None) -> ast.ImportFrom | ast.Import | list:
"""
:param module: e.g. `airflow.operators.bash` for `from airflow.operators.bash import BashOperator`
:param names: e.g. `BashOperator` for `from airflow.operators.bash import BashOperator`
Expand All @@ -100,28 +77,31 @@ def py_import(
'import json'
"""
if module is not None:
return ast.ImportFrom(
module=module, names=[ast.alias(name=name) for name in names], level=0
)
return ast.ImportFrom(module=module, names=[ast.alias(name=name) for name in names], level=0)
elif module is None and names:
return ast.Import(names=[ast.alias(name=name) for name in names], level=0)
else:
return []


def py_with(
item: ast.expr, body: List[ast.stmt], assignment: str | None = None
) -> ast.With:
def py_with(item: ast.expr, body: List[ast.stmt], assignment: str | None = None) -> ast.With:
# noinspection PyTypeChecker
"""
>>> render_ast(py_with(py_object("Bar"), [ast.Pass()]))
'with Bar():\\n pass'
>>> render_ast(
... py_with(py_object("DAG", dag_id="foo").value, [py_object("Operator", task_id="foo")])
... py_with(
... py_object("DAG", dag_id="foo").value,
... [py_object("Operator", task_id="foo")],
... )
... )
"with DAG(dag_id='foo'):\\n Operator(task_id='foo')"
>>> render_ast(
... py_with(py_object("DAG", dag_id="foo").value, [py_object("Operator", task_id="foo")], "dag")
... py_with(
... py_object("DAG", dag_id="foo").value,
... [py_object("Operator", task_id="foo")],
... "dag",
... )
... )
"with DAG(dag_id='foo') as dag:\\n Operator(task_id='foo')"
"""
Expand All @@ -142,7 +122,7 @@ def py_with(
def py_function(c: Callable):
"""
>>> def foo(a, b):
... print(a + b)
... print(a + b)
>>> render_ast(py_function(foo))
'def foo(a, b):\\n print(a + b)'
"""
Expand Down
6 changes: 1 addition & 5 deletions orbiter/file_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,4 @@ class FileTypeYAML(FileType):
if __name__ == "__main__":
import doctest

doctest.testmod(
optionflags=doctest.ELLIPSIS
| doctest.NORMALIZE_WHITESPACE
| doctest.IGNORE_EXCEPTION_DETAIL
)
doctest.testmod(optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.IGNORE_EXCEPTION_DETAIL)
14 changes: 3 additions & 11 deletions orbiter/objects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@ def conn_id(conn_id: str, prefix: str = "", conn_type: str = "generic") -> dict:
Usage:
```python
OrbiterBashOperator(
**conn_id("my_conn_id")
)
OrbiterBashOperator(**conn_id("my_conn_id"))
```
:param conn_id: The connection id
:type conn_id: str
Expand All @@ -80,11 +78,7 @@ def conn_id(conn_id: str, prefix: str = "", conn_type: str = "generic") -> dict:

return {
f"{prefix + '_' if prefix else ''}conn_id": conn_id,
"orbiter_conns": {
OrbiterConnection(
conn_id=conn_id, **({"conn_type": conn_type} if conn_type else {})
)
},
"orbiter_conns": {OrbiterConnection(conn_id=conn_id, **({"conn_type": conn_type} if conn_type else {}))},
}


Expand All @@ -93,9 +87,7 @@ def pool(name: str) -> dict:
Usage:
```python
OrbiterBashOperator(
**pool("my_pool")
)
OrbiterBashOperator(**pool("my_pool"))
```
:param name: The pool name
:type name: str
Expand Down
Loading

0 comments on commit 964661e

Please sign in to comment.