Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split sql_plan.py into separate files #1626

Merged
merged 5 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_plan import (
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataset/semantic_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import override

from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_plan import SqlSelectStatementNode
from metricflow.sql.sql_select_node import SqlSelectStatementNode


class SemanticModelDataSet(SqlDataSet):
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from metricflow.dataset.dataset_classes import DataSet
from metricflow.sql.sql_plan import (
SqlPlanNode,
SqlSelectStatementNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode


class SqlDataSet(DataSet):
Expand Down
5 changes: 4 additions & 1 deletion metricflow/plan_conversion/to_sql_plan/dataflow_to_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.plan_conversion.instance_converters import CreateSelectColumnsForInstances
from metricflow.plan_conversion.to_sql_plan.dataflow_to_subquery import DataflowNodeToSqlSubqueryVisitor
from metricflow.sql.sql_plan import SqlCteNode, SqlSelectColumn, SqlSelectStatementNode, SqlTableNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from metricflow.sql.sql_plan import (
SqlPlan,
SqlPlanNode,
SqlSelectStatementNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode

logger = logging.getLogger(__name__)

Expand Down
10 changes: 4 additions & 6 deletions metricflow/plan_conversion/to_sql_plan/dataflow_to_subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@
SelectOnlyLinkableSpecs,
)
from metricflow.plan_conversion.to_sql_plan.sql_join_builder import ColumnEqualityDescription, SqlPlanJoinBuilder
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode


class DataflowNodeToSqlSubqueryVisitor(DataflowPlanNodeVisitor[SqlDataSet]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet
from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr
from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlSelectStatementNode
from metricflow.sql.sql_plan import SqlExpressionNode
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlSelectStatementNode


@dataclass(frozen=True)
Expand Down
13 changes: 6 additions & 7 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
from metricflow.sql.optimizer.required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,7 +86,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return node

Expand Down
3 changes: 2 additions & 1 deletion metricflow/sql/optimizer/cte_alias_to_cte_node_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlSelectStatementNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_select_node import SqlSelectStatementNode

logger = logging.getLogger(__name__)

Expand Down
13 changes: 6 additions & 7 deletions metricflow/sql/optimizer/cte_mapping_lookup_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
from typing_extensions import override

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,7 +80,7 @@ def visit_table_node(self, node: SqlTableNode) -> None:
self._default_handler(node)

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None:
self._default_handler(node)

@override
Expand Down
13 changes: 6 additions & 7 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -264,7 +263,7 @@ def visit_table_node(self, node: SqlTableNode) -> None:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return

Expand Down
16 changes: 7 additions & 9 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -766,7 +764,7 @@ def _get_matching_column_for_order_by(
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down Expand Up @@ -841,7 +839,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down
14 changes: 6 additions & 8 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,7 +83,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down
14 changes: 6 additions & 8 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@
SqlExpressionRenderResult,
)
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlan,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -349,7 +347,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa:
bind_parameter_set=SqlBindParameterSet(),
)

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanRenderResult: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanRenderResult: # noqa: D102
return SqlPlanRenderResult(
sql=node.select_query.rstrip(),
bind_parameter_set=SqlBindParameterSet(),
Expand Down
72 changes: 72 additions & 0 deletions metricflow/sql/sql_ctas_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT
from typing_extensions import override

from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode


@dataclass(frozen=True, eq=False)
class SqlCreateTableAsNode(SqlPlanNode):
"""An SQL node representing a CREATE TABLE AS statement.

Attributes:
sql_table: The SQL table to create.
"""

sql_table: SqlTable

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(sql_table: SqlTable, parent_node: SqlPlanNode) -> SqlCreateTableAsNode: # noqa: D102
return SqlCreateTableAsNode(
parent_nodes=(parent_node,),
sql_table=sql_table,
)

@override
def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_create_table_as_node(self)

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
return f"Create table {repr(self.sql_table.sql)}"

@property
def parent_node(self) -> SqlPlanNode: # noqa: D102
return self.parent_nodes[0]

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return self.parent_node.nearest_select_columns(cte_source_mapping)

@override
def copy(self) -> SqlCreateTableAsNode:
return SqlCreateTableAsNode(parent_nodes=self.parent_nodes, sql_table=self.sql_table)
Loading
Loading