diff --git a/fastetl/operators/db_to_csv_operator.py b/fastetl/operators/db_to_csv_operator.py index 4547424..1516f9e 100644 --- a/fastetl/operators/db_to_csv_operator.py +++ b/fastetl/operators/db_to_csv_operator.py @@ -10,7 +10,10 @@ from pathlib import Path from typing import Optional +import pandas as pd + from airflow.models.baseoperator import BaseOperator +from airflow.utils.context import Context from fastetl.custom_functions.utils.get_table_cols_name import get_table_cols_name from fastetl.custom_functions.utils.db_connection import get_hook_and_engine_by_provider @@ -84,7 +87,18 @@ def __init__( self.columns_to_remove: Optional[list[str]] = columns_to_remove self.int_columns: Optional[list[str]] = int_columns - def select_all_sql(self): + def _select_all_sql(self) -> str: + """Generate a SELECT statement to fetch all columns from + a table. + + Returns: + str: a SELECT statement for all columns + """ + if any(argument is None for argument in (self.table_scheme, self.table_name)): + raise ValueError( + "table_scheme and table_name are required " + "when select_sql is not provided." + ) cols = get_table_cols_name(self.conn_id, self.table_scheme, self.table_name) if self.columns_to_remove: cols = [c for c in cols if c not in self.columns_to_remove] @@ -95,18 +109,76 @@ def select_all_sql(self): FROM {self.table_scheme}.{self.table_name}; """ - def execute(self, context): + def _starts_with_sql_keyword(self, value: str) -> bool: + """Check if a string starts with a SQL keyword.""" + sql_keywords = ("SELECT", "WITH", "INSERT", "UPDATE", "DELETE") + return any(value.strip().upper().startswith(kw) for kw in sql_keywords) + + def _resolve_select_sql(self, select_sql: Optional[str | Path]) -> str: + """Resolve select_sql to a query string, handling files and inline + queries. + + Args: + select_sql (Optional[str | Path]): The SQL query or file path to + resolve. If a string, it can be either a file path or an inline + SQL query. If a file path, the file must exist. + + Raises: + FileNotFoundError: if a file path is specified but the file does not + exist. + TypeError: if select_sql is neither a string nor a Path object. + ValueError: if select_sql is None or empty and table_scheme or + table_name are not provided. + + Returns: + str: the resolved SQL query string. + """ + + if not select_sql: + return self._select_all_sql() + + # Handle Path objects + if isinstance(select_sql, Path): + if select_sql.is_file(): + return select_sql.read_text(encoding="utf-8") + raise FileNotFoundError(f"File not found: {select_sql}") + + # Handle strings + if isinstance(select_sql, str): + # Quick check: SQL keywords indicate it's a query, not a path + if self._starts_with_sql_keyword(select_sql): + return select_sql + + # Attempt to load as a file path + try: + path = Path(select_sql) + if path.is_file(): + return path.read_text(encoding="utf-8") + except OSError: + # Catches "File name too long" and other path-related OS errors + # Treat as a query string instead + pass + + # Default: treat as a query string + return select_sql + + raise TypeError(f"select_sql must be a string or Path, got {type(select_sql)}") + + def execute(self, context: Context): + """Executes the SQL query and saves the CSV file. In the process, + converts data types to integers and removes specified characters + if the options have been specified. + + Args: + context (Context): The Airflow context object. + + Returns: + str: The name of the written file. + """ + _ = context # left unused db_hook, _ = get_hook_and_engine_by_provider(self.conn_id) - if self.select_sql: - path_sql = Path(self.select_sql) - if path_sql.is_file(): - query = path_sql.read_text(encoding="utf-8") - else: - query = self.select_sql - df_select = query - else: - df_select = self.select_all_sql() + df_select = self._resolve_select_sql(self.select_sql) self.log.info(f"Executing SQL check: {df_select}") df = db_hook.get_pandas_df(df_select) @@ -114,7 +186,7 @@ def execute(self, context): # Convert columns data types to int if self.int_columns: for col in self.int_columns: - df[col] = df[col].astype("Int64") + df[col] = pd.to_numeric(df[col], errors="coerce").astype("Int64") # Remove specified characters if self.characters_to_remove: