Skip to content
This repository was archived by the owner on Feb 6, 2025. It is now read-only.

Draft implementation for cross-db query plan execution. #551

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2019-present Kensho Technologies, LLC.
from copy import copy
from dataclasses import dataclass
from typing import FrozenSet, List, Optional, Tuple, cast
from typing import Any, Callable, Dict, FrozenSet, List, Mapping, Optional, Set, Tuple, cast

from graphql import print_ast
from graphql.language.ast import (
Expand All @@ -22,17 +22,21 @@
from ..exceptions import GraphQLValidationError
from ..schema import FilterDirective, OutputDirective
from .split_query import AstType, SubQueryNode
from .utils import get_query_runtime_arguments


@dataclass
class SubQueryPlan:
"""Query plan for a part of a larger query over a single schema."""

# Unique identifier for this sub-plan.
plan_id: int

# Representing a piece of the overall query with directives added.
query_ast: DocumentNode

# Identifier for the schema that this query piece targets.
schema_id: Optional[str]
schema_id: str

# The query that the current query depends on, or None if the current query does not
# depend on another.
Expand All @@ -51,6 +55,9 @@ class OutputJoinDescriptor:
# should be made.
output_names: Tuple[str, str]

# SubQueryPlan, the sub-plan node for which the join happens between it and its parent sub-plan.
child_query_plan: SubQueryPlan


@dataclass(frozen=True)
class QueryPlanDescriptor:
Expand Down Expand Up @@ -91,14 +98,20 @@ def make_query_plan(
"""
output_join_descriptors: List[OutputJoinDescriptor] = []

if root_sub_query_node.schema_id is None:
raise AssertionError(
"Unreachable code reached. The schema id of root_sub_query_node "
f'"{root_sub_query_node.query_ast}" has not been determined.'
)
root_sub_query_plan = SubQueryPlan(
plan_id=0,
query_ast=root_sub_query_node.query_ast,
schema_id=root_sub_query_node.schema_id,
parent_query_plan=None,
child_query_plans=[],
)

_make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors)
_make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors, 1)

return QueryPlanDescriptor(
root_sub_query_plan=root_sub_query_plan,
Expand All @@ -111,6 +124,7 @@ def _make_query_plan_recursive(
sub_query_node: SubQueryNode,
sub_query_plan: SubQueryPlan,
output_join_descriptors: List[OutputJoinDescriptor],
next_plan_id: int,
) -> None:
"""Recursively copy the structure of sub_query_node onto sub_query_plan.

Expand All @@ -124,7 +138,7 @@ def _make_query_plan_recursive(
sub_query_plan: SubQueryPlan, whose list of child query plans and query AST are
modified.
output_join_descriptors: describing which outputs should be joined and how.

next_plan_id: the next available plan ID to use. IDs at and above this number are free.
"""
# Iterate through child connections of query node
for child_query_connection in sub_query_node.child_query_connections:
Expand All @@ -149,26 +163,34 @@ def _make_query_plan_recursive(
else:
new_child_query_ast = DocumentNode(definitions=[child_query_type_with_filter])

if child_sub_query_node.schema_id is None:
raise AssertionError(
"Unreachable code reached. The schema id of the child_sub_query_node "
f'"{child_sub_query_node.query_ast}" has not been determined.'
)
# Create new SubQueryPlan for child
child_sub_query_plan = SubQueryPlan(
plan_id=next_plan_id,
query_ast=new_child_query_ast,
schema_id=child_sub_query_node.schema_id,
parent_query_plan=sub_query_plan,
child_query_plans=[],
)
next_plan_id += 1

# Add new SubQueryPlan to parent's child list
sub_query_plan.child_query_plans.append(child_sub_query_plan)

# Add information about this edge
new_output_join_descriptor = OutputJoinDescriptor(
output_names=(parent_out_name, child_out_name),
child_query_plan=child_sub_query_plan,
)
output_join_descriptors.append(new_output_join_descriptor)

# Recursively repeat on child SubQueryPlans
_make_query_plan_recursive(
child_sub_query_node, child_sub_query_plan, output_join_descriptors
child_sub_query_node, child_sub_query_plan, output_join_descriptors, next_plan_id
)


Expand Down Expand Up @@ -300,13 +322,38 @@ def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_dep
line_separation = "\n" + " " * indentation_depth * depth
query_plan_strings.append(line_separation)

query_str = 'Execute in schema named "{}":\n'.format(query_plan.schema_id)
query_str = 'Execute subplan ID {} in schema named "{}":\n'.format(
query_plan.plan_id, query_plan.schema_id
)
query_str += print_ast(query_plan.query_ast)
query_str = query_str.replace("\n", line_separation)
query_plan_strings.append(query_str)

query_plan_strings.append("\n\nJoin together outputs as follows: ")
query_plan_strings.append(str(query_plan_descriptor.output_join_descriptors))
query_plan_strings.append("\n\nJoin together outputs as follows:\n")
for descriptor in query_plan_descriptor.output_join_descriptors:
if descriptor.child_query_plan.parent_query_plan is None:
raise AssertionError(
f"Invalid join descriptor. The parent plan was not set for child_query_plan "
f"{descriptor.child_query_plan}."
)
query_plan_strings.append(
"\n".join(
[
" ".join(
[
str(descriptor.output_names),
"between subplan IDs",
str(
[
descriptor.child_query_plan.parent_query_plan.plan_id,
descriptor.child_query_plan.plan_id,
]
),
]
)
]
)
)
query_plan_strings.append("\n\nRemove the following outputs at the end: ")
query_plan_strings.append(str(query_plan_descriptor.intermediate_output_names) + "\n")

Expand All @@ -316,7 +363,9 @@ def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_dep
def _get_plan_and_depth_in_dfs_order(query_plan: SubQueryPlan) -> List[Tuple[SubQueryPlan, int]]:
"""Return a list of topologically sorted (query plan, depth) tuples."""

def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth):
def _get_plan_and_depth_in_dfs_order_helper(
query_plan: SubQueryPlan, depth: int
) -> List[Tuple[SubQueryPlan, int]]:
plan_and_depth_in_dfs_order = [(query_plan, depth)]
for child_query_plan in query_plan.child_query_plans:
plan_and_depth_in_dfs_order.extend(
Expand All @@ -325,3 +374,190 @@ def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth):
return plan_and_depth_in_dfs_order

return _get_plan_and_depth_in_dfs_order_helper(query_plan, 0)


def execute_query_plan(
schema_id_to_execution_func: Dict[
str, Callable[[str, Mapping[str, Any]], List[Dict[str, Any]]]
],
query_plan_descriptor: QueryPlanDescriptor,
query_args: Dict[str, Any],
) -> List[Dict[str, Any]]:
"""Execute the given query plan and return the produced results."""
result_components_by_plan_id = {}

stitching_output_names_by_parent_plan_id: Dict[int, List[Tuple[str, str]]] = dict()
for join_descriptor in query_plan_descriptor.output_join_descriptors:
if join_descriptor.child_query_plan.parent_query_plan is None:
raise AssertionError(
f"Invalid join descriptor. The parent plan was not set for child_query_plan "
f"{join_descriptor.child_query_plan}."
)
parent_plan_id = join_descriptor.child_query_plan.parent_query_plan.plan_id
stitching_output_names_by_parent_plan_id.setdefault(parent_plan_id, []).append(
join_descriptor.output_names
)

full_query_args = dict(query_args)

plan_and_depth = _get_plan_and_depth_in_dfs_order(query_plan_descriptor.root_sub_query_plan)
Copy link
Collaborator

@bojanserafimov bojanserafimov Sep 10, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the code below depend on the order being a dfs order? In some very trivial cases we will have to start execution from the leafs. Just making sure we don't lock ourselves out from that execution plan with the code structure.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One thing we should document better than we currently have is the distinction between the output of split_query() and the output of make_query_plan() -- I think that would clarify the situation here. split_query() tells us where we cross schemas, but does not specify an execution order. make_query_plan() specifies an order between the subqueries, and is free to specify any valid order -- and may choose an optimized order if it has access to statistics etc.

Given that, once a query plan is made, it's always executed from the root onward. We may in the future allow different styles of executors (e.g. DFS, BFS, async + parallel across children, etc.), but this executor function is just a simple sync DFS.


for query_plan, _ in plan_and_depth:
plan_id = query_plan.plan_id
schema_id = query_plan.schema_id
if schema_id is None:
raise AssertionError(
f'Unreachable code reached. The schema id of query piece "{query_plan.query_ast}" '
f"has not been determined."
)

subquery_graphql = print_ast(query_plan.query_ast)

# HACK(predrag): Add proper error checking for missing arguments here.
# HACK(predrag): Don't bother running queries if the previous query's stitching outputs
# returned no values to pass to the next query.
subquery_args = {
argument_name: full_query_args[argument_name]
for argument_name in get_query_runtime_arguments(query_plan.query_ast)
}

# Run the query and save the results.
execution_func = schema_id_to_execution_func[schema_id]
subquery_result = execution_func(subquery_graphql, subquery_args)
result_components_by_plan_id[plan_id] = subquery_result

# Capture and record any values that will be used for stitching by other subqueries.
child_extra_output_names = {
# The .get() call is to handle the case of query plans with no children.
# They have no extra output values for their children, on account of having no children.
output_name
for output_name, _ in stitching_output_names_by_parent_plan_id.get(plan_id, [])
}
child_extra_output_values: Dict[str, Set[Any]] = {
# Make sure we deduplicate the values -- there's no point in running subqueries
# with duplicated runtime argument values.
output_name: set()
for output_name in child_extra_output_names
}
for subquery_row in subquery_result:
for output_name in child_extra_output_names:
# We intentionally discard None values -- None is never a foreign key value.
# This is standard in all relational systems as well.
output_value = subquery_row.get(output_name, None)
if output_value is not None:
child_extra_output_values[output_name].add(output_value)
# TODO(predrag): Use the "merge_disjoint_dicts" function here,
# there should never be any overlap here.
new_query_args = {
# Argument values cannot be sets, so we turn the sets back into lists.
output_argument_name: list(child_extra_output_values[output_argument_name])
for output_argument_name in child_extra_output_names
}
full_query_args.update(new_query_args)

join_indexes_by_plan_id = _make_join_indexes(
query_plan_descriptor, result_components_by_plan_id
)

joined_results = _join_results(
result_components_by_plan_id,
join_indexes_by_plan_id,
result_components_by_plan_id[query_plan_descriptor.root_sub_query_plan.plan_id],
query_plan_descriptor.output_join_descriptors,
)

return _drop_intermediate_outputs(
query_plan_descriptor.intermediate_output_names, joined_results
)


def _make_join_indexes(
query_plan_descriptor: QueryPlanDescriptor,
result_components_by_plan_id: Dict[int, List[Dict[str, Any]]],
) -> Dict[int, Dict[str, List[int]]]:
"""Return a dict from child plan id to a join index between its and its parents' rows."""
join_indexes_by_plan_id: Dict[int, Dict[str, List[int]]] = dict()

for join_descriptor in query_plan_descriptor.output_join_descriptors:
child_plan_id = join_descriptor.child_query_plan.plan_id
_, child_output_name = join_descriptor.output_names

if child_plan_id in join_indexes_by_plan_id:
raise AssertionError(
"Unreachable code reached: {} {} {}".format(
child_plan_id,
join_indexes_by_plan_id,
query_plan_descriptor.output_join_descriptors,
)
)

join_indexes_by_plan_id[child_plan_id] = _make_join_index_for_output(
result_components_by_plan_id[child_plan_id], child_output_name
)

return join_indexes_by_plan_id


def _make_join_index_for_output(
results: List[Dict[str, Any]], join_output_name: str
) -> Dict[str, List[int]]:
"""Return a dict of each value of the join column to a list of row indexes where it appears."""
join_index: Dict[str, List[int]] = {}
for row_index, row in enumerate(results):
join_value = row[join_output_name]
join_index.setdefault(join_value, []).append(row_index)

return join_index


def _join_results(
result_components_by_plan_id: Dict[int, List[Dict[str, Any]]],
join_indexes_by_plan_id: Dict[int, Dict[str, List[int]]],
current_results: List[Dict[str, Any]],
join_descriptors: List[OutputJoinDescriptor],
) -> List[Dict[str, Any]]:
"""Return the merged results across all subplans using the calculated join indexes."""
if len(join_descriptors) == 0:
# No further joining to be done!
return current_results

next_results = []

next_join_descriptor = join_descriptors[0]
remaining_join_descriptors = join_descriptors[1:]

join_plan_id = next_join_descriptor.child_query_plan.plan_id
join_index = join_indexes_by_plan_id[join_plan_id]
joining_results = result_components_by_plan_id[join_plan_id]
join_from_key, _ = next_join_descriptor.output_names

for current_row in current_results:
join_value = current_row[join_from_key]

# To get inner join semantics, we don't output results that don't have matches.
# When we add support for stitching across @optional edges, we'll need to update this
# code to also output results even when the join index doesn't contain matches.
for join_matched_index in join_index.get(join_value, []):
joining_row = joining_results[join_matched_index]
next_results.append(dict(current_row, **joining_row))

return _join_results(
result_components_by_plan_id,
join_indexes_by_plan_id,
next_results,
remaining_join_descriptors,
)


def _drop_intermediate_outputs(
columns_to_drop: FrozenSet[str], results: List[Dict[str, Any]]
) -> List[Dict[str, Any]]:
"""Return the provided results with the specified column names dropped."""
processed_results = []

for row in results:
processed_results.append(
{key: value for key, value in row.items() if key not in columns_to_drop}
)

return processed_results
7 changes: 5 additions & 2 deletions graphql_compiler/schema_transformation/split_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _check_or_set_schema_id(self, type_name: str) -> None:


def split_query(
query_ast: DocumentNode, merged_schema_descriptor: MergedSchemaDescriptor
query_ast: DocumentNode, merged_schema_descriptor: MergedSchemaDescriptor, strict=True
) -> Tuple[SubQueryNode, FrozenSet[str]]:
"""Split input query AST into a tree of SubQueryNodes targeting each individual schema.

Expand All @@ -171,6 +171,9 @@ def split_query(
Args:
query_ast: representing a GraphQL query to split.
merged_schema_descriptor: description of the merged schema to split the query over.
strict: bool, if set to True then limits query splitting to queries that are guaranteed
to be safely splittable. If False, then some queries may be permitted to be split
even though they are illegal. Use with caution.

Returns:
Tuple of:
Expand All @@ -186,7 +189,7 @@ def split_query(
- SchemaStructureError if the input merged_schema_descriptor appears to be invalid
or inconsistent
"""
check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast)
check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast, strict=strict)

# If schema directives are correctly represented in the schema object, type_info is all
# that's needed to detect and address stitching fields. However, GraphQL currently ignores
Expand Down
Loading