Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 84 additions & 12 deletions fastetl/operators/db_to_csv_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from pathlib import Path
from typing import Optional

import pandas as pd

from airflow.models.baseoperator import BaseOperator
from airflow.utils.context import Context

from fastetl.custom_functions.utils.get_table_cols_name import get_table_cols_name
from fastetl.custom_functions.utils.db_connection import get_hook_and_engine_by_provider
Expand Down Expand Up @@ -84,7 +87,18 @@ def __init__(
self.columns_to_remove: Optional[list[str]] = columns_to_remove
self.int_columns: Optional[list[str]] = int_columns

def select_all_sql(self):
def _select_all_sql(self) -> str:
"""Generate a SELECT statement to fetch all columns from
a table.

Returns:
str: a SELECT statement for all columns
"""
if any(argument is None for argument in (self.table_scheme, self.table_name)):
raise ValueError(
"table_scheme and table_name are required "
"when select_sql is not provided."
)
cols = get_table_cols_name(self.conn_id, self.table_scheme, self.table_name)
if self.columns_to_remove:
cols = [c for c in cols if c not in self.columns_to_remove]
Expand All @@ -95,26 +109,84 @@ def select_all_sql(self):
FROM {self.table_scheme}.{self.table_name};
"""

def execute(self, context):
def _starts_with_sql_keyword(self, value: str) -> bool:
"""Check if a string starts with a SQL keyword."""
sql_keywords = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE")
return any(value.strip().upper().startswith(kw) for kw in sql_keywords)

def _resolve_select_sql(self, select_sql: Optional[str | Path]) -> str:
"""Resolve select_sql to a query string, handling files and inline
queries.

Args:
select_sql (Optional[str | Path]): The SQL query or file path to
resolve. If a string, it can be either a file path or an inline
SQL query. If a file path, the file must exist.

Raises:
FileNotFoundError: if a file path is specified but the file does not
exist.
TypeError: if select_sql is neither a string nor a Path object.
ValueError: if select_sql is None or empty and table_scheme or
table_name are not provided.

Returns:
str: the resolved SQL query string.
"""

if not select_sql:
return self._select_all_sql()

# Handle Path objects
if isinstance(select_sql, Path):
if select_sql.is_file():
return select_sql.read_text(encoding="utf-8")
raise FileNotFoundError(f"File not found: {select_sql}")

# Handle strings
if isinstance(select_sql, str):
# Quick check: SQL keywords indicate it's a query, not a path
if self._starts_with_sql_keyword(select_sql):
return select_sql

# Attempt to load as a file path
try:
path = Path(select_sql)
if path.is_file():
return path.read_text(encoding="utf-8")
except OSError:
# Catches "File name too long" and other path-related OS errors
# Treat as a query string instead
pass

# Default: treat as a query string
return select_sql

raise TypeError(f"select_sql must be a string or Path, got {type(select_sql)}")

def execute(self, context: Context):
"""Executes the SQL query and saves the CSV file. In the process,
converts data types to integers and removes specified characters
if the options have been specified.

Args:
context (Context): The Airflow context object.

Returns:
str: The name of the written file.
"""
_ = context # left unused
db_hook, _ = get_hook_and_engine_by_provider(self.conn_id)

if self.select_sql:
path_sql = Path(self.select_sql)
if path_sql.is_file():
query = path_sql.read_text(encoding="utf-8")
else:
query = self.select_sql
df_select = query
else:
df_select = self.select_all_sql()
df_select = self._resolve_select_sql(self.select_sql)

self.log.info(f"Executing SQL check: {df_select}")
df = db_hook.get_pandas_df(df_select)

# Convert columns data types to int
if self.int_columns:
for col in self.int_columns:
df[col] = df[col].astype("Int64")
df[col] = pd.to_numeric(df[col], errors="coerce").astype("Int64")

# Remove specified characters
if self.characters_to_remove:
Expand Down
Loading