From 96d258b316c803d6bc3c34904acdca9f24790609 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Mon, 9 Sep 2019 22:45:16 -0400 Subject: [PATCH 1/8] Draft implementation for cross-db query plan execution. --- .../{make_query_plan.py => query_plan.py} | 195 +++++++++++++++++- .../schema_transformation/split_query.py | 7 +- .../schema_transformation/utils.py | 53 ++++- .../test_make_query_plan.py | 2 +- 4 files changed, 244 insertions(+), 13 deletions(-) rename graphql_compiler/schema_transformation/{make_query_plan.py => query_plan.py} (58%) diff --git a/graphql_compiler/schema_transformation/make_query_plan.py b/graphql_compiler/schema_transformation/query_plan.py similarity index 58% rename from graphql_compiler/schema_transformation/make_query_plan.py rename to graphql_compiler/schema_transformation/query_plan.py index e786c59bb..438d52539 100644 --- a/graphql_compiler/schema_transformation/make_query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -11,10 +11,12 @@ from ..ast_manipulation import get_only_query_definition from ..exceptions import GraphQLValidationError from ..schema import FilterDirective, OutputDirective +from .utils import get_query_runtime_arguments SubQueryPlan = namedtuple( 'SubQueryPlan', ( + 'plan_id', # int, unique identifier for this sub-plan 'query_ast', # Document, representing a piece of the overall query with directives added 'schema_id', # str, identifying the schema that this query piece targets 'parent_query_plan', # SubQueryPlan, the query that the current query depends on @@ -26,6 +28,9 @@ OutputJoinDescriptor = namedtuple( 'OutputJoinDescriptor', ( 'output_names', # Tuple[str, str], (parent output name, child output name) + 'child_query_plan', # SubQueryPlan, the sub-plan node for which the join happens + # between it and its parent sub-plan + # May be expanded to have more attributes, e.g. is_optional, describing how the join # should be made ) @@ -67,13 +72,14 @@ def make_query_plan(root_sub_query_node, intermediate_output_names): output_join_descriptors = [] root_sub_query_plan = SubQueryPlan( + plan_id=0, query_ast=root_sub_query_node.query_ast, schema_id=root_sub_query_node.schema_id, parent_query_plan=None, child_query_plans=[], ) - _make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors) + _make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors, 1) return QueryPlanDescriptor( root_sub_query_plan=root_sub_query_plan, @@ -82,7 +88,8 @@ def make_query_plan(root_sub_query_node, intermediate_output_names): ) -def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors): +def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors, + next_plan_id): """Recursively copy the structure of sub_query_node onto sub_query_plan. For each child connection contained in sub_query_node, create a new SubQueryPlan for @@ -96,7 +103,7 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr modified output_join_descriptors: List[OutputJoinDescriptor], describing which outputs should be joined and how - + next_plan_id: int, the next available plan ID to use. IDs at and above this number are free. """ # Iterate through child connections of query node for child_query_connection in sub_query_node.child_query_connections: @@ -123,11 +130,13 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr # Create new SubQueryPlan for child child_sub_query_plan = SubQueryPlan( + plan_id=next_plan_id, query_ast=new_child_query_ast, schema_id=child_sub_query_node.schema_id, parent_query_plan=sub_query_plan, child_query_plans=[], ) + next_plan_id += 1 # Add new SubQueryPlan to parent's child list sub_query_plan.child_query_plans.append(child_sub_query_plan) @@ -135,12 +144,13 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr # Add information about this edge new_output_join_descriptor = OutputJoinDescriptor( output_names=(parent_out_name, child_out_name), + child_query_plan=child_sub_query_plan, ) output_join_descriptors.append(new_output_join_descriptor) # Recursively repeat on child SubQueryPlans _make_query_plan_recursive( - child_sub_query_node, child_sub_query_plan, output_join_descriptors + child_sub_query_node, child_sub_query_plan, output_join_descriptors, next_plan_id ) @@ -253,15 +263,25 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4): line_separation = u'\n' + u' ' * indentation_depth * depth query_plan_strings.append(line_separation) - query_str = u'Execute in schema named "{}":\n'.format(query_plan.schema_id) + query_str = u'Execute subplan ID {} in schema named "{}":\n'.format( + query_plan.plan_id, query_plan.schema_id) query_str += print_ast(query_plan.query_ast) query_str = query_str.replace(u'\n', line_separation) query_plan_strings.append(query_str) query_plan_strings.append(u'\n\nJoin together outputs as follows: ') - query_plan_strings.append(str(query_plan_descriptor.output_join_descriptors)) + query_plan_strings.append(str([ + ' '.join([ + str(descriptor.output_names), + 'between subplan IDs', + str([ + descriptor.child_query_plan.parent_query_plan.plan_id, + descriptor.child_query_plan.plan_id + ])]) + for descriptor in query_plan_descriptor.output_join_descriptors + ])) query_plan_strings.append(u'\n\nRemove the following outputs at the end: ') - query_plan_strings.append(str(query_plan_descriptor.intermediate_output_names) + u'\n') + query_plan_strings.append(str(set(query_plan_descriptor.intermediate_output_names)) + u'\n') return ''.join(query_plan_strings) @@ -276,3 +296,164 @@ def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth): ) return plan_and_depth_in_dfs_order return _get_plan_and_depth_in_dfs_order_helper(query_plan, 0) + + +def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query_args): + """Execute the given query plan and return the produced results.""" + result_components_by_plan_id = {} + + stitching_output_names_by_parent_plan_id = dict() + for join_descriptor in query_plan_descriptor.output_join_descriptors: + parent_plan_id = join_descriptor.child_query_plan.parent_query_plan.plan_id + stitching_output_names_by_parent_plan_id.setdefault(parent_plan_id, []).append( + join_descriptor.output_names) + + full_query_args = dict(query_args) + + plan_and_depth = _get_plan_and_depth_in_dfs_order(query_plan_descriptor.root_sub_query_plan) + + for query_plan, _ in plan_and_depth: + plan_id = query_plan.plan_id + schema_id = query_plan.schema_id + + subquery_graphql = print_ast(query_plan.query_ast) + + print('\n\n********* BEGIN *********\n') + print(subquery_graphql) + + # HACK(predrag): Add proper error checking for missing arguments here. + # HACK(predrag): Don't bother running queries if the previous query's stitching outputs + # returned no values to pass to the next query. + subquery_args = { + argument_name: full_query_args[argument_name] + for argument_name in get_query_runtime_arguments(query_plan.query_ast) + } + + print(subquery_args) + + # Run the query and save the results. + execution_func = schema_id_to_execution_func[schema_id] + subquery_result = execution_func(subquery_graphql, subquery_args) + result_components_by_plan_id[plan_id] = subquery_result + + print(subquery_result) + + # Capture and record any values that will be used for stitching by other subqueries. + child_extra_output_names = { + # The .get() call is to handle the case of query plans with no children. + # They have no extra output values for their children, on account of having no children. + output_name + for output_name, _ in stitching_output_names_by_parent_plan_id.get(plan_id, []) + } + child_extra_output_values = { + # Make sure we deduplicate the values -- there's no point in running subqueries + # with duplicated runtime argument values. + output_name: set() + for output_name in child_extra_output_names + } + for subquery_row in subquery_result: + for output_name in child_extra_output_names: + # We intentionally discard None values -- None is never a foreign key value. + # This is standard in all relational systems as well. + output_value = subquery_row.get(output_name, None) + if output_value is not None: + child_extra_output_values[output_name].add(output_value) + # TODO(predrag): Use the "merge_disjoint_dicts" function here, + # there should never be any overlap here. + new_query_args = { + # Argument values cannot be sets, so we turn the sets back into lists. + output_argument_name: list(child_extra_output_values[output_argument_name]) + for output_argument_name in child_extra_output_names + } + full_query_args.update(new_query_args) + + print(new_query_args) + print('\n********** END ***********\n') + + join_indexes_by_plan_id = _make_join_indexes( + query_plan_descriptor, result_components_by_plan_id) + + joined_results = _join_results( + result_components_by_plan_id, join_indexes_by_plan_id, + result_components_by_plan_id[query_plan_descriptor.root_sub_query_plan.plan_id], + query_plan_descriptor.output_join_descriptors) + + return _drop_intermediate_outputs( + query_plan_descriptor.intermediate_output_names, joined_results) + + +def _make_join_indexes(query_plan_descriptor, result_components_by_plan_id): + """Return a dict from child plan id to a join index between its and its parents' rows.""" + join_indexes_by_plan_id = dict() + + for join_descriptor in query_plan_descriptor.output_join_descriptors: + child_plan_id = join_descriptor.child_query_plan.plan_id + _, child_output_name = join_descriptor.output_names + + if child_plan_id in join_indexes_by_plan_id: + raise AssertionError('Unreachable code reached: {} {} {}' + .format(child_plan_id, join_indexes_by_plan_id, + query_plan_descriptor.output_join_descriptors)) + + join_indexes_by_plan_id[child_plan_id] = _make_join_index_for_output( + result_components_by_plan_id[child_plan_id], child_output_name) + + return join_indexes_by_plan_id + + +def _make_join_index_for_output(results, join_output_name): + """Return a dict of each value of the join column to a list of row indexes where it appears.""" + print('making join index on column ', join_output_name) + print(results) + + join_index = {} + for row_index, row in enumerate(results): + join_value = row[join_output_name] + join_index.setdefault(join_value, []).append(row_index) + + return join_index + + +def _join_results(result_components_by_plan_id, join_indexes_by_plan_id, + current_results, join_descriptors): + """Return the merged results across all subplans using the calculated join indexes.""" + if len(join_descriptors) == 0: + # No further joining to be done! + return current_results + + next_results = [] + + next_join_descriptor = join_descriptors[0] + remaining_join_descriptors = join_descriptors[1:] + + join_plan_id = next_join_descriptor.child_query_plan.plan_id + join_index = join_indexes_by_plan_id[join_plan_id] + joining_results = result_components_by_plan_id[join_plan_id] + join_from_key, join_to_key = next_join_descriptor.output_names + + for current_row in current_results: + join_value = current_row[join_from_key] + + # To get inner join semantics, we don't output results that don't have matches. + # When we add support for stitching across @optional edges, we'll need to update this + # code to also output results even when the join index doesn't contain matches. + for join_matched_index in join_index.get(join_value, []): + joining_row = joining_results[join_matched_index] + next_results.append(dict(current_row, **joining_row)) + + return _join_results(result_components_by_plan_id, join_indexes_by_plan_id, + next_results, remaining_join_descriptors) + + +def _drop_intermediate_outputs(columns_to_drop, results): + """Return the provided results with the specified column names dropped.""" + processed_results = [] + + for row in results: + processed_results.append({ + key: value + for key, value in row.items() + if key not in columns_to_drop + }) + + return processed_results diff --git a/graphql_compiler/schema_transformation/split_query.py b/graphql_compiler/schema_transformation/split_query.py index 5e3198f18..5c67c3819 100644 --- a/graphql_compiler/schema_transformation/split_query.py +++ b/graphql_compiler/schema_transformation/split_query.py @@ -47,7 +47,7 @@ def __init__(self, query_ast): # List[SubQueryNode], the queries that depend on the current query -def split_query(query_ast, merged_schema_descriptor): +def split_query(query_ast, merged_schema_descriptor, strict=True): """Split input query AST into a tree of SubQueryNodes targeting each individual schema. Property fields used in the stitch will be added if not already present. @output directives @@ -62,6 +62,9 @@ def split_query(query_ast, merged_schema_descriptor): schema: GraphQLSchema representing the merged schema type_name_to_schema_id: Dict[str, str], mapping type names to the id of the schema it came from + strict: bool, if set to True then limits query splitting to queries that are guaranteed + to be safely splittable. If False, then some queries may be permitted to be split + even though they are illegal. Use with caution. Returns: Tuple[SubQueryNode, frozenset[str]]. The first element is the root of the tree of @@ -77,7 +80,7 @@ def split_query(query_ast, merged_schema_descriptor): - SchemaStructureError if the input merged_schema_descriptor appears to be invalid or inconsistent """ - check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast) + check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast, strict=strict) # If schema directives are correctly represented in the schema object, type_info is all # that's needed to detect and address stitching fields. However, GraphQL currently ignores diff --git a/graphql_compiler/schema_transformation/utils.py b/graphql_compiler/schema_transformation/utils.py index 9f7e8683d..64f16385f 100644 --- a/graphql_compiler/schema_transformation/utils.py +++ b/graphql_compiler/schema_transformation/utils.py @@ -11,8 +11,11 @@ import six from ..ast_manipulation import get_ast_with_non_null_and_list_stripped +from ..compiler.helpers import ( + get_parameter_name, get_uniquely_named_objects_by_name, is_runtime_parameter +) from ..exceptions import GraphQLError, GraphQLValidationError -from ..schema import FilterDirective, OptionalDirective, OutputDirective +from ..schema import FilterDirective, OptionalDirective, OutputDirective, RecurseDirective class SchemaTransformError(GraphQLError): @@ -397,8 +400,16 @@ class CheckQueryIsValidToSplitVisitor(Visitor): OptionalDirective.name, )) + def __init__(self, strict): + """Initialize the visitor with the appropriate strictness setting.""" + super(CheckQueryIsValidToSplitVisitor, self).__init__() + self.strict = strict + def enter_Directive(self, node, *args): """Check that the directive is supported.""" + if not self.strict: + return + if node.name.value not in self.supported_directives: raise GraphQLValidationError( u'Directive "{}" is not yet supported, only "{}" are currently ' @@ -443,7 +454,7 @@ def enter_SelectionSet(self, node, *args): seen_vertex_field = True -def check_query_is_valid_to_split(schema, query_ast): +def check_query_is_valid_to_split(schema, query_ast, strict=True): """Check the query is valid for splitting. In particular, ensure that the query validates against the schema, does not contain @@ -453,6 +464,9 @@ def check_query_is_valid_to_split(schema, query_ast): Args: schema: GraphQLSchema object query_ast: Document + strict: bool, if set to True then limits query splitting to queries that are guaranteed + to be safely splittable. If False, then some queries may be permitted to be split + even though they are illegal. Use with caution. Raises: GraphQLValidationError if the query doesn't validate against the schema, contains @@ -466,5 +480,38 @@ def check_query_is_valid_to_split(schema, query_ast): u'AST does not validate: {}'.format(built_in_validation_errors) ) # Check no bad directives and fields are in order - visitor = CheckQueryIsValidToSplitVisitor() + visitor = CheckQueryIsValidToSplitVisitor(strict) + visit(query_ast, visitor) + + +class QueryRuntimeArgumentsVisitor(Visitor): + """Visitor that collects runtime argument names from @filter directives.""" + + def __init__(self): + """Initialize the visitor.""" + super(QueryRuntimeArgumentsVisitor, self).__init__() + self.runtime_arguments = set() + + def enter_Directive(self, node, *args): + """Check that the directive is supported.""" + if node.name.value != FilterDirective.name: + return + + directive_arguments = get_uniquely_named_objects_by_name(node.arguments) + entry_names = [ + list_element.value + for list_element in directive_arguments['value'].value.values + ] + + self.runtime_arguments.update( + get_parameter_name(name) + for name in entry_names + if is_runtime_parameter(name) + ) + + +def get_query_runtime_arguments(query_ast): + """Return a set containing the names of the runtime arguments required by the query.""" + visitor = QueryRuntimeArgumentsVisitor() visit(query_ast, visitor) + return visitor.runtime_arguments diff --git a/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py b/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py index 233c6540a..95e752b89 100644 --- a/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py +++ b/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py @@ -4,7 +4,7 @@ from graphql import parse, print_ast -from ...schema_transformation.make_query_plan import make_query_plan +from ...schema_transformation.query_plan import make_query_plan from ...schema_transformation.split_query import split_query from .example_schema import basic_merged_schema From e43616285034abc733bc0f133e9eb0beed82c37a Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Fri, 13 Sep 2019 10:01:07 -0400 Subject: [PATCH 2/8] Remove print statements. --- .../schema_transformation/query_plan.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/graphql_compiler/schema_transformation/query_plan.py b/graphql_compiler/schema_transformation/query_plan.py index 438d52539..27427e29f 100644 --- a/graphql_compiler/schema_transformation/query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -269,8 +269,8 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4): query_str = query_str.replace(u'\n', line_separation) query_plan_strings.append(query_str) - query_plan_strings.append(u'\n\nJoin together outputs as follows: ') - query_plan_strings.append(str([ + query_plan_strings.append(u'\n\nJoin together outputs as follows:\n') + query_plan_strings.append('\n'.join([ ' '.join([ str(descriptor.output_names), 'between subplan IDs', @@ -318,9 +318,6 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query subquery_graphql = print_ast(query_plan.query_ast) - print('\n\n********* BEGIN *********\n') - print(subquery_graphql) - # HACK(predrag): Add proper error checking for missing arguments here. # HACK(predrag): Don't bother running queries if the previous query's stitching outputs # returned no values to pass to the next query. @@ -329,15 +326,11 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query for argument_name in get_query_runtime_arguments(query_plan.query_ast) } - print(subquery_args) - # Run the query and save the results. execution_func = schema_id_to_execution_func[schema_id] subquery_result = execution_func(subquery_graphql, subquery_args) result_components_by_plan_id[plan_id] = subquery_result - print(subquery_result) - # Capture and record any values that will be used for stitching by other subqueries. child_extra_output_names = { # The .get() call is to handle the case of query plans with no children. @@ -367,9 +360,6 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query } full_query_args.update(new_query_args) - print(new_query_args) - print('\n********** END ***********\n') - join_indexes_by_plan_id = _make_join_indexes( query_plan_descriptor, result_components_by_plan_id) @@ -403,9 +393,6 @@ def _make_join_indexes(query_plan_descriptor, result_components_by_plan_id): def _make_join_index_for_output(results, join_output_name): """Return a dict of each value of the join column to a list of row indexes where it appears.""" - print('making join index on column ', join_output_name) - print(results) - join_index = {} for row_index, row in enumerate(results): join_value = row[join_output_name] From e8b783cf329cd17c5a09fd28ef3a1c8450dde0a3 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Mon, 9 Sep 2019 22:45:16 -0400 Subject: [PATCH 3/8] Draft implementation for cross-db query plan execution. --- .../{make_query_plan.py => query_plan.py} | 195 +++++++++++++++++- .../schema_transformation/split_query.py | 7 +- .../schema_transformation/utils.py | 53 ++++- .../test_make_query_plan.py | 2 +- 4 files changed, 244 insertions(+), 13 deletions(-) rename graphql_compiler/schema_transformation/{make_query_plan.py => query_plan.py} (58%) diff --git a/graphql_compiler/schema_transformation/make_query_plan.py b/graphql_compiler/schema_transformation/query_plan.py similarity index 58% rename from graphql_compiler/schema_transformation/make_query_plan.py rename to graphql_compiler/schema_transformation/query_plan.py index e786c59bb..438d52539 100644 --- a/graphql_compiler/schema_transformation/make_query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -11,10 +11,12 @@ from ..ast_manipulation import get_only_query_definition from ..exceptions import GraphQLValidationError from ..schema import FilterDirective, OutputDirective +from .utils import get_query_runtime_arguments SubQueryPlan = namedtuple( 'SubQueryPlan', ( + 'plan_id', # int, unique identifier for this sub-plan 'query_ast', # Document, representing a piece of the overall query with directives added 'schema_id', # str, identifying the schema that this query piece targets 'parent_query_plan', # SubQueryPlan, the query that the current query depends on @@ -26,6 +28,9 @@ OutputJoinDescriptor = namedtuple( 'OutputJoinDescriptor', ( 'output_names', # Tuple[str, str], (parent output name, child output name) + 'child_query_plan', # SubQueryPlan, the sub-plan node for which the join happens + # between it and its parent sub-plan + # May be expanded to have more attributes, e.g. is_optional, describing how the join # should be made ) @@ -67,13 +72,14 @@ def make_query_plan(root_sub_query_node, intermediate_output_names): output_join_descriptors = [] root_sub_query_plan = SubQueryPlan( + plan_id=0, query_ast=root_sub_query_node.query_ast, schema_id=root_sub_query_node.schema_id, parent_query_plan=None, child_query_plans=[], ) - _make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors) + _make_query_plan_recursive(root_sub_query_node, root_sub_query_plan, output_join_descriptors, 1) return QueryPlanDescriptor( root_sub_query_plan=root_sub_query_plan, @@ -82,7 +88,8 @@ def make_query_plan(root_sub_query_node, intermediate_output_names): ) -def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors): +def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descriptors, + next_plan_id): """Recursively copy the structure of sub_query_node onto sub_query_plan. For each child connection contained in sub_query_node, create a new SubQueryPlan for @@ -96,7 +103,7 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr modified output_join_descriptors: List[OutputJoinDescriptor], describing which outputs should be joined and how - + next_plan_id: int, the next available plan ID to use. IDs at and above this number are free. """ # Iterate through child connections of query node for child_query_connection in sub_query_node.child_query_connections: @@ -123,11 +130,13 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr # Create new SubQueryPlan for child child_sub_query_plan = SubQueryPlan( + plan_id=next_plan_id, query_ast=new_child_query_ast, schema_id=child_sub_query_node.schema_id, parent_query_plan=sub_query_plan, child_query_plans=[], ) + next_plan_id += 1 # Add new SubQueryPlan to parent's child list sub_query_plan.child_query_plans.append(child_sub_query_plan) @@ -135,12 +144,13 @@ def _make_query_plan_recursive(sub_query_node, sub_query_plan, output_join_descr # Add information about this edge new_output_join_descriptor = OutputJoinDescriptor( output_names=(parent_out_name, child_out_name), + child_query_plan=child_sub_query_plan, ) output_join_descriptors.append(new_output_join_descriptor) # Recursively repeat on child SubQueryPlans _make_query_plan_recursive( - child_sub_query_node, child_sub_query_plan, output_join_descriptors + child_sub_query_node, child_sub_query_plan, output_join_descriptors, next_plan_id ) @@ -253,15 +263,25 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4): line_separation = u'\n' + u' ' * indentation_depth * depth query_plan_strings.append(line_separation) - query_str = u'Execute in schema named "{}":\n'.format(query_plan.schema_id) + query_str = u'Execute subplan ID {} in schema named "{}":\n'.format( + query_plan.plan_id, query_plan.schema_id) query_str += print_ast(query_plan.query_ast) query_str = query_str.replace(u'\n', line_separation) query_plan_strings.append(query_str) query_plan_strings.append(u'\n\nJoin together outputs as follows: ') - query_plan_strings.append(str(query_plan_descriptor.output_join_descriptors)) + query_plan_strings.append(str([ + ' '.join([ + str(descriptor.output_names), + 'between subplan IDs', + str([ + descriptor.child_query_plan.parent_query_plan.plan_id, + descriptor.child_query_plan.plan_id + ])]) + for descriptor in query_plan_descriptor.output_join_descriptors + ])) query_plan_strings.append(u'\n\nRemove the following outputs at the end: ') - query_plan_strings.append(str(query_plan_descriptor.intermediate_output_names) + u'\n') + query_plan_strings.append(str(set(query_plan_descriptor.intermediate_output_names)) + u'\n') return ''.join(query_plan_strings) @@ -276,3 +296,164 @@ def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth): ) return plan_and_depth_in_dfs_order return _get_plan_and_depth_in_dfs_order_helper(query_plan, 0) + + +def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query_args): + """Execute the given query plan and return the produced results.""" + result_components_by_plan_id = {} + + stitching_output_names_by_parent_plan_id = dict() + for join_descriptor in query_plan_descriptor.output_join_descriptors: + parent_plan_id = join_descriptor.child_query_plan.parent_query_plan.plan_id + stitching_output_names_by_parent_plan_id.setdefault(parent_plan_id, []).append( + join_descriptor.output_names) + + full_query_args = dict(query_args) + + plan_and_depth = _get_plan_and_depth_in_dfs_order(query_plan_descriptor.root_sub_query_plan) + + for query_plan, _ in plan_and_depth: + plan_id = query_plan.plan_id + schema_id = query_plan.schema_id + + subquery_graphql = print_ast(query_plan.query_ast) + + print('\n\n********* BEGIN *********\n') + print(subquery_graphql) + + # HACK(predrag): Add proper error checking for missing arguments here. + # HACK(predrag): Don't bother running queries if the previous query's stitching outputs + # returned no values to pass to the next query. + subquery_args = { + argument_name: full_query_args[argument_name] + for argument_name in get_query_runtime_arguments(query_plan.query_ast) + } + + print(subquery_args) + + # Run the query and save the results. + execution_func = schema_id_to_execution_func[schema_id] + subquery_result = execution_func(subquery_graphql, subquery_args) + result_components_by_plan_id[plan_id] = subquery_result + + print(subquery_result) + + # Capture and record any values that will be used for stitching by other subqueries. + child_extra_output_names = { + # The .get() call is to handle the case of query plans with no children. + # They have no extra output values for their children, on account of having no children. + output_name + for output_name, _ in stitching_output_names_by_parent_plan_id.get(plan_id, []) + } + child_extra_output_values = { + # Make sure we deduplicate the values -- there's no point in running subqueries + # with duplicated runtime argument values. + output_name: set() + for output_name in child_extra_output_names + } + for subquery_row in subquery_result: + for output_name in child_extra_output_names: + # We intentionally discard None values -- None is never a foreign key value. + # This is standard in all relational systems as well. + output_value = subquery_row.get(output_name, None) + if output_value is not None: + child_extra_output_values[output_name].add(output_value) + # TODO(predrag): Use the "merge_disjoint_dicts" function here, + # there should never be any overlap here. + new_query_args = { + # Argument values cannot be sets, so we turn the sets back into lists. + output_argument_name: list(child_extra_output_values[output_argument_name]) + for output_argument_name in child_extra_output_names + } + full_query_args.update(new_query_args) + + print(new_query_args) + print('\n********** END ***********\n') + + join_indexes_by_plan_id = _make_join_indexes( + query_plan_descriptor, result_components_by_plan_id) + + joined_results = _join_results( + result_components_by_plan_id, join_indexes_by_plan_id, + result_components_by_plan_id[query_plan_descriptor.root_sub_query_plan.plan_id], + query_plan_descriptor.output_join_descriptors) + + return _drop_intermediate_outputs( + query_plan_descriptor.intermediate_output_names, joined_results) + + +def _make_join_indexes(query_plan_descriptor, result_components_by_plan_id): + """Return a dict from child plan id to a join index between its and its parents' rows.""" + join_indexes_by_plan_id = dict() + + for join_descriptor in query_plan_descriptor.output_join_descriptors: + child_plan_id = join_descriptor.child_query_plan.plan_id + _, child_output_name = join_descriptor.output_names + + if child_plan_id in join_indexes_by_plan_id: + raise AssertionError('Unreachable code reached: {} {} {}' + .format(child_plan_id, join_indexes_by_plan_id, + query_plan_descriptor.output_join_descriptors)) + + join_indexes_by_plan_id[child_plan_id] = _make_join_index_for_output( + result_components_by_plan_id[child_plan_id], child_output_name) + + return join_indexes_by_plan_id + + +def _make_join_index_for_output(results, join_output_name): + """Return a dict of each value of the join column to a list of row indexes where it appears.""" + print('making join index on column ', join_output_name) + print(results) + + join_index = {} + for row_index, row in enumerate(results): + join_value = row[join_output_name] + join_index.setdefault(join_value, []).append(row_index) + + return join_index + + +def _join_results(result_components_by_plan_id, join_indexes_by_plan_id, + current_results, join_descriptors): + """Return the merged results across all subplans using the calculated join indexes.""" + if len(join_descriptors) == 0: + # No further joining to be done! + return current_results + + next_results = [] + + next_join_descriptor = join_descriptors[0] + remaining_join_descriptors = join_descriptors[1:] + + join_plan_id = next_join_descriptor.child_query_plan.plan_id + join_index = join_indexes_by_plan_id[join_plan_id] + joining_results = result_components_by_plan_id[join_plan_id] + join_from_key, join_to_key = next_join_descriptor.output_names + + for current_row in current_results: + join_value = current_row[join_from_key] + + # To get inner join semantics, we don't output results that don't have matches. + # When we add support for stitching across @optional edges, we'll need to update this + # code to also output results even when the join index doesn't contain matches. + for join_matched_index in join_index.get(join_value, []): + joining_row = joining_results[join_matched_index] + next_results.append(dict(current_row, **joining_row)) + + return _join_results(result_components_by_plan_id, join_indexes_by_plan_id, + next_results, remaining_join_descriptors) + + +def _drop_intermediate_outputs(columns_to_drop, results): + """Return the provided results with the specified column names dropped.""" + processed_results = [] + + for row in results: + processed_results.append({ + key: value + for key, value in row.items() + if key not in columns_to_drop + }) + + return processed_results diff --git a/graphql_compiler/schema_transformation/split_query.py b/graphql_compiler/schema_transformation/split_query.py index 5e3198f18..5c67c3819 100644 --- a/graphql_compiler/schema_transformation/split_query.py +++ b/graphql_compiler/schema_transformation/split_query.py @@ -47,7 +47,7 @@ def __init__(self, query_ast): # List[SubQueryNode], the queries that depend on the current query -def split_query(query_ast, merged_schema_descriptor): +def split_query(query_ast, merged_schema_descriptor, strict=True): """Split input query AST into a tree of SubQueryNodes targeting each individual schema. Property fields used in the stitch will be added if not already present. @output directives @@ -62,6 +62,9 @@ def split_query(query_ast, merged_schema_descriptor): schema: GraphQLSchema representing the merged schema type_name_to_schema_id: Dict[str, str], mapping type names to the id of the schema it came from + strict: bool, if set to True then limits query splitting to queries that are guaranteed + to be safely splittable. If False, then some queries may be permitted to be split + even though they are illegal. Use with caution. Returns: Tuple[SubQueryNode, frozenset[str]]. The first element is the root of the tree of @@ -77,7 +80,7 @@ def split_query(query_ast, merged_schema_descriptor): - SchemaStructureError if the input merged_schema_descriptor appears to be invalid or inconsistent """ - check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast) + check_query_is_valid_to_split(merged_schema_descriptor.schema, query_ast, strict=strict) # If schema directives are correctly represented in the schema object, type_info is all # that's needed to detect and address stitching fields. However, GraphQL currently ignores diff --git a/graphql_compiler/schema_transformation/utils.py b/graphql_compiler/schema_transformation/utils.py index 9f7e8683d..64f16385f 100644 --- a/graphql_compiler/schema_transformation/utils.py +++ b/graphql_compiler/schema_transformation/utils.py @@ -11,8 +11,11 @@ import six from ..ast_manipulation import get_ast_with_non_null_and_list_stripped +from ..compiler.helpers import ( + get_parameter_name, get_uniquely_named_objects_by_name, is_runtime_parameter +) from ..exceptions import GraphQLError, GraphQLValidationError -from ..schema import FilterDirective, OptionalDirective, OutputDirective +from ..schema import FilterDirective, OptionalDirective, OutputDirective, RecurseDirective class SchemaTransformError(GraphQLError): @@ -397,8 +400,16 @@ class CheckQueryIsValidToSplitVisitor(Visitor): OptionalDirective.name, )) + def __init__(self, strict): + """Initialize the visitor with the appropriate strictness setting.""" + super(CheckQueryIsValidToSplitVisitor, self).__init__() + self.strict = strict + def enter_Directive(self, node, *args): """Check that the directive is supported.""" + if not self.strict: + return + if node.name.value not in self.supported_directives: raise GraphQLValidationError( u'Directive "{}" is not yet supported, only "{}" are currently ' @@ -443,7 +454,7 @@ def enter_SelectionSet(self, node, *args): seen_vertex_field = True -def check_query_is_valid_to_split(schema, query_ast): +def check_query_is_valid_to_split(schema, query_ast, strict=True): """Check the query is valid for splitting. In particular, ensure that the query validates against the schema, does not contain @@ -453,6 +464,9 @@ def check_query_is_valid_to_split(schema, query_ast): Args: schema: GraphQLSchema object query_ast: Document + strict: bool, if set to True then limits query splitting to queries that are guaranteed + to be safely splittable. If False, then some queries may be permitted to be split + even though they are illegal. Use with caution. Raises: GraphQLValidationError if the query doesn't validate against the schema, contains @@ -466,5 +480,38 @@ def check_query_is_valid_to_split(schema, query_ast): u'AST does not validate: {}'.format(built_in_validation_errors) ) # Check no bad directives and fields are in order - visitor = CheckQueryIsValidToSplitVisitor() + visitor = CheckQueryIsValidToSplitVisitor(strict) + visit(query_ast, visitor) + + +class QueryRuntimeArgumentsVisitor(Visitor): + """Visitor that collects runtime argument names from @filter directives.""" + + def __init__(self): + """Initialize the visitor.""" + super(QueryRuntimeArgumentsVisitor, self).__init__() + self.runtime_arguments = set() + + def enter_Directive(self, node, *args): + """Check that the directive is supported.""" + if node.name.value != FilterDirective.name: + return + + directive_arguments = get_uniquely_named_objects_by_name(node.arguments) + entry_names = [ + list_element.value + for list_element in directive_arguments['value'].value.values + ] + + self.runtime_arguments.update( + get_parameter_name(name) + for name in entry_names + if is_runtime_parameter(name) + ) + + +def get_query_runtime_arguments(query_ast): + """Return a set containing the names of the runtime arguments required by the query.""" + visitor = QueryRuntimeArgumentsVisitor() visit(query_ast, visitor) + return visitor.runtime_arguments diff --git a/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py b/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py index 233c6540a..95e752b89 100644 --- a/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py +++ b/graphql_compiler/tests/schema_transformation_tests/test_make_query_plan.py @@ -4,7 +4,7 @@ from graphql import parse, print_ast -from ...schema_transformation.make_query_plan import make_query_plan +from ...schema_transformation.query_plan import make_query_plan from ...schema_transformation.split_query import split_query from .example_schema import basic_merged_schema From 4dc336280e01e3f55fc106b212ef21a2b0e0f271 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Fri, 13 Sep 2019 10:01:07 -0400 Subject: [PATCH 4/8] Remove print statements. --- .../schema_transformation/query_plan.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/graphql_compiler/schema_transformation/query_plan.py b/graphql_compiler/schema_transformation/query_plan.py index 438d52539..27427e29f 100644 --- a/graphql_compiler/schema_transformation/query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -269,8 +269,8 @@ def print_query_plan(query_plan_descriptor, indentation_depth=4): query_str = query_str.replace(u'\n', line_separation) query_plan_strings.append(query_str) - query_plan_strings.append(u'\n\nJoin together outputs as follows: ') - query_plan_strings.append(str([ + query_plan_strings.append(u'\n\nJoin together outputs as follows:\n') + query_plan_strings.append('\n'.join([ ' '.join([ str(descriptor.output_names), 'between subplan IDs', @@ -318,9 +318,6 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query subquery_graphql = print_ast(query_plan.query_ast) - print('\n\n********* BEGIN *********\n') - print(subquery_graphql) - # HACK(predrag): Add proper error checking for missing arguments here. # HACK(predrag): Don't bother running queries if the previous query's stitching outputs # returned no values to pass to the next query. @@ -329,15 +326,11 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query for argument_name in get_query_runtime_arguments(query_plan.query_ast) } - print(subquery_args) - # Run the query and save the results. execution_func = schema_id_to_execution_func[schema_id] subquery_result = execution_func(subquery_graphql, subquery_args) result_components_by_plan_id[plan_id] = subquery_result - print(subquery_result) - # Capture and record any values that will be used for stitching by other subqueries. child_extra_output_names = { # The .get() call is to handle the case of query plans with no children. @@ -367,9 +360,6 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query } full_query_args.update(new_query_args) - print(new_query_args) - print('\n********** END ***********\n') - join_indexes_by_plan_id = _make_join_indexes( query_plan_descriptor, result_components_by_plan_id) @@ -403,9 +393,6 @@ def _make_join_indexes(query_plan_descriptor, result_components_by_plan_id): def _make_join_index_for_output(results, join_output_name): """Return a dict of each value of the join column to a list of row indexes where it appears.""" - print('making join index on column ', join_output_name) - print(results) - join_index = {} for row_index, row in enumerate(results): join_value = row[join_output_name] From 4c3f303cc7f9c387853d3f2134a26b63d7ca16b8 Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Wed, 20 Jan 2021 17:15:31 -0500 Subject: [PATCH 5/8] add type hints --- .../schema_transformation/query_plan.py | 170 +++++++++++++----- .../schema_transformation/utils.py | 44 +++-- 2 files changed, 155 insertions(+), 59 deletions(-) diff --git a/graphql_compiler/schema_transformation/query_plan.py b/graphql_compiler/schema_transformation/query_plan.py index cbf979a66..f3c821e46 100644 --- a/graphql_compiler/schema_transformation/query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -1,7 +1,18 @@ # Copyright 2019-present Kensho Technologies, LLC. from copy import copy from dataclasses import dataclass -from typing import FrozenSet, List, Optional, Tuple, cast +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + List, + Mapping, + Optional, + Set, + Tuple, + cast, +) from graphql import print_ast from graphql.language.ast import ( @@ -21,20 +32,22 @@ from ..ast_manipulation import get_only_query_definition from ..exceptions import GraphQLValidationError from ..schema import FilterDirective, OutputDirective -from .utils import get_query_runtime_arguments - from .split_query import AstType, SubQueryNode +from .utils import get_query_runtime_arguments @dataclass class SubQueryPlan: """Query plan for a part of a larger query over a single schema.""" + # Unique identifier for this sub-plan. + plan_id: int + # Representing a piece of the overall query with directives added. query_ast: DocumentNode # Identifier for the schema that this query piece targets. - schema_id: Optional[str] + schema_id: str # The query that the current query depends on, or None if the current query does not # depend on another. @@ -53,6 +66,9 @@ class OutputJoinDescriptor: # should be made. output_names: Tuple[str, str] + # SubQueryPlan, the sub-plan node for which the join happens between it and its parent sub-plan. + child_query_plan: SubQueryPlan + @dataclass(frozen=True) class QueryPlanDescriptor: @@ -93,6 +109,11 @@ def make_query_plan( """ output_join_descriptors: List[OutputJoinDescriptor] = [] + if root_sub_query_node.schema_id is None: + raise AssertionError( + "Unreachable code reached. The schema id of root_sub_query_node " + f'"{root_sub_query_node.query_ast}" has not been determined.' + ) root_sub_query_plan = SubQueryPlan( plan_id=0, query_ast=root_sub_query_node.query_ast, @@ -114,7 +135,7 @@ def _make_query_plan_recursive( sub_query_node: SubQueryNode, sub_query_plan: SubQueryPlan, output_join_descriptors: List[OutputJoinDescriptor], - next_plan_id: int + next_plan_id: int, ) -> None: """Recursively copy the structure of sub_query_node onto sub_query_plan. @@ -128,7 +149,7 @@ def _make_query_plan_recursive( sub_query_plan: SubQueryPlan, whose list of child query plans and query AST are modified. output_join_descriptors: describing which outputs should be joined and how. - next_plan_id: int, the next available plan ID to use. IDs at and above this number are free. + next_plan_id: the next available plan ID to use. IDs at and above this number are free. """ # Iterate through child connections of query node for child_query_connection in sub_query_node.child_query_connections: @@ -153,6 +174,11 @@ def _make_query_plan_recursive( else: new_child_query_ast = DocumentNode(definitions=[child_query_type_with_filter]) + if child_sub_query_node.schema_id is None: + raise AssertionError( + "Unreachable code reached. The schema id of the child_sub_query_node " + f'"{child_sub_query_node.query_ast}" has not been determined.' + ) # Create new SubQueryPlan for child child_sub_query_plan = SubQueryPlan( plan_id=next_plan_id, @@ -308,22 +334,37 @@ def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_dep query_plan_strings.append(line_separation) query_str = 'Execute subplan ID {} in schema named "{}":\n'.format( - query_plan.plan_id, query_plan.schema_id) + query_plan.plan_id, query_plan.schema_id + ) query_str += print_ast(query_plan.query_ast) query_str = query_str.replace("\n", line_separation) query_plan_strings.append(query_str) query_plan_strings.append("\n\nJoin together outputs as follows:\n") - query_plan_strings.append('\n'.join([ - " ".join([ - str(descriptor.output_names), - "between subplan IDs", - str([ - descriptor.child_query_plan.parent_query_plan.plan_id, - descriptor.child_query_plan.plan_id - ])]) - for descriptor in query_plan_descriptor.output_join_descriptors - ])) + for descriptor in query_plan_descriptor.output_join_descriptors: + if descriptor.child_query_plan.parent_query_plan is None: + raise AssertionError( + f"Invalid join descriptor. The parent plan was not set for child_query_plan " + f"{descriptor.child_query_plan}." + ) + query_plan_strings.append( + "\n".join( + [ + " ".join( + [ + str(descriptor.output_names), + "between subplan IDs", + str( + [ + descriptor.child_query_plan.parent_query_plan.plan_id, + descriptor.child_query_plan.plan_id, + ] + ), + ] + ) + ] + ) + ) query_plan_strings.append("\n\nRemove the following outputs at the end: ") query_plan_strings.append(str(query_plan_descriptor.intermediate_output_names) + "\n") @@ -333,7 +374,9 @@ def print_query_plan(query_plan_descriptor: QueryPlanDescriptor, indentation_dep def _get_plan_and_depth_in_dfs_order(query_plan: SubQueryPlan) -> List[Tuple[SubQueryPlan, int]]: """Return a list of topologically sorted (query plan, depth) tuples.""" - def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth): + def _get_plan_and_depth_in_dfs_order_helper( + query_plan: SubQueryPlan, depth: int + ) -> List[Tuple[SubQueryPlan, int]]: plan_and_depth_in_dfs_order = [(query_plan, depth)] for child_query_plan in query_plan.child_query_plans: plan_and_depth_in_dfs_order.extend( @@ -344,15 +387,27 @@ def _get_plan_and_depth_in_dfs_order_helper(query_plan, depth): return _get_plan_and_depth_in_dfs_order_helper(query_plan, 0) -def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query_args): +def execute_query_plan( + schema_id_to_execution_func: Dict[ + str, Callable[[str, Mapping[str, Any]], List[Dict[str, Any]]] + ], + query_plan_descriptor: QueryPlanDescriptor, + query_args: Dict[str, Any], +) -> List[Dict[str, Any]]: """Execute the given query plan and return the produced results.""" result_components_by_plan_id = {} - stitching_output_names_by_parent_plan_id = dict() + stitching_output_names_by_parent_plan_id: Dict[int, List[Tuple[str, str]]] = dict() for join_descriptor in query_plan_descriptor.output_join_descriptors: + if join_descriptor.child_query_plan.parent_query_plan is None: + raise AssertionError( + f"Invalid join descriptor. The parent plan was not set for child_query_plan " + f"{join_descriptor.child_query_plan}." + ) parent_plan_id = join_descriptor.child_query_plan.parent_query_plan.plan_id stitching_output_names_by_parent_plan_id.setdefault(parent_plan_id, []).append( - join_descriptor.output_names) + join_descriptor.output_names + ) full_query_args = dict(query_args) @@ -361,6 +416,11 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query for query_plan, _ in plan_and_depth: plan_id = query_plan.plan_id schema_id = query_plan.schema_id + if schema_id is None: + raise AssertionError( + f'Unreachable code reached. The schema id of query piece "{query_plan.query_ast}" ' + f"has not been determined." + ) subquery_graphql = print_ast(query_plan.query_ast) @@ -384,7 +444,7 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query output_name for output_name, _ in stitching_output_names_by_parent_plan_id.get(plan_id, []) } - child_extra_output_values = { + child_extra_output_values: Dict[str, Set[Any]] = { # Make sure we deduplicate the values -- there's no point in running subqueries # with duplicated runtime argument values. output_name: set() @@ -407,39 +467,53 @@ def execute_query_plan(schema_id_to_execution_func, query_plan_descriptor, query full_query_args.update(new_query_args) join_indexes_by_plan_id = _make_join_indexes( - query_plan_descriptor, result_components_by_plan_id) + query_plan_descriptor, result_components_by_plan_id + ) joined_results = _join_results( - result_components_by_plan_id, join_indexes_by_plan_id, + result_components_by_plan_id, + join_indexes_by_plan_id, result_components_by_plan_id[query_plan_descriptor.root_sub_query_plan.plan_id], - query_plan_descriptor.output_join_descriptors) + query_plan_descriptor.output_join_descriptors, + ) return _drop_intermediate_outputs( - query_plan_descriptor.intermediate_output_names, joined_results) + query_plan_descriptor.intermediate_output_names, joined_results + ) -def _make_join_indexes(query_plan_descriptor, result_components_by_plan_id): +def _make_join_indexes( + query_plan_descriptor: QueryPlanDescriptor, + result_components_by_plan_id: Dict[int, List[Dict[str, Any]]], +) -> Dict[int, Dict[str, List[int]]]: """Return a dict from child plan id to a join index between its and its parents' rows.""" - join_indexes_by_plan_id = dict() + join_indexes_by_plan_id: Dict[int, Dict[str, List[int]]] = dict() for join_descriptor in query_plan_descriptor.output_join_descriptors: child_plan_id = join_descriptor.child_query_plan.plan_id _, child_output_name = join_descriptor.output_names if child_plan_id in join_indexes_by_plan_id: - raise AssertionError('Unreachable code reached: {} {} {}' - .format(child_plan_id, join_indexes_by_plan_id, - query_plan_descriptor.output_join_descriptors)) + raise AssertionError( + "Unreachable code reached: {} {} {}".format( + child_plan_id, + join_indexes_by_plan_id, + query_plan_descriptor.output_join_descriptors, + ) + ) join_indexes_by_plan_id[child_plan_id] = _make_join_index_for_output( - result_components_by_plan_id[child_plan_id], child_output_name) + result_components_by_plan_id[child_plan_id], child_output_name + ) return join_indexes_by_plan_id -def _make_join_index_for_output(results, join_output_name): +def _make_join_index_for_output( + results: List[Dict[str, Any]], join_output_name: str +) -> Dict[str, List[int]]: """Return a dict of each value of the join column to a list of row indexes where it appears.""" - join_index = {} + join_index: Dict[str, List[int]] = {} for row_index, row in enumerate(results): join_value = row[join_output_name] join_index.setdefault(join_value, []).append(row_index) @@ -447,8 +521,12 @@ def _make_join_index_for_output(results, join_output_name): return join_index -def _join_results(result_components_by_plan_id, join_indexes_by_plan_id, - current_results, join_descriptors): +def _join_results( + result_components_by_plan_id: Dict[int, List[Dict[str, Any]]], + join_indexes_by_plan_id: Dict[int, Dict[str, List[int]]], + current_results: List[Dict[str, Any]], + join_descriptors: List[OutputJoinDescriptor], +) -> List[Dict[str, Any]]: """Return the merged results across all subplans using the calculated join indexes.""" if len(join_descriptors) == 0: # No further joining to be done! @@ -474,19 +552,23 @@ def _join_results(result_components_by_plan_id, join_indexes_by_plan_id, joining_row = joining_results[join_matched_index] next_results.append(dict(current_row, **joining_row)) - return _join_results(result_components_by_plan_id, join_indexes_by_plan_id, - next_results, remaining_join_descriptors) + return _join_results( + result_components_by_plan_id, + join_indexes_by_plan_id, + next_results, + remaining_join_descriptors, + ) -def _drop_intermediate_outputs(columns_to_drop, results): +def _drop_intermediate_outputs( + columns_to_drop: FrozenSet[str], results: List[Dict[str, Any]] +) -> List[Dict[str, Any]]: """Return the provided results with the specified column names dropped.""" processed_results = [] for row in results: - processed_results.append({ - key: value - for key, value in row.items() - if key not in columns_to_drop - }) + processed_results.append( + {key: value for key, value in row.items() if key not in columns_to_drop} + ) return processed_results diff --git a/graphql_compiler/schema_transformation/utils.py b/graphql_compiler/schema_transformation/utils.py index 3c874ad72..4ac40e967 100644 --- a/graphql_compiler/schema_transformation/utils.py +++ b/graphql_compiler/schema_transformation/utils.py @@ -13,6 +13,7 @@ FragmentSpreadNode, InlineFragmentNode, InterfaceTypeDefinitionNode, + ListValueNode, NamedTypeNode, NameNode, Node, @@ -20,6 +21,7 @@ ScalarTypeDefinitionNode, SelectionNode, SelectionSetNode, + StringValueNode, UnionTypeDefinitionNode, ) from graphql.language.visitor import Visitor, visit @@ -30,10 +32,12 @@ from ..ast_manipulation import get_ast_with_non_null_and_list_stripped from ..compiler.helpers import ( - get_parameter_name, get_uniquely_named_objects_by_name, is_runtime_parameter + get_parameter_name, + get_uniquely_named_objects_by_name, + is_runtime_parameter, ) from ..exceptions import GraphQLError, GraphQLValidationError -from ..schema import FilterDirective, OptionalDirective, OutputDirective, RecurseDirective +from ..schema import FilterDirective, OptionalDirective, OutputDirective class SchemaTransformError(GraphQLError): @@ -695,7 +699,7 @@ class CheckQueryIsValidToSplitVisitor(Visitor): ) ) - def __init__(self, strict): + def __init__(self, strict: bool): """Initialize the visitor with the appropriate strictness setting.""" super(CheckQueryIsValidToSplitVisitor, self).__init__() self.strict = strict @@ -769,7 +773,7 @@ def enter_selection_set( def check_query_is_valid_to_split( - schema: GraphQLSchema, query_ast: DocumentNode, strict=True + schema: GraphQLSchema, query_ast: DocumentNode, strict: bool = True ) -> None: """Check the query is valid for splitting. @@ -801,30 +805,40 @@ def check_query_is_valid_to_split( class QueryRuntimeArgumentsVisitor(Visitor): """Visitor that collects runtime argument names from @filter directives.""" - def __init__(self): + def __init__(self) -> None: """Initialize the visitor.""" super(QueryRuntimeArgumentsVisitor, self).__init__() - self.runtime_arguments = set() + self.runtime_arguments: Set[str] = set() - def enter_Directive(self, node, *args): + def enter_directive( + self, node: DirectiveNode, key: Any, parent: Any, path: List[Any], ancestors: List[Any] + ) -> None: """Check that the directive is supported.""" if node.name.value != FilterDirective.name: return directive_arguments = get_uniquely_named_objects_by_name(node.arguments) - entry_names = [ - list_element.value - for list_element in directive_arguments['value'].value.values - ] + directive_argument_list = directive_arguments["value"].value + if not isinstance(directive_argument_list, ListValueNode): + raise AssertionError( + "Unreachable code reached. Expected directive_argument_list to be of type " + f"ListValueNode, but was of type {type(directive_argument_list)}." + ) + entry_names: List[str] = [] + for list_element in directive_argument_list.values: + if not isinstance(list_element, StringValueNode): + raise AssertionError( + f"Expected directive arguments to be StringValueNode, but received " + f"{list_element} with type {type(list_element)}." + ) + entry_names.append(list_element.value) self.runtime_arguments.update( - get_parameter_name(name) - for name in entry_names - if is_runtime_parameter(name) + get_parameter_name(name) for name in entry_names if is_runtime_parameter(name) ) -def get_query_runtime_arguments(query_ast): +def get_query_runtime_arguments(query_ast: DocumentNode) -> Set[str]: """Return a set containing the names of the runtime arguments required by the query.""" visitor = QueryRuntimeArgumentsVisitor() visit(query_ast, visitor) From bb5c35564d472f3673ea8b19b1f8b3052b647f86 Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Mon, 25 Jan 2021 09:19:17 -0500 Subject: [PATCH 6/8] lint --- .../schema_transformation/query_plan.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/graphql_compiler/schema_transformation/query_plan.py b/graphql_compiler/schema_transformation/query_plan.py index f3c821e46..4fe45956f 100644 --- a/graphql_compiler/schema_transformation/query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -1,18 +1,7 @@ # Copyright 2019-present Kensho Technologies, LLC. from copy import copy from dataclasses import dataclass -from typing import ( - Any, - Callable, - Dict, - FrozenSet, - List, - Mapping, - Optional, - Set, - Tuple, - cast, -) +from typing import Any, Callable, Dict, FrozenSet, List, Mapping, Optional, Set, Tuple, cast from graphql import print_ast from graphql.language.ast import ( From ac5cc521a327d495fcbb2c567e10762033eca49c Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Mon, 25 Jan 2021 09:51:11 -0500 Subject: [PATCH 7/8] remove unused variable --- graphql_compiler/schema_transformation/query_plan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphql_compiler/schema_transformation/query_plan.py b/graphql_compiler/schema_transformation/query_plan.py index 4fe45956f..22330d1b1 100644 --- a/graphql_compiler/schema_transformation/query_plan.py +++ b/graphql_compiler/schema_transformation/query_plan.py @@ -529,7 +529,7 @@ def _join_results( join_plan_id = next_join_descriptor.child_query_plan.plan_id join_index = join_indexes_by_plan_id[join_plan_id] joining_results = result_components_by_plan_id[join_plan_id] - join_from_key, join_to_key = next_join_descriptor.output_names + join_from_key, _ = next_join_descriptor.output_names for current_row in current_results: join_value = current_row[join_from_key] From 3551b46dfeb15023d51a07513ac7eadc8cfca6e8 Mon Sep 17 00:00:00 2001 From: Selene Chew Date: Mon, 25 Jan 2021 10:07:11 -0500 Subject: [PATCH 8/8] typing copilot tighten --- mypy.ini | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mypy.ini b/mypy.ini index ad086efcb..b04b52c58 100644 --- a/mypy.ini +++ b/mypy.ini @@ -190,9 +190,6 @@ check_untyped_defs = False [mypy-graphql_compiler.schema_transformation.*] disallow_untyped_calls = False -[mypy-graphql_compiler.schema_transformation.make_query_plan.*] -disallow_untyped_defs = False - [mypy-graphql_compiler.schema_transformation.split_query.*] disallow_incomplete_defs = False disallow_untyped_defs = False @@ -344,9 +341,6 @@ disallow_untyped_calls = False # Third-party module rule relaxations -[mypy-arrow.*] -ignore_missing_imports = True - [mypy-funcy.*] ignore_missing_imports = True