diff --git a/graphql_compiler/schema_transformation/make_query_plan.py b/graphql_compiler/schema_transformation/query_plan.py similarity index 56% rename from graphql_compiler/schema_transformation/make_query_plan.py rename to graphql_compiler/schema_transformation/query_plan.py index 6547b08f2..22330d1b1 100644 --- a/graphql_compiler/schema_transformation/make_query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -1,7 +1,7 @@ # Copyright 2019-present Kensho Technologies, LLC. from copy import copy from dataclasses import dataclass -from typing import FrozenSet, List, Optional, Tuple, cast +from typing import Any, Callable, Dict, FrozenSet, List, Mapping, Optional, Set, Tuple, cast from graphql import print_ast from graphql.language.ast import ( @@ -22,17 +22,21 @@ from ..exceptions import GraphQLValidationError from ..schema import FilterDirective, OutputDirective from .split_query import AstType, SubQueryNode +from .utils import get_query_runtime_arguments @dataclass class SubQueryPlan: """Query plan for a part of a larger query over a single schema.""" + # Unique identifier for this sub-plan. + plan_id: int + # Representing a piece of the overall query with directives added. query_ast: DocumentNode # Identifier for the schema that this query piece targets. - schema_id: Optional[str] + schema_id: str # The query that the current query depends on, or None if the current query does not # depend on another. @@ -51,6 +55,9 @@ class OutputJoinDescriptor: # should be made. output_names: Tuple[str, str] + # SubQueryPlan, the sub-plan node for which the join happens between it and its parent sub-plan. + child_query_plan: SubQueryPlan + @dataclass(frozen=True) class QueryPlanDescriptor: @@ -91,14 +98,20 @@ def make_query_plan( """ output_join_descriptors: List[OutputJoinDescriptor] = [] + if root_sub_query_node.schema_id is None: + raise AssertionError( + "Unreachable code reached. The schema id of root_sub_query_node " + f'"{root_sub_query_node.query_ast}" has not been determined.' + ) root_sub_query_plan = SubQueryPlan( + plan_id=0, query_ast=root_sub_query_node.query_ast, schema_id=root_sub_query_node.schema_id, parent_query_plan=None, child_query_plans=[], ) - _make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors) + _make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors, 1) return QueryPlanDescriptor( root_sub_query_plan=root_sub_query_plan, @@ -111,6 +124,7 @@ def _make_query_plan_recursive( sub_query_node: SubQueryNode, sub_query_plan: SubQueryPlan, output_join_descriptors: List[OutputJoinDescriptor], + next_plan_id: int, ) -> None: """Recursively copy the structure of sub_query_node onto sub_query_plan. @@ -124,7 +138,7 @@ def _make_query_plan_recursive( sub_query_plan: SubQueryPlan, whose list of child query plans and query AST are modified. output_join_descriptors: describing which outputs should be joined and how. - + next_plan_id: the next available plan ID to use. IDs at and above this number are free. """ # Iterate through child connections of query node for child_query_connection in sub_query_node.child_query_connections: @@ -149,13 +163,20 @@ def _make_query_plan_recursive( else: new_child_query_ast = DocumentNode(definitions=[child_query_type_with_filter]) + if child_sub_query_node.schema_id is None: + raise AssertionError( + "Unreachable code reached. The schema id of the child_sub_query_node " + f'"{child_sub_query_node.query_ast}" has not been determined.' + ) # Create new SubQueryPlan for child child_sub_query_plan = SubQueryPlan( + plan_id=next_plan_id, query_ast=new_child_query_ast, schema_id=child_sub_query_node.schema_id, parent_query_plan=sub_query_plan, child_query_plans=[], ) + next_plan_id += 1 # Add new SubQueryPlan to parent's child list sub_query_plan.child_query_plans.append(child_sub_query_plan) @@ -163,12 +184,13 @@ def _make_query_plan_recursive( # Add information about this edge new_output_join_descriptor = OutputJoinDescriptor( output_names=(parent_out_name, child_out_name), + child_query_plan=child_sub_query_plan, ) output_join_descriptors.append(new_output_join_descriptor) # Recursively repeat on child SubQueryPlans _make_query_plan_recursive( - child_sub_query_node, child_sub_query_plan, output_join_descriptors + child_sub_query_node, child_sub_query_plan, output_join_descriptors, next_plan_id ) @@ -300,13 +322,38 @@ def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_dep line_separation = "\n" + " " * indentation_depth * depth query_plan_strings.append(line_separation) - query_str = 'Execute in schema named "{}":\n'.format(query_plan.schema_id) + query_str = 'Execute subplan ID {} in schema named "{}":\n'.format( + query_plan.plan_id, query_plan.schema_id + ) query_str += print_ast(query_plan.query_ast) query_str = query_str.replace("\n", line_separation) query_plan_strings.append(query_str) - query_plan_strings.append("\n\nJoin together outputs as follows: ") - query_plan_strings.append(str(query_plan_descriptor.output_join_descriptors)) + query_plan_strings.append("\n\nJoin together outputs as follows:\n") + for descriptor in query_plan_descriptor.output_join_descriptors: + if descriptor.child_query_plan.parent_query_plan is None: + raise AssertionError( + f"Invalid join descriptor. The parent plan was not set for child_query_plan " + f"{descriptor.child_query_plan}." + ) + query_plan_strings.append( + "\n".join( + [ + " ".join( + [ + str(descriptor.output_names), + "between subplan IDs", + str( + [ + descriptor.child_query_plan.parent_query_plan.plan_id, + descriptor.child_query_plan.plan_id, + ] + ), + ] + ) + ] + ) + ) query_plan_strings.append("\n\nRemove the following outputs at the end: ") query_plan_strings.append(str(query_plan_descriptor.intermediate_output_names) + "\n") @@ -316,7 +363,9 @@ def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_dep def _get_plan_and_depth_in_dfs_order(query_plan: SubQueryPlan) -> List[Tuple[SubQueryPlan, int]]: """Return a list of topologically sorted (query plan, depth) tuples.""" - def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth): + def _get_plan_and_depth_in_dfs_order_helper( + query_plan: SubQueryPlan, depth: int + ) -> List[Tuple[SubQueryPlan, int]]: plan_and_depth_in_dfs_order = [(query_plan, depth)] for child_query_plan in query_plan.child_query_plans: plan_and_depth_in_dfs_order.extend( @@ -325,3 +374,190 @@ def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth): return plan_and_depth_in_dfs_order return _get_plan_and_depth_in_dfs_order_helper(query_plan, 0) + + +def execute_query_plan( + schema_id_to_execution_func: Dict[ + str, Callable[[str, Mapping[str, Any]], List[Dict[str, Any]]] + ], + query_plan_descriptor: QueryPlanDescriptor, + query_args: Dict[str, Any], +) -> List[Dict[str, Any]]: + """Execute the given query plan and return the produced results.""" + result_components_by_plan_id = {} + + stitching_output_names_by_parent_plan_id: Dict[int, List[Tuple[str, str]]] = dict() + for join_descriptor in query_plan_descriptor.output_join_descriptors: + if join_descriptor.child_query_plan.parent_query_plan is None: + raise AssertionError( + f"Invalid join descriptor. The parent plan was not set for child_query_plan " + f"{join_descriptor.child_query_plan}." + ) + parent_plan_id = join_descriptor.child_query_plan.parent_query_plan.plan_id + stitching_output_names_by_parent_plan_id.setdefault(parent_plan_id, []).append( + join_descriptor.output_names + ) + + full_query_args = dict(query_args) + + plan_and_depth = _get_plan_and_depth_in_dfs_order(query_plan_descriptor.root_sub_query_plan) + + for query_plan, _ in plan_and_depth: + plan_id = query_plan.plan_id + schema_id = query_plan.schema_id + if schema_id is None: + raise AssertionError( + f'Unreachable code reached. The schema id of query piece "{query_plan.query_ast}" ' + f"has not been determined." + ) + + subquery_graphql = print_ast(query_plan.query_ast) + + # HACK(predrag): Add proper error checking for missing arguments here. + # HACK(predrag): Don't bother running queries if the previous query's stitching outputs + # returned no values to pass to the next query. + subquery_args = { + argument_name: full_query_args[argument_name] + for argument_name in get_query_runtime_arguments(query_plan.query_ast) + } + + # Run the query and save the results. + execution_func = schema_id_to_execution_func[schema_id] + subquery_result = execution_func(subquery_graphql, subquery_args) + result_components_by_plan_id[plan_id] = subquery_result + + # Capture and record any values that will be used for stitching by other subqueries. + child_extra_output_names = { + # The .get() call is to handle the case of query plans with no children. + # They have no extra output values for their children, on account of having no children. + output_name + for output_name, _ in stitching_output_names_by_parent_plan_id.get(plan_id, []) + } + child_extra_output_values: Dict[str, Set[Any]] = { + # Make sure we deduplicate the values -- there's no point in running subqueries + # with duplicated runtime argument values. + output_name: set() + for output_name in child_extra_output_names + } + for subquery_row in subquery_result: + for output_name in child_extra_output_names: + # We intentionally discard None values -- None is never a foreign key value. + # This is standard in all relational systems as well. + output_value = subquery_row.get(output_name, None) + if output_value is not None: + child_extra_output_values[output_name].add(output_value) + # TODO(predrag): Use the "merge_disjoint_dicts" function here, + # there should never be any overlap here. + new_query_args = { + # Argument values cannot be sets, so we turn the sets back into lists. + output_argument_name: list(child_extra_output_values[output_argument_name]) + for output_argument_name in child_extra_output_names + } + full_query_args.update(new_query_args) + + join_indexes_by_plan_id = _make_join_indexes( + query_plan_descriptor, result_components_by_plan_id + ) + + joined_results = _join_results( + result_components_by_plan_id, + join_indexes_by_plan_id, + result_components_by_plan_id[query_plan_descriptor.root_sub_query_plan.plan_id], + query_plan_descriptor.output_join_descriptors, + ) + + return _drop_intermediate_outputs( + query_plan_descriptor.intermediate_output_names, joined_results + ) + + +def _make_join_indexes( + query_plan_descriptor: QueryPlanDescriptor, + result_components_by_plan_id: Dict[int, List[Dict[str, Any]]], +) -> Dict[int, Dict[str, List[int]]]: + """Return a dict from child plan id to a join index between its and its parents' rows.""" + join_indexes_by_plan_id: Dict[int, Dict[str, List[int]]] = dict() + + for join_descriptor in query_plan_descriptor.output_join_descriptors: + child_plan_id = join_descriptor.child_query_plan.plan_id + _, child_output_name = join_descriptor.output_names + + if child_plan_id in join_indexes_by_plan_id: + raise AssertionError( + "Unreachable code reached: {} {} {}".format( + child_plan_id, + join_indexes_by_plan_id, + query_plan_descriptor.output_join_descriptors, + ) + ) + + join_indexes_by_plan_id[child_plan_id] = _make_join_index_for_output( + result_components_by_plan_id[child_plan_id], child_output_name + ) + + return join_indexes_by_plan_id + + +def _make_join_index_for_output( + results: List[Dict[str, Any]], join_output_name: str +) -> Dict[str, List[int]]: + """Return a dict of each value of the join column to a list of row indexes where it appears.""" + join_index: Dict[str, List[int]] = {} + for row_index, row in enumerate(results): + join_value = row[join_output_name] + join_index.setdefault(join_value, []).append(row_index) + + return join_index + + +def _join_results( + result_components_by_plan_id: Dict[int, List[Dict[str, Any]]], + join_indexes_by_plan_id: Dict[int, Dict[str, List[int]]], + current_results: List[Dict[str, Any]], + join_descriptors: List[OutputJoinDescriptor], +) -> List[Dict[str, Any]]: + """Return the merged results across all subplans using the calculated join indexes.""" + if len(join_descriptors) == 0: + # No further joining to be done! + return current_results + + next_results = [] + + next_join_descriptor = join_descriptors[0] + remaining_join_descriptors = join_descriptors[1:] + + join_plan_id = next_join_descriptor.child_query_plan.plan_id + join_index = join_indexes_by_plan_id[join_plan_id] + joining_results = result_components_by_plan_id[join_plan_id] + join_from_key, _ = next_join_descriptor.output_names + + for current_row in current_results: + join_value = current_row[join_from_key] + + # To get inner join semantics, we don't output results that don't have matches. + # When we add support for stitching across @optional edges, we'll need to update this + # code to also output results even when the join index doesn't contain matches. + for join_matched_index in join_index.get(join_value, []): + joining_row = joining_results[join_matched_index] + next_results.append(dict(current_row, **joining_row)) + + return _join_results( + result_components_by_plan_id, + join_indexes_by_plan_id, + next_results, + remaining_join_descriptors, + ) + + +def _drop_intermediate_outputs( + columns_to_drop: FrozenSet[str], results: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: + """Return the provided results with the specified column names dropped.""" + processed_results = [] + + for row in results: + processed_results.append( + {key: value for key, value in row.items() if key not in columns_to_drop} + ) + + return processed_results diff --git a/graphql_compiler/schema_transformation/split_query.py b/graphql_compiler/schema_transformation/split_query.py index b85e97483..647e88620 100644 --- a/graphql_compiler/schema_transformation/split_query.py +++ b/graphql_compiler/schema_transformation/split_query.py @@ -159,7 +159,7 @@ def _check_or_set_schema_id(self, type_name: str) -> None: def split_query( - query_ast: DocumentNode, merged_schema_descriptor: MergedSchemaDescriptor + query_ast: DocumentNode, merged_schema_descriptor: MergedSchemaDescriptor, strict=True ) -> Tuple[SubQueryNode, FrozenSet[str]]: """Split input query AST into a tree of SubQueryNodes targeting each individual schema. @@ -171,6 +171,9 @@ def split_query( Args: query_ast: representing a GraphQL query to split. merged_schema_descriptor: description of the merged schema to split the query over. + strict: bool, if set to True then limits query splitting to queries that are guaranteed + to be safely splittable. If False, then some queries may be permitted to be split + even though they are illegal. Use with caution. Returns: Tuple of: @@ -186,7 +189,7 @@ def split_query( - SchemaStructureError if the input merged_schema_descriptor appears to be invalid or inconsistent """ - check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast) + check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast, strict=strict) # If schema directives are correctly represented in the schema object, type_info is all # that's needed to detect and address stitching fields. However, GraphQL currently ignores diff --git a/graphql_compiler/schema_transformation/utils.py b/graphql_compiler/schema_transformation/utils.py index 01649408e..4ac40e967 100644 --- a/graphql_compiler/schema_transformation/utils.py +++ b/graphql_compiler/schema_transformation/utils.py @@ -13,6 +13,7 @@ FragmentSpreadNode, InlineFragmentNode, InterfaceTypeDefinitionNode, + ListValueNode, NamedTypeNode, NameNode, Node, @@ -20,6 +21,7 @@ ScalarTypeDefinitionNode, SelectionNode, SelectionSetNode, + StringValueNode, UnionTypeDefinitionNode, ) from graphql.language.visitor import Visitor, visit @@ -29,6 +31,11 @@ import six from ..ast_manipulation import get_ast_with_non_null_and_list_stripped +from ..compiler.helpers import ( + get_parameter_name, + get_uniquely_named_objects_by_name, + is_runtime_parameter, +) from ..exceptions import GraphQLError, GraphQLValidationError from ..schema import FilterDirective, OptionalDirective, OutputDirective @@ -692,10 +699,18 @@ class CheckQueryIsValidToSplitVisitor(Visitor): ) ) + def __init__(self, strict: bool): + """Initialize the visitor with the appropriate strictness setting.""" + super(CheckQueryIsValidToSplitVisitor, self).__init__() + self.strict = strict + def enter_directive( self, node: DirectiveNode, key: Any, parent: Any, path: List[Any], ancestors: List[Any] ) -> None: """Check that the directive is supported.""" + if not self.strict: + return + if node.name.value not in self.supported_directives: raise GraphQLValidationError( 'Directive "{}" is not yet supported, only "{}" are currently ' @@ -757,7 +772,9 @@ def enter_selection_set( seen_vertex_field = True -def check_query_is_valid_to_split(schema: GraphQLSchema, query_ast: DocumentNode) -> None: +def check_query_is_valid_to_split( + schema: GraphQLSchema, query_ast: DocumentNode, strict: bool = True +) -> None: """Check the query is valid for splitting. In particular, ensure that the query validates against the schema, does not contain @@ -767,6 +784,9 @@ def check_query_is_valid_to_split(schema: GraphQLSchema, query_ast: DocumentNode Args: schema: schema the query is written against. query_ast: query to split. + strict: bool, if set to True then limits query splitting to queries that are guaranteed + to be safely splittable. If False, then some queries may be permitted to be split + even though they are illegal. Use with caution. Raises: GraphQLValidationError if the query doesn't validate against the schema, contains @@ -778,5 +798,48 @@ def check_query_is_valid_to_split(schema: GraphQLSchema, query_ast: DocumentNode if len(built_in_validation_errors) > 0: raise GraphQLValidationError("AST does not validate: {}".format(built_in_validation_errors)) # Check no bad directives and fields are in order - visitor = CheckQueryIsValidToSplitVisitor() + visitor = CheckQueryIsValidToSplitVisitor(strict) + visit(query_ast, visitor) + + +class QueryRuntimeArgumentsVisitor(Visitor): + """Visitor that collects runtime argument names from @filter directives.""" + + def __init__(self) -> None: + """Initialize the visitor.""" + super(QueryRuntimeArgumentsVisitor, self).__init__() + self.runtime_arguments: Set[str] = set() + + def enter_directive( + self, node: DirectiveNode, key: Any, parent: Any, path: List[Any], ancestors: List[Any] + ) -> None: + """Check that the directive is supported.""" + if node.name.value != FilterDirective.name: + return + + directive_arguments = get_uniquely_named_objects_by_name(node.arguments) + directive_argument_list = directive_arguments["value"].value + if not isinstance(directive_argument_list, ListValueNode): + raise AssertionError( + "Unreachable code reached. Expected directive_argument_list to be of type " + f"ListValueNode, but was of type {type(directive_argument_list)}." + ) + entry_names: List[str] = [] + for list_element in directive_argument_list.values: + if not isinstance(list_element, StringValueNode): + raise AssertionError( + f"Expected directive arguments to be StringValueNode, but received " + f"{list_element} with type {type(list_element)}." + ) + entry_names.append(list_element.value) + + self.runtime_arguments.update( + get_parameter_name(name) for name in entry_names if is_runtime_parameter(name) + ) + + +def get_query_runtime_arguments(query_ast: DocumentNode) -> Set[str]: + """Return a set containing the names of the runtime arguments required by the query.""" + visitor = QueryRuntimeArgumentsVisitor() visit(query_ast, visitor) + return visitor.runtime_arguments diff --git a/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py b/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py index eb825fc90..4da2e7059 100644 --- a/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py +++ b/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py @@ -4,7 +4,7 @@ from graphql import parse, print_ast -from ...schema_transformation.make_query_plan import make_query_plan +from ...schema_transformation.query_plan import make_query_plan from ...schema_transformation.split_query import split_query from .example_schema import basic_merged_schema diff --git a/mypy.ini b/mypy.ini index ad086efcb..b04b52c58 100644 --- a/mypy.ini +++ b/mypy.ini @@ -190,9 +190,6 @@ check_untyped_defs = False [mypy-graphql_compiler.schema_transformation.*] disallow_untyped_calls = False -[mypy-graphql_compiler.schema_transformation.make_query_plan.*] -disallow_untyped_defs = False - [mypy-graphql_compiler.schema_transformation.split_query.*] disallow_incomplete_defs = False disallow_untyped_defs = False @@ -344,9 +341,6 @@ disallow_untyped_calls = False # Third-party module rule relaxations -[mypy-arrow.*] -ignore_missing_imports = True - [mypy-funcy.*] ignore_missing_imports = True