diff --git a/pandasai/__init__.py b/pandasai/__init__.py index 51c878a7c..a65ce95e0 100644 --- a/pandasai/__init__.py +++ b/pandasai/__init__.py @@ -118,7 +118,8 @@ def create( if df is not None: schema = df.schema schema.name = sanitize_sql_table_name(dataset_name) - df.to_parquet(parquet_file_path, index=False) + parquet_file_path_abs_path = file_manager.abs_path(parquet_file_path) + df.to_parquet(parquet_file_path_abs_path, index=False) elif view: _relation = [Relation(**relation) for relation in relations or ()] schema: SemanticLayerSchema = SemanticLayerSchema( diff --git a/pandasai/agent/base.py b/pandasai/agent/base.py index bb22ee0a9..7ade9e0c9 100644 --- a/pandasai/agent/base.py +++ b/pandasai/agent/base.py @@ -123,7 +123,7 @@ def _execute_local_sql_query(self, query: str) -> pd.DataFrame: with duckdb.connect() as con: # Register all DataFrames in the state for df in self._state.dfs: - con.register(df.schema.source.table, df) + con.register(df.schema.name, df) # Execute the query and fetch the result as a pandas DataFrame result = con.sql(query).df() @@ -145,10 +145,13 @@ def _execute_sql_query(self, query: str) -> pd.DataFrame: if not self._state.dfs: raise ValueError("No DataFrames available to register for query execution.") - if self._state.dfs[0].schema.source.type in LOCAL_SOURCE_TYPES: + df0 = self._state.dfs[0] + source = df0.schema.source or None + + if source and source.type in LOCAL_SOURCE_TYPES: return self._execute_local_sql_query(query) else: - return self._state.dfs[0].execute_sql_query(query) + return df0.execute_sql_query(query) def execute_with_retries(self, code: str) -> Any: """Execute the code with retry logic.""" diff --git a/pandasai/core/code_generation/code_cleaning.py b/pandasai/core/code_generation/code_cleaning.py index a182e27ef..fe30b483d 100644 --- a/pandasai/core/code_generation/code_cleaning.py +++ b/pandasai/core/code_generation/code_cleaning.py @@ -55,11 +55,8 @@ def _clean_sql_query(self, sql_query: str) -> str: sql_query = sql_query.rstrip(";") table_names = extract_table_names(sql_query) allowed_table_names = { - df.schema.source.table: df.schema.source.table for df in self.context.dfs - } | { - f'"{df.schema.source.table}"': df.schema.source.table - for df in self.context.dfs - } + df.schema.name: df.schema.name for df in self.context.dfs + } | {f'"{df.schema.name}"': df.schema.name for df in self.context.dfs} return self._replace_table_names(sql_query, table_names, allowed_table_names) def _validate_and_make_table_name_case_sensitive(self, node: ast.AST) -> ast.AST: diff --git a/pandasai/data_loader/loader.py b/pandasai/data_loader/loader.py index fc50a580b..57d0d29f3 100644 --- a/pandasai/data_loader/loader.py +++ b/pandasai/data_loader/loader.py @@ -5,22 +5,22 @@ from pandasai.dataframe.base import DataFrame from pandasai.exceptions import MethodNotImplementedError +from pandasai.helpers.path import get_validated_dataset_path from pandasai.helpers.sql_sanitizer import sanitize_sql_table_name from .. import ConfigManager from ..constants import ( LOCAL_SOURCE_TYPES, ) -from .query_builder import QueryBuilder from .semantic_layer_schema import SemanticLayerSchema from .transformation_manager import TransformationManager -from .view_query_builder import ViewQueryBuilder class DatasetLoader: def __init__(self, schema: SemanticLayerSchema, dataset_path: str): self.schema = schema self.dataset_path = dataset_path + self.org_name, self.dataset_name = get_validated_dataset_path(self.dataset_path) @classmethod def create_loader_from_schema( diff --git a/pandasai/data_loader/local_loader.py b/pandasai/data_loader/local_loader.py index 69dc298a5..3587b23b5 100644 --- a/pandasai/data_loader/local_loader.py +++ b/pandasai/data_loader/local_loader.py @@ -5,11 +5,11 @@ from pandasai.dataframe.base import DataFrame from pandasai.exceptions import InvalidDataSourceType +from ..config import ConfigManager from ..constants import ( LOCAL_SOURCE_TYPES, ) from .loader import DatasetLoader -from .transformation_manager import TransformationManager class LocalDatasetLoader(DatasetLoader): @@ -44,10 +44,11 @@ def _load_from_local_source(self) -> pd.DataFrame: return self._read_csv_or_parquet(filepath, source_type) def _read_csv_or_parquet(self, file_path: str, file_format: str) -> pd.DataFrame: + file_manager = ConfigManager.get().file_manager if file_format == "parquet": - df = pd.read_parquet(file_path) + df = pd.read_parquet(file_manager.abs_path(file_path)) elif file_format == "csv": - df = pd.read_csv(file_path) + df = pd.read_csv(file_manager.abs_path(file_path)) else: raise ValueError(f"Unsupported file format: {file_format}") diff --git a/pandasai/data_loader/query_builder.py b/pandasai/data_loader/query_builder.py index de8635532..cc5a2f79c 100644 --- a/pandasai/data_loader/query_builder.py +++ b/pandasai/data_loader/query_builder.py @@ -1,3 +1,4 @@ +import re from typing import Any, Dict, List, Union from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema @@ -8,11 +9,15 @@ def __init__(self, schema: SemanticLayerSchema): self.schema = schema def format_query(self, query): - return query + pattern = re.compile( + rf"\bFROM\s+{re.escape(self.schema.name)}\b", re.IGNORECASE + ) + replacement = self._get_from_statement() + return pattern.sub(replacement, query) def build_query(self) -> str: columns = self._get_columns() - query = f"SELECT {columns}" + query = f"SELECT {columns} " query += self._get_from_statement() query += self._add_order_by() query += self._add_limit() @@ -26,7 +31,7 @@ def _get_columns(self) -> str: return "*" def _get_from_statement(self): - return f" FROM {self.schema.source.table.lower()}" + return f"FROM {self.schema.source.table.lower()}" def _add_order_by(self) -> str: if not self.schema.order_by: @@ -47,7 +52,7 @@ def _add_limit(self, n=None) -> str: def get_head_query(self, n=5): source_type = self.schema.source.type columns = self._get_columns() - query = f"SELECT {columns}" + query = f"SELECT {columns} " query += self._get_from_statement() order_by = "RANDOM()" if source_type in {"sqlite", "postgres"} else "RAND()" return f"{query} ORDER BY {order_by} LIMIT {n}" diff --git a/pandasai/data_loader/semantic_layer_schema.py b/pandasai/data_loader/semantic_layer_schema.py index 1ef6a4595..319518f56 100644 --- a/pandasai/data_loader/semantic_layer_schema.py +++ b/pandasai/data_loader/semantic_layer_schema.py @@ -29,6 +29,15 @@ class SQLConnectionConfig(BaseModel): user: str = Field(..., description="Database username") password: str = Field(..., description="Database password") + def __eq__(self, other): + return ( + self.host == other.host + and self.port == other.port + and self.database == other.database + and self.user == other.user + and self.password == other.password + ) + class Column(BaseModel): name: str = Field(..., description="Name of the column.") @@ -174,6 +183,28 @@ class Source(BaseModel): ) table: Optional[str] = Field(None, description="Table of the data source.") + def is_compatible_source(self, source2: "Source"): + """ + Checks if two sources are compatible for combining in a view. + + Two sources are considered compatible if: + - Both are local sources. + - Both are remote sources with the same connection. + + Compatible sources can be used together within the same view. + + Args: + source2 (Source): The source to compare against. + + Returns: + bool: True if the sources can be combined in a view, False otherwise. + """ + if self.type in LOCAL_SOURCE_TYPES and source2.type in LOCAL_SOURCE_TYPES: + return True + if self.type in REMOTE_SOURCE_TYPES and source2.type in REMOTE_SOURCE_TYPES: + return self.connection == source2.connection + return False + @model_validator(mode="before") @classmethod def validate_type_and_fields(cls, values): diff --git a/pandasai/data_loader/view_loader.py b/pandasai/data_loader/view_loader.py index df26c283f..8e26f8916 100644 --- a/pandasai/data_loader/view_loader.py +++ b/pandasai/data_loader/view_loader.py @@ -1,6 +1,14 @@ +from typing import Dict, Optional + +import pandas as pd + from pandasai.dataframe.virtual_dataframe import VirtualDataFrame -from .semantic_layer_schema import SemanticLayerSchema +from .. import InvalidConfigError +from ..exceptions import MaliciousQueryError +from ..helpers.sql_sanitizer import is_sql_query_safe +from .loader import DatasetLoader +from .semantic_layer_schema import SemanticLayerSchema, Source, is_schema_source_same from .sql_loader import SQLDatasetLoader from .view_query_builder import ViewQueryBuilder @@ -12,7 +20,40 @@ class ViewDatasetLoader(SQLDatasetLoader): def __init__(self, schema: SemanticLayerSchema, dataset_path: str): super().__init__(schema, dataset_path) - self.query_builder: ViewQueryBuilder = ViewQueryBuilder(schema) + self.dependencies_datasets = self._get_dependencies_datasets() + self.schema_dependencies_dict: dict[ + str, DatasetLoader + ] = self._get_dependencies_schemas() + self.source: Source = list(self.schema_dependencies_dict.values())[ + 0 + ].schema.source + self.query_builder: ViewQueryBuilder = ViewQueryBuilder( + schema, self.schema_dependencies_dict + ) + + def _get_dependencies_datasets(self) -> set[str]: + return { + table.split(".")[0] + for relation in self.schema.relations + for table in (relation.from_, relation.to) + } + + def _get_dependencies_schemas(self) -> dict[str, DatasetLoader]: + dependency_dict = { + dep: DatasetLoader.create_loader_from_path(f"{self.org_name}/{dep}") + for dep in self.dependencies_datasets + } + + loaders = list(dependency_dict.values()) + base_source = loaders[0].schema.source + + for loader in loaders[1:]: + if not base_source.is_compatible_source(loader.schema.source): + raise ValueError( + f"Source in loader with schema {loader.schema} is not compatible with the first loader's source." + ) + + return dependency_dict def load(self) -> VirtualDataFrame: return VirtualDataFrame( @@ -20,3 +61,30 @@ def load(self) -> VirtualDataFrame: data_loader=ViewDatasetLoader(self.schema, self.dataset_path), path=self.dataset_path, ) + + def execute_query(self, query: str, params: Optional[list] = None) -> pd.DataFrame: + source_type = self.source.type + connection_info = self.source.connection + + formatted_query = self.query_builder.format_query(query) + load_function = self._get_loader_function(source_type) + + if not is_sql_query_safe(formatted_query): + raise MaliciousQueryError( + "The SQL query is deemed unsafe and will not be executed." + ) + try: + dataframe: pd.DataFrame = load_function( + connection_info, formatted_query, params + ) + return dataframe + + except ModuleNotFoundError as e: + raise ImportError( + f"{source_type.capitalize()} connector not found. Please install the pandasai_sql[{source_type}] library, e.g. `pip install pandasai_sql[{source_type}]`." + ) from e + + except Exception as e: + raise RuntimeError( + f"Failed to execute query for '{source_type}' with: {formatted_query}" + ) from e diff --git a/pandasai/data_loader/view_query_builder.py b/pandasai/data_loader/view_query_builder.py index 4c1a606df..24484e0a4 100644 --- a/pandasai/data_loader/view_query_builder.py +++ b/pandasai/data_loader/view_query_builder.py @@ -1,20 +1,24 @@ -from typing import Any, Dict, List, Union +import re +from typing import Dict +from pandasai.data_loader.loader import DatasetLoader from pandasai.data_loader.query_builder import QueryBuilder -from pandasai.data_loader.semantic_layer_schema import Relation, SemanticLayerSchema +from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema +from pandasai.data_loader.sql_loader import SQLDatasetLoader class ViewQueryBuilder(QueryBuilder): - def __init__(self, schema: SemanticLayerSchema): + def __init__( + self, + schema: SemanticLayerSchema, + schema_dependencies_dict: Dict[str, DatasetLoader], + ): super().__init__(schema) - - def format_query(self, query): - return f"{self._get_with_statement()}{query}" + self.schema_dependencies_dict = schema_dependencies_dict def build_query(self) -> str: columns = self._get_columns() - query = self._get_with_statement() - query += f"SELECT {columns}" + query = f"SELECT {columns} " query += self._get_from_statement() query += self._add_order_by() query += self._add_limit() @@ -28,13 +32,31 @@ def _get_columns(self) -> str: else: return super()._get_columns() - def _get_from_statement(self): - return f" FROM {self.schema.name}" + def _get_columns_for_table(self, query): + match = re.search(r"SELECT\s+(.*?)\s+FROM", query, re.IGNORECASE) + if not match: + return None + + columns = match.group(1).split(",") + return [col.strip() for col in columns] + + def _get_sub_query_from_loader(self, loader: SQLDatasetLoader) -> (str, str): + query = loader.query_builder.build_query() + return query, loader.schema.name - def _get_with_statement(self): + def _get_from_statement(self): relations = self.schema.relations - first_table = relations[0].from_.split(".")[0] - query = f"WITH {self.schema.name} AS ( SELECT\n" + first_dataset = relations[0].from_.split(".")[0] + first_loader = self.schema_dependencies_dict[first_dataset] + + if isinstance(first_loader, SQLDatasetLoader): + first_query, first_name = self._get_sub_query_from_loader(first_loader) + else: + raise ValueError( + "Views for local datasets or nested views are currently not supported." + ) + + query = f"FROM ( SELECT\n" if self.schema.columns: query += ", ".join( @@ -44,11 +66,21 @@ def _get_with_statement(self): ] ) else: - query += "*" + query += "* " - query += f"\nFROM {first_table}" + query += f"\nFROM ( {first_query} ) AS {first_name}" for relation in relations: - to_table = relation.to.split(".")[0] - query += f"\nJOIN {to_table} ON {relation.from_} = {relation.to}" - query += ")\n" + to_datasets = relation.to.split(".")[0] + loader = self.schema_dependencies_dict[to_datasets] + subquery, dataset_name = self._get_sub_query_from_loader(loader) + query += f"\nJOIN ( {subquery} ) AS {dataset_name}\n" + query += f"ON {relation.from_} = {relation.to}" + query += f") AS {self.schema.name}\n" + return query + + def get_head_query(self, n=5): + columns = self._get_columns() + query = f"SELECT {columns}" + query += self._get_from_statement() + return f"{query} LIMIT {n}" diff --git a/pandasai/dataframe/base.py b/pandasai/dataframe/base.py index 702b3a5cd..2a5196bcb 100644 --- a/pandasai/dataframe/base.py +++ b/pandasai/dataframe/base.py @@ -163,17 +163,16 @@ def push(self): "name": self.schema.name, } - dataset_directory = os.path.join("datasets", self.path) file_manager = ConfigManager.get().file_manager headers = {"accept": "application/json", "x-authorization": f"Bearer {api_key}"} - files = [] - schema_file_path = os.path.join(dataset_directory, "schema.yaml") - data_file_path = os.path.join(dataset_directory, "data.parquet") + schema_file_path = os.path.join(self.path, "schema.yaml") + data_file_path = os.path.join(self.path, "data.parquet") # Open schema.yaml schema_file = file_manager.load_binary(schema_file_path) - files.append(("files", ("schema.yaml", schema_file, "application/x-yaml"))) + + files = [("files", ("schema.yaml", schema_file, "application/x-yaml"))] # Check if data.parquet exists and open it if file_manager.exists(data_file_path): diff --git a/pandasai/helpers/dataframe_serializer.py b/pandasai/helpers/dataframe_serializer.py index 4970e5bdd..7bef0b486 100644 --- a/pandasai/helpers/dataframe_serializer.py +++ b/pandasai/helpers/dataframe_serializer.py @@ -18,11 +18,7 @@ def serialize(df: "DataFrame") -> str: Returns: str: dataframe stringify """ - dataframe_info = " None: """Creates a directory if it doesn't exist.""" pass + @abstractmethod + def abs_path(self, file_path: str) -> str: + """Returns the absolute path of {file_path}""" + pass + class DefaultFileManager(FileManager): """Local file system implementation of FileLoader.""" @@ -45,29 +50,26 @@ def __init__(self): self.base_path = os.path.join(find_project_root(), "datasets") def load(self, file_path: str) -> str: - full_path = os.path.join(self.base_path, file_path) - with open(full_path, "r", encoding="utf-8") as f: + with open(self.abs_path(file_path), "r", encoding="utf-8") as f: return f.read() def load_binary(self, file_path: str) -> bytes: - full_path = os.path.join(self.base_path, file_path) - with open(full_path, "rb") as f: + with open(self.abs_path(file_path), "rb") as f: return f.read() def write(self, file_path: str, content: str) -> None: - full_path = os.path.join(self.base_path, file_path) - with open(full_path, "w", encoding="utf-8") as f: + with open(self.abs_path(file_path), "w", encoding="utf-8") as f: f.write(content) def write_binary(self, file_path: str, content: bytes) -> None: - full_path = os.path.join(self.base_path, file_path) - with open(full_path, "wb") as f: + with open(self.abs_path(file_path), "wb") as f: f.write(content) def exists(self, file_path: str) -> bool: - full_path = os.path.join(self.base_path, file_path) - return os.path.exists(full_path) + return os.path.exists(self.abs_path(file_path)) def mkdir(self, dir_path: str) -> None: - full_path = os.path.join(self.base_path, dir_path) - os.makedirs(full_path, exist_ok=True) + os.makedirs(self.abs_path(dir_path), exist_ok=True) + + def abs_path(self, file_path: str) -> str: + return os.path.join(self.base_path, file_path) diff --git a/tests/unit_tests/agent/test_agent.py b/tests/unit_tests/agent/test_agent.py index 87edf1889..c2952d222 100644 --- a/tests/unit_tests/agent/test_agent.py +++ b/tests/unit_tests/agent/test_agent.py @@ -79,7 +79,7 @@ def test_generate_code_with_cache_hit( self, mock_generate_code, agent: Agent, sample_df ): # Set up the cache to return a pre-cached response - cached_code = f"""execute_sql_query('SELECT A FROM {sample_df.schema.source.table}') + cached_code = f"""execute_sql_query('SELECT A FROM {sample_df.schema.name}') print('Cached result: US has the highest GDP.')""" agent._state.config.enable_cache = True agent._state.cache.get = MagicMock(return_value=cached_code) @@ -420,7 +420,7 @@ def test_train_method_with_code_but_no_queries(self, agent): agent.train(codes) def test_execute_local_sql_query_success(self, agent, sample_df): - query = f"SELECT count(*) as total from {sample_df.schema.source.table};" + query = f'SELECT count(*) as total from "{sample_df.schema.name}";' expected_result = pd.DataFrame({"total": [3]}) result = agent._execute_local_sql_query(query) pd.testing.assert_frame_equal(result, expected_result) @@ -430,7 +430,7 @@ def test_execute_local_sql_query_failure(self, agent): agent._execute_local_sql_query("wrong query;") def test_execute_sql_query_success_local(self, agent, sample_df): - query = f"SELECT count(*) as total from {sample_df.schema.source.table};" + query = f'SELECT count(*) as total from "{sample_df.schema.name}";' expected_result = pd.DataFrame({"total": [3]}) result = agent._execute_sql_query(query) pd.testing.assert_frame_equal(result, expected_result) diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index b56456aa8..f67f5703f 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -7,7 +7,9 @@ from pandasai import ConfigManager from pandasai.data_loader.loader import DatasetLoader +from pandasai.data_loader.query_builder import QueryBuilder from pandasai.data_loader.semantic_layer_schema import SemanticLayerSchema +from pandasai.data_loader.sql_loader import SQLDatasetLoader from pandasai.dataframe.base import DataFrame from pandasai.helpers.filemanager import DefaultFileManager from pandasai.helpers.path import find_project_root @@ -134,6 +136,77 @@ def mysql_schema(raw_mysql_schema): return SemanticLayerSchema(**raw_mysql_schema) +@pytest.fixture +def mock_view_loader_instance_parents(sample_df): + """Fixture to mock DatasetLoader and its methods.""" + # Mock the create_loader_from_path method + mock_loader_instance = MagicMock(spec=SQLDatasetLoader) + mock_loader_instance.load.return_value = sample_df + schema = SemanticLayerSchema( + **{ + "name": "parents", + "source": { + "type": "mysql", + "connection": { + "host": "localhost", + "port": 3306, + "database": "test_db", + "user": "test_user", + "password": "test_password", + }, + "table": "parents", + }, + } + ) + mock_query_builder = QueryBuilder(schema=schema) + mock_loader_instance.query_builder = mock_query_builder + mock_loader_instance.schema = schema + yield mock_loader_instance + + +@pytest.fixture +def mock_view_loader_instance_children(sample_df): + """Fixture to mock DatasetLoader and its methods.""" + # Mock the create_loader_from_path method + mock_loader_instance = MagicMock(spec=SQLDatasetLoader) + mock_loader_instance.load.return_value = sample_df + schema = SemanticLayerSchema( + **{ + "name": "children", + "source": { + "type": "mysql", + "connection": { + "host": "localhost", + "port": 3306, + "database": "test_db", + "user": "test_user", + "password": "test_password", + }, + "table": "children", + }, + } + ) + mock_query_builder = QueryBuilder(schema=schema) + mock_loader_instance.query_builder = mock_query_builder + mock_loader_instance.schema = schema + yield mock_loader_instance + + +@pytest.fixture +def mysql_view_schema(raw_mysql_view_schema): + return SemanticLayerSchema(**raw_mysql_view_schema) + + +@pytest.fixture +def mysql_view_dependencies_dict( + mock_view_loader_instance_parents, mock_view_loader_instance_children +) -> dict[str, MagicMock]: + return { + "parents": mock_view_loader_instance_parents, + "children": mock_view_loader_instance_children, + } + + @pytest.fixture(scope="session") def mock_json_load(): mock = MagicMock() diff --git a/tests/unit_tests/core/code_generation/test_code_cleaning.py b/tests/unit_tests/core/code_generation/test_code_cleaning.py index 01f7bdbd1..191ebe06b 100644 --- a/tests/unit_tests/core/code_generation/test_code_cleaning.py +++ b/tests/unit_tests/core/code_generation/test_code_cleaning.py @@ -64,14 +64,14 @@ def test_replace_table_names_invalid(self): ) def test_clean_sql_query(self): - table = self.sample_df.schema.source.table + table = self.sample_df.schema.name sql_query = f"SELECT * FROM {table};" self.cleaner.context.dfs = [self.sample_df] result = self.cleaner._clean_sql_query(sql_query) self.assertEqual(result, f"SELECT * FROM {table}") def test_validate_and_make_table_name_case_sensitive(self): - table = self.sample_df.schema.source.table + table = self.sample_df.schema.name node = ast.Assign( targets=[ast.Name(id="query", ctx=ast.Store())], value=ast.Constant(value=f"SELECT * FROM {table}"), diff --git a/tests/unit_tests/data_loader/test_loader.py b/tests/unit_tests/data_loader/test_loader.py index db0521245..34eec3e67 100644 --- a/tests/unit_tests/data_loader/test_loader.py +++ b/tests/unit_tests/data_loader/test_loader.py @@ -15,7 +15,7 @@ def test_load_from_local_source_valid(self, sample_schema): with patch("os.path.exists", return_value=True), patch( "pandasai.data_loader.local_loader.LocalDatasetLoader._read_csv_or_parquet" ) as mock_read_csv_or_parquet: - loader = LocalDatasetLoader(sample_schema, "test") + loader = LocalDatasetLoader(sample_schema, "test/test") mock_read_csv_or_parquet.return_value = DataFrame( {"email": ["test@example.com"]} @@ -29,7 +29,7 @@ def test_load_from_local_source_valid(self, sample_schema): def test_load_from_local_source_invalid_source_type(self, sample_schema): sample_schema.source.type = "mysql" - loader = LocalDatasetLoader(sample_schema, "test") + loader = LocalDatasetLoader(sample_schema, "test/test") with pytest.raises( InvalidDataSourceType, match="Unsupported local source type" @@ -65,7 +65,7 @@ def test_load_schema_file_not_found(self): DatasetLoader._read_schema_file("test/users") def test_read_parquet(self, sample_schema): - loader = LocalDatasetLoader(sample_schema, "test") + loader = LocalDatasetLoader(sample_schema, "test/test") mock_df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) with patch("pandas.read_parquet", return_value=mock_df) as mock_read_parquet: @@ -75,7 +75,7 @@ def test_read_parquet(self, sample_schema): assert result.equals(mock_df) def test_read_csv(self, sample_schema): - loader = LocalDatasetLoader(sample_schema, "test") + loader = LocalDatasetLoader(sample_schema, "test/test") mock_df = pd.DataFrame({"col1": [1, 2, 3], "col2": ["a", "b", "c"]}) with patch("pandas.read_csv", return_value=mock_df) as mock_read_csv: @@ -85,14 +85,14 @@ def test_read_csv(self, sample_schema): assert result.equals(mock_df) def test_read_csv_or_parquet_unsupported_format(self, sample_schema): - loader = LocalDatasetLoader(sample_schema, "test") + loader = LocalDatasetLoader(sample_schema, "test/test") with pytest.raises(ValueError, match="Unsupported file format: unsupported"): loader._read_csv_or_parquet("dummy_path", "unsupported") def test_apply_transformations(self, sample_schema): """Test that DatasetLoader correctly uses TransformationManager.""" - loader = LocalDatasetLoader(sample_schema, "test") + loader = LocalDatasetLoader(sample_schema, "test/test") df = pd.DataFrame( { diff --git a/tests/unit_tests/data_loader/test_sql_loader.py b/tests/unit_tests/data_loader/test_sql_loader.py index 2a9bbc277..cf2ee4789 100644 --- a/tests/unit_tests/data_loader/test_sql_loader.py +++ b/tests/unit_tests/data_loader/test_sql_loader.py @@ -193,10 +193,10 @@ def test_mysql_safe_query(self, mysql_schema): mock_sql_query.return_value = True logging.debug("Loading schema from dataset path: %s", loader) - result = loader.execute_query("select * from users") + result = loader.execute_query("SELECT * FROM users") assert isinstance(result, DataFrame) - mock_sql_query.assert_called_once_with("select * from users") + mock_sql_query.assert_called_once_with("SELECT * FROM users") def test_mysql_malicious_with_no_import(self, mysql_schema): """Test loading data from a MySQL source creates a VirtualDataFrame and handles queries correctly.""" diff --git a/tests/unit_tests/dataframe/test_view_query_builder.py b/tests/unit_tests/dataframe/test_view_query_builder.py index 75eb6ab2a..af0960a57 100644 --- a/tests/unit_tests/dataframe/test_view_query_builder.py +++ b/tests/unit_tests/dataframe/test_view_query_builder.py @@ -7,49 +7,31 @@ class TestViewQueryBuilder: @pytest.fixture - def mysql_view_schema(self): - raw_schema = { - "name": "Users", - "columns": [ - {"name": "parents.id"}, - {"name": "parents.name"}, - {"name": "children.name"}, - ], - "relations": [{"from": "parents.id", "to": "children.id"}], - "view": "true", - } - return SemanticLayerSchema(**raw_schema) + def view_query_builder(self, mysql_view_schema, mysql_view_dependencies_dict): + return ViewQueryBuilder(mysql_view_schema, mysql_view_dependencies_dict) - @pytest.fixture - def view_query_builder(self, mysql_view_schema): - return ViewQueryBuilder(mysql_view_schema) - - def test__init__(self, mysql_view_schema): - query_builder = ViewQueryBuilder(mysql_view_schema) + def test__init__(self, mysql_view_schema, mysql_view_dependencies_dict): + query_builder = ViewQueryBuilder( + mysql_view_schema, mysql_view_dependencies_dict + ) assert isinstance(query_builder, ViewQueryBuilder) assert isinstance(query_builder, QueryBuilder) assert query_builder.schema == mysql_view_schema def test_format_query(self, view_query_builder): - query = "SELECT ALL" + query = "SELECT * FROM table_llm_friendly" formatted_query = view_query_builder.format_query(query) - assert ( - formatted_query - == """WITH Users AS ( SELECT -parents.id AS parents_id, parents.name AS parents_name, children.name AS children_name -FROM parents -JOIN children ON parents.id = children.id) -SELECT ALL""" - ) + assert formatted_query == "SELECT * FROM table_llm_friendly" def test_build_query(self, view_query_builder) -> str: assert ( view_query_builder.build_query() - == """WITH Users AS ( SELECT + == """SELECT parents_id, parents_name, children_name FROM ( SELECT parents.id AS parents_id, parents.name AS parents_name, children.name AS children_name -FROM parents -JOIN children ON parents.id = children.id) -SELECT parents_id, parents_name, children_name FROM Users""" +FROM ( SELECT * FROM parents ) AS parents +JOIN ( SELECT * FROM children ) AS children +ON parents.id = children.id) AS parent-children +""" ) def test_get_columns(self, view_query_builder): @@ -63,25 +45,12 @@ def test_get_columns_empty(self, view_query_builder): assert view_query_builder._get_columns() == "*" def test_get_from_statement(self, view_query_builder): - assert view_query_builder._get_from_statement() == " FROM Users" - - def test_get_with_statement(self, view_query_builder): assert ( - view_query_builder._get_with_statement() - == """WITH Users AS ( SELECT + view_query_builder._get_from_statement() + == """FROM ( SELECT parents.id AS parents_id, parents.name AS parents_name, children.name AS children_name -FROM parents -JOIN children ON parents.id = children.id) -""" - ) - - def test_get_with_statement_no_columns(self, view_query_builder): - view_query_builder.schema.columns = None - assert ( - view_query_builder._get_with_statement() - == """WITH Users AS ( SELECT -* -FROM parents -JOIN children ON parents.id = children.id) +FROM ( SELECT * FROM parents ) AS parents +JOIN ( SELECT * FROM children ) AS children +ON parents.id = children.id) AS parent-children """ ) diff --git a/tests/unit_tests/helpers/test_dataframe_serializer.py b/tests/unit_tests/helpers/test_dataframe_serializer.py index 89e31b605..f03866ae6 100644 --- a/tests/unit_tests/helpers/test_dataframe_serializer.py +++ b/tests/unit_tests/helpers/test_dataframe_serializer.py @@ -9,7 +9,7 @@ def test_serialize_with_name_and_description(self, sample_df): """Test serialization with name and description attributes.""" result = DataframeSerializer.serialize(sample_df) - expected = """ + expected = """
A,B 1,4 2,5 diff --git a/tests/unit_tests/prompts/test_sql_prompt.py b/tests/unit_tests/prompts/test_sql_prompt.py index 46b1b1e80..3acdc3dce 100644 --- a/tests/unit_tests/prompts/test_sql_prompt.py +++ b/tests/unit_tests/prompts/test_sql_prompt.py @@ -64,7 +64,7 @@ def test_str_with_args(self, output_type, output_type_template): prompt_content == f''' -
+
diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 407c72de7..250b2f10a 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -50,7 +50,7 @@ def test_validate_llm_with_langchain(self): ConfigManager.validate_llm() assert isinstance(ConfigManager._config.llm, LangchainLLM) - assert ConfigManager._config.llm._llm == mock_langchain_llm + assert ConfigManager._config.llm.langchain_llm == mock_langchain_llm def test_update_config(self): """Test updating configuration with new values"""