Skip to content

Commit

Permalink
Split sql_plan.py into separate files (#1626)
Browse files Browse the repository at this point in the history
As per title.
  • Loading branch information
plypaul authored Jan 25, 2025
1 parent db86f16 commit fbd2e83
Show file tree
Hide file tree
Showing 34 changed files with 572 additions and 509 deletions.
4 changes: 2 additions & 2 deletions metricflow/dataset/convert_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_plan import (
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataset/semantic_model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import override

from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.sql.sql_plan import SqlSelectStatementNode
from metricflow.sql.sql_select_node import SqlSelectStatementNode


class SemanticModelDataSet(SqlDataSet):
Expand Down
2 changes: 1 addition & 1 deletion metricflow/dataset/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from metricflow.dataset.dataset_classes import DataSet
from metricflow.sql.sql_plan import (
SqlPlanNode,
SqlSelectStatementNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode


class SqlDataSet(DataSet):
Expand Down
5 changes: 4 additions & 1 deletion metricflow/plan_conversion/to_sql_plan/dataflow_to_cte.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@
from metricflow.dataset.sql_dataset import SqlDataSet
from metricflow.plan_conversion.instance_converters import CreateSelectColumnsForInstances
from metricflow.plan_conversion.to_sql_plan.dataflow_to_subquery import DataflowNodeToSqlSubqueryVisitor
from metricflow.sql.sql_plan import SqlCteNode, SqlSelectColumn, SqlSelectStatementNode, SqlTableNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion metricflow/plan_conversion/to_sql_plan/dataflow_to_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from metricflow.sql.sql_plan import (
SqlPlan,
SqlPlanNode,
SqlSelectStatementNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode

logger = logging.getLogger(__name__)

Expand Down
10 changes: 4 additions & 6 deletions metricflow/plan_conversion/to_sql_plan/dataflow_to_subquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@
SelectOnlyLinkableSpecs,
)
from metricflow.plan_conversion.to_sql_plan.sql_join_builder import ColumnEqualityDescription, SqlPlanJoinBuilder
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlSelectColumn,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode


class DataflowNodeToSqlSubqueryVisitor(DataflowPlanNodeVisitor[SqlDataSet]):
Expand Down
3 changes: 2 additions & 1 deletion metricflow/plan_conversion/to_sql_plan/sql_join_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from metricflow.dataflow.nodes.join_to_time_spine import JoinToTimeSpineNode
from metricflow.dataset.sql_dataset import AnnotatedSqlDataSet
from metricflow.plan_conversion.sql_expression_builders import make_coalesced_expr
from metricflow.sql.sql_plan import SqlExpressionNode, SqlJoinDescription, SqlSelectStatementNode
from metricflow.sql.sql_plan import SqlExpressionNode
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlSelectStatementNode


@dataclass(frozen=True)
Expand Down
13 changes: 6 additions & 7 deletions metricflow/sql/optimizer/column_pruner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,15 @@
from metricflow.sql.optimizer.required_column_aliases import SqlMapRequiredColumnAliasesVisitor
from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -87,7 +86,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return node

Expand Down
3 changes: 2 additions & 1 deletion metricflow/sql/optimizer/cte_alias_to_cte_node_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from metricflow_semantics.mf_logging.lazy_formattable import LazyFormat

from metricflow.sql.sql_plan import SqlCteAliasMapping, SqlSelectStatementNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_select_node import SqlSelectStatementNode

logger = logging.getLogger(__name__)

Expand Down
13 changes: 6 additions & 7 deletions metricflow/sql/optimizer/cte_mapping_lookup_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,15 @@
from typing_extensions import override

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -81,7 +80,7 @@ def visit_table_node(self, node: SqlTableNode) -> None:
self._default_handler(node)

@override
def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None:
self._default_handler(node)

@override
Expand Down
13 changes: 6 additions & 7 deletions metricflow/sql/optimizer/required_column_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,16 @@

from metricflow.sql.optimizer.cte_alias_to_cte_node_mapping import SqlCteAliasMappingLookup
from metricflow.sql.optimizer.tag_column_aliases import NodeToColumnAliasMapping
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteAliasMapping, SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteAliasMapping,
SqlCteNode,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -264,7 +263,7 @@ def visit_table_node(self, node: SqlTableNode) -> None:
"""There are no SELECT columns in this node, so pruning cannot apply."""
return

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> None:
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> None:
"""Pruning cannot be done here since this is an arbitrary user-provided SQL query."""
return

Expand Down
16 changes: 7 additions & 9 deletions metricflow/sql/optimizer/rewriting_sub_query_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -766,7 +764,7 @@ def _get_matching_column_for_order_by(
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down Expand Up @@ -841,7 +839,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down
14 changes: 6 additions & 8 deletions metricflow/sql/optimizer/table_alias_simplifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,16 @@
from typing_extensions import override

from metricflow.sql.optimizer.sql_query_plan_optimizer import SqlPlanOptimizer
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -85,7 +83,7 @@ def visit_select_statement_node(self, node: SqlSelectStatementNode) -> SqlPlanNo
def visit_table_node(self, node: SqlTableNode) -> SqlPlanNode: # noqa: D102
return node

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanNode: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanNode: # noqa: D102
return node

def visit_create_table_as_node(self, node: SqlCreateTableAsNode) -> SqlPlanNode: # noqa: D102
Expand Down
14 changes: 6 additions & 8 deletions metricflow/sql/render/sql_plan_renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,17 @@
SqlExpressionRenderResult,
)
from metricflow.sql.render.rendering_constants import SqlRenderingConstants
from metricflow.sql.sql_ctas_node import SqlCreateTableAsNode
from metricflow.sql.sql_cte_node import SqlCteNode
from metricflow.sql.sql_plan import (
SqlCreateTableAsNode,
SqlCteNode,
SqlJoinDescription,
SqlOrderByDescription,
SqlPlan,
SqlPlanNode,
SqlPlanNodeVisitor,
SqlSelectColumn,
SqlSelectQueryFromClauseNode,
SqlSelectStatementNode,
SqlTableNode,
)
from metricflow.sql.sql_select_node import SqlJoinDescription, SqlOrderByDescription, SqlSelectStatementNode
from metricflow.sql.sql_select_text_node import SqlSelectTextNode
from metricflow.sql.sql_table_node import SqlTableNode

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -349,7 +347,7 @@ def visit_table_node(self, node: SqlTableNode) -> SqlPlanRenderResult: # noqa:
bind_parameter_set=SqlBindParameterSet(),
)

def visit_query_from_clause_node(self, node: SqlSelectQueryFromClauseNode) -> SqlPlanRenderResult: # noqa: D102
def visit_query_from_clause_node(self, node: SqlSelectTextNode) -> SqlPlanRenderResult: # noqa: D102
return SqlPlanRenderResult(
sql=node.select_query.rstrip(),
bind_parameter_set=SqlBindParameterSet(),
Expand Down
72 changes: 72 additions & 0 deletions metricflow/sql/sql_ctas_node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Optional, Sequence

from metricflow_semantics.dag.id_prefix import IdPrefix, StaticIdPrefix
from metricflow_semantics.sql.sql_table import SqlTable
from metricflow_semantics.visitor import VisitorOutputT
from typing_extensions import override

from metricflow.sql.sql_cte_node import SqlCteAliasMapping
from metricflow.sql.sql_plan import SqlPlanNode, SqlPlanNodeVisitor, SqlSelectColumn
from metricflow.sql.sql_select_node import SqlSelectStatementNode
from metricflow.sql.sql_table_node import SqlTableNode


@dataclass(frozen=True, eq=False)
class SqlCreateTableAsNode(SqlPlanNode):
"""An SQL node representing a CREATE TABLE AS statement.
Attributes:
sql_table: The SQL table to create.
"""

sql_table: SqlTable

def __post_init__(self) -> None: # noqa: D105
super().__post_init__()
assert len(self.parent_nodes) == 1

@staticmethod
def create(sql_table: SqlTable, parent_node: SqlPlanNode) -> SqlCreateTableAsNode: # noqa: D102
return SqlCreateTableAsNode(
parent_nodes=(parent_node,),
sql_table=sql_table,
)

@override
def accept(self, visitor: SqlPlanNodeVisitor[VisitorOutputT]) -> VisitorOutputT:
return visitor.visit_create_table_as_node(self)

@property
@override
def as_select_node(self) -> Optional[SqlSelectStatementNode]:
return None

@property
@override
def as_sql_table_node(self) -> Optional[SqlTableNode]:
return None

@property
@override
def description(self) -> str:
return f"Create table {repr(self.sql_table.sql)}"

@property
def parent_node(self) -> SqlPlanNode: # noqa: D102
return self.parent_nodes[0]

@classmethod
@override
def id_prefix(cls) -> IdPrefix:
return StaticIdPrefix.SQL_PLAN_CREATE_TABLE_AS_ID_PREFIX

@override
def nearest_select_columns(self, cte_source_mapping: SqlCteAliasMapping) -> Optional[Sequence[SqlSelectColumn]]:
return self.parent_node.nearest_select_columns(cte_source_mapping)

@override
def copy(self) -> SqlCreateTableAsNode:
return SqlCreateTableAsNode(parent_nodes=self.parent_nodes, sql_table=self.sql_table)
Loading

0 comments on commit fbd2e83

Please sign in to comment.