diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 51c878a7c..a65ce95e0 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -118,7 +118,8 @@ def create( if df is not None: schema = df.schema schema.name = sanitize_sql_table_name(dataset_name) - df.to_parquet(parquet_file_path, index=False) + parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path) + df.to_parquet(parquet_file_path_abs_path, index=False) elif view: _relation = [Relation(**relation) for relation in relations or ()] schema: SemanticLayerSchema = SemanticLayerSchema( diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index bb22ee0a9..7ade9e0c9 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -123,7 +123,7 @@ def _execute_local_sql_query(self, query: str) -> pd.DataFrame: with duckdb.connect() as con: # Register all DataFrames in the state for df in self._state.dfs: - con.register(df.schema.source.table, df) + con.register(df.schema.name, df) # Execute the query and fetch the result as a pandas DataFrame result = con.sql(query).df() @@ -145,10 +145,13 @@ def _execute_sql_query(self, query: str) -> pd.DataFrame: if not self._state.dfs: raise ValueError("No DataFrames available to register for query execution.") - if self._state.dfs[0].schema.source.type in LOCAL_SOURCE_TYPES: + df0 = self._state.dfs[0] + source = df0.schema.source or None + + if source and source.type in LOCAL_SOURCE_TYPES: return self._execute_local_sql_query(query) else: - return self._state.dfs[0].execute_sql_query(query) + return df0.execute_sql_query(query) def execute_with_retries(self, code: str) -> Any: """Execute the code with retry logic.""" diff --git a/pandasai/core/code_generation/code_cleaning.py b/pandasai/core/code_generation/code_cleaning.py index a182e27ef..fe30b483d 100644 --- a/pandasai/core/code_generation/code_cleaning.py +++ b/pandasai/core/code_generation/code_cleaning.py @@ -55,11 +55,8 @@ def _clean_sql_query(self, sql_query: str) -> str: sql_query = sql_query.rstrip(";") table_names = extract_table_names(sql_query) allowed_table_names = { - df.schema.source.table: df.schema.source.table for df in self.context.dfs - } | { - f'"{df.schema.source.table}"': df.schema.source.table - for df in self.context.dfs - } + df.schema.name: df.schema.name for df in self.context.dfs + } | {f'"{df.schema.name}"': df.schema.name for df in self.context.dfs} return self._replace_table_names(sql_query, table_names, allowed_table_names) def _validate_and_make_table_name_case_sensitive(self, node: ast.AST) -> ast.AST: diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index fc50a580b..57d0d29f3 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -5,22 +5,22 @@ from pandasai.dataframe.base import DataFrame from pandasai.exceptions import MethodNotImplementedError +from pandasai.helpers.path import get_validated_dataset_path from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name from .. import ConfigManager from ..constants import ( LOCAL_SOURCE_TYPES, ) -from .query_builder import QueryBuilder from .semantic_layer_schema import SemanticLayerSchema from .transformation_manager import TransformationManager -from .view_query_builder import ViewQueryBuilder class DatasetLoader: def __init__(self, schema: SemanticLayerSchema, dataset_path: str): self.schema = schema self.dataset_path = dataset_path + self.org_name, self.dataset_name = get_validated_dataset_path(self.dataset_path) @classmethod def create_loader_from_schema( diff --git a/pandasai/data_loader/local_loader.py b/pandasai/data_loader/local_loader.py index 69dc298a5..3587b23b5 100644 --- a/pandasai/data_loader/local_loader.py +++ b/pandasai/data_loader/local_loader.py @@ -5,11 +5,11 @@ from pandasai.dataframe.base import DataFrame from pandasai.exceptions import InvalidDataSourceType +from ..config import ConfigManager from ..constants import ( LOCAL_SOURCE_TYPES, ) from .loader import DatasetLoader -from .transformation_manager import TransformationManager class LocalDatasetLoader(DatasetLoader): @@ -44,10 +44,11 @@ def _load_from_local_source(self) -> pd.DataFrame: return self._read_csv_or_parquet(filepath, source_type) def _read_csv_or_parquet(self, file_path: str, file_format: str) -> pd.DataFrame: + file_manager = ConfigManager.get().file_manager if file_format == "parquet": - df = pd.read_parquet(file_path) + df = pd.read_parquet(file_manager.abs_path(file_path)) elif file_format == "csv": - df = pd.read_csv(file_path) + df = pd.read_csv(file_manager.abs_path(file_path)) else: raise ValueError(f"Unsupported file format: {file_format}") diff --git a/pandasai/data_loader/query_builder.py b/pandasai/data_loader/query_builder.py index de8635532..cc5a2f79c 100644 --- a/pandasai/data_loader/query_builder.py +++ b/pandasai/data_loader/query_builder.py @@ -1,3 +1,4 @@ +import re from typing import Any, Dict, List, Union from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema @@ -8,11 +9,15 @@ def __init__(self, schema: SemanticLayerSchema): self.schema = schema def format_query(self, query): - return query + pattern = re.compile( + rf"\bFROM\s+{re.escape(self.schema.name)}\b", re.IGNORECASE + ) + replacement = self._get_from_statement() + return pattern.sub(replacement, query) def build_query(self) -> str: columns = self._get_columns() - query = f"SELECT {columns}" + query = f"SELECT {columns} " query += self._get_from_statement() query += self._add_order_by() query += self._add_limit() @@ -26,7 +31,7 @@ def _get_columns(self) -> str: return "*" def _get_from_statement(self): - return f" FROM {self.schema.source.table.lower()}" + return f"FROM {self.schema.source.table.lower()}" def _add_order_by(self) -> str: if not self.schema.order_by: @@ -47,7 +52,7 @@ def _add_limit(self, n=None) -> str: def get_head_query(self, n=5): source_type = self.schema.source.type columns = self._get_columns() - query = f"SELECT {columns}" + query = f"SELECT {columns} " query += self._get_from_statement() order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()" return f"{query} ORDER BY {order_by} LIMIT {n}" diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index 1ef6a4595..319518f56 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -29,6 +29,15 @@ class SQLConnectionConfig(BaseModel): user: str = Field(..., description="Database username") password: str = Field(..., description="Database password") + def __eq__(self, other): + return ( + self.host == other.host + and self.port == other.port + and self.database == other.database + and self.user == other.user + and self.password == other.password + ) + class Column(BaseModel): name: str = Field(..., description="Name of the column.") @@ -174,6 +183,28 @@ class Source(BaseModel): ) table: Optional[str] = Field(None, description="Table of the data source.") + def is_compatible_source(self, source2: "Source"): + """ + Checks if two sources are compatible for combining in a view. + + Two sources are considered compatible if: + - Both are local sources. + - Both are remote sources with the same connection. + + Compatible sources can be used together within the same view. + + Args: + source2 (Source): The source to compare against. + + Returns: + bool: True if the sources can be combined in a view, False otherwise. + """ + if self.type in LOCAL_SOURCE_TYPES and source2.type in LOCAL_SOURCE_TYPES: + return True + if self.type in REMOTE_SOURCE_TYPES and source2.type in REMOTE_SOURCE_TYPES: + return self.connection == source2.connection + return False + @model_validator(mode="before") @classmethod def validate_type_and_fields(cls, values): diff --git a/pandasai/data_loader/view_loader.py b/pandasai/data_loader/view_loader.py index df26c283f..8e26f8916 100644 --- a/pandasai/data_loader/view_loader.py +++ b/pandasai/data_loader/view_loader.py @@ -1,6 +1,14 @@ +from typing import Dict, Optional + +import pandas as pd + from pandasai.dataframe.virtual_dataframe import VirtualDataFrame -from .semantic_layer_schema import SemanticLayerSchema +from .. import InvalidConfigError +from ..exceptions import MaliciousQueryError +from ..helpers.sql_sanitizer import is_sql_query_safe +from .loader import DatasetLoader +from .semantic_layer_schema import SemanticLayerSchema, Source, is_schema_source_same from .sql_loader import SQLDatasetLoader from .view_query_builder import ViewQueryBuilder @@ -12,7 +20,40 @@ class ViewDatasetLoader(SQLDatasetLoader): def __init__(self, schema: SemanticLayerSchema, dataset_path: str): super().__init__(schema, dataset_path) - self.query_builder: ViewQueryBuilder = ViewQueryBuilder(schema) + self.dependencies_datasets = self._get_dependencies_datasets() + self.schema_dependencies_dict: dict[ + str, DatasetLoader + ] = self._get_dependencies_schemas() + self.source: Source = list(self.schema_dependencies_dict.values())[ + 0 + ].schema.source + self.query_builder: ViewQueryBuilder = ViewQueryBuilder( + schema, self.schema_dependencies_dict + ) + + def _get_dependencies_datasets(self) -> set[str]: + return { + table.split(".")[0] + for relation in self.schema.relations + for table in (relation.from_, relation.to) + } + + def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]: + dependency_dict = { + dep: DatasetLoader.create_loader_from_path(f"{self.org_name}/{dep}") + for dep in self.dependencies_datasets + } + + loaders = list(dependency_dict.values()) + base_source = loaders[0].schema.source + + for loader in loaders[1:]: + if not base_source.is_compatible_source(loader.schema.source): + raise ValueError( + f"Source in loader with schema {loader.schema} is not compatible with the first loader's source." + ) + + return dependency_dict def load(self) -> VirtualDataFrame: return VirtualDataFrame( @@ -20,3 +61,30 @@ def load(self) -> VirtualDataFrame: data_loader=ViewDatasetLoader(self.schema, self.dataset_path), path=self.dataset_path, ) + + def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFrame: + source_type = self.source.type + connection_info = self.source.connection + + formatted_query = self.query_builder.format_query(query) + load_function = self._get_loader_function(source_type) + + if not is_sql_query_safe(formatted_query): + raise MaliciousQueryError( + "The SQL query is deemed unsafe and will not be executed." + ) + try: + dataframe: pd.DataFrame = load_function( + connection_info, formatted_query, params + ) + return dataframe + + except ModuleNotFoundError as e: + raise ImportError( + f"{source_type.capitalize()} connector not found. Please install the pandasai_sql[{source_type}] library, e.g. `pip install pandasai_sql[{source_type}]`." + ) from e + + except Exception as e: + raise RuntimeError( + f"Failed to execute query for '{source_type}' with: {formatted_query}" + ) from e diff --git a/pandasai/data_loader/view_query_builder.py b/pandasai/data_loader/view_query_builder.py index 4c1a606df..24484e0a4 100644 --- a/pandasai/data_loader/view_query_builder.py +++ b/pandasai/data_loader/view_query_builder.py @@ -1,20 +1,24 @@ -from typing import Any, Dict, List, Union +import re +from typing import Dict +from pandasai.data_loader.loader import DatasetLoader from pandasai.data_loader.query_builder import QueryBuilder -from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema +from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema +from pandasai.data_loader.sql_loader import SQLDatasetLoader class ViewQueryBuilder(QueryBuilder): - def __init__(self, schema: SemanticLayerSchema): + def __init__( + self, + schema: SemanticLayerSchema, + schema_dependencies_dict: Dict[str, DatasetLoader], + ): super().__init__(schema) - - def format_query(self, query): - return f"{self._get_with_statement()}{query}" + self.schema_dependencies_dict = schema_dependencies_dict def build_query(self) -> str: columns = self._get_columns() - query = self._get_with_statement() - query += f"SELECT {columns}" + query = f"SELECT {columns} " query += self._get_from_statement() query += self._add_order_by() query += self._add_limit() @@ -28,13 +32,31 @@ def _get_columns(self) -> str: else: return super()._get_columns() - def _get_from_statement(self): - return f" FROM {self.schema.name}" + def _get_columns_for_table(self, query): + match = re.search(r"SELECT\s+(.*?)\s+FROM", query, re.IGNORECASE) + if not match: + return None + + columns = match.group(1).split(",") + return [col.strip() for col in columns] + + def _get_sub_query_from_loader(self, loader: SQLDatasetLoader) -> (str, str): + query = loader.query_builder.build_query() + return query, loader.schema.name - def _get_with_statement(self): + def _get_from_statement(self): relations = self.schema.relations - first_table = relations[0].from_.split(".")[0] - query = f"WITH {self.schema.name} AS ( SELECT\n" + first_dataset = relations[0].from_.split(".")[0] + first_loader = self.schema_dependencies_dict[first_dataset] + + if isinstance(first_loader, SQLDatasetLoader): + first_query, first_name = self._get_sub_query_from_loader(first_loader) + else: + raise ValueError( + "Views for local datasets or nested views are currently not supported." + ) + + query = f"FROM ( SELECT\n" if self.schema.columns: query += ", ".join( @@ -44,11 +66,21 @@ def _get_with_statement(self): ] ) else: - query += "*" + query += "* " - query += f"\nFROM {first_table}" + query += f"\nFROM ( {first_query} ) AS {first_name}" for relation in relations: - to_table = relation.to.split(".")[0] - query += f"\nJOIN {to_table} ON {relation.from_} = {relation.to}" - query += ")\n" + to_datasets = relation.to.split(".")[0] + loader = self.schema_dependencies_dict[to_datasets] + subquery, dataset_name = self._get_sub_query_from_loader(loader) + query += f"\nJOIN ( {subquery} ) AS {dataset_name}\n" + query += f"ON {relation.from_} = {relation.to}" + query += f") AS {self.schema.name}\n" + return query + + def get_head_query(self, n=5): + columns = self._get_columns() + query = f"SELECT {columns}" + query += self._get_from_statement() + return f"{query} LIMIT {n}" diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index 702b3a5cd..2a5196bcb 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -163,17 +163,16 @@ def push(self): "name": self.schema.name, } - dataset_directory = os.path.join("datasets", self.path) file_manager = ConfigManager.get().file_manager headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"} - files = [] - schema_file_path = os.path.join(dataset_directory, "schema.yaml") - data_file_path = os.path.join(dataset_directory, "data.parquet") + schema_file_path = os.path.join(self.path, "schema.yaml") + data_file_path = os.path.join(self.path, "data.parquet") # Open schema.yaml schema_file = file_manager.load_binary(schema_file_path) - files.append(("files", ("schema.yaml", schema_file, "application/x-yaml"))) + + files = [("files", ("schema.yaml", schema_file, "application/x-yaml"))] # Check if data.parquet exists and open it if file_manager.exists(data_file_path): diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index 4970e5bdd..7bef0b486 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -18,11 +18,7 @@ def serialize(df: "DataFrame") -> str: Returns: str: dataframe stringify """ - dataframe_info = "