From ad1ad165ee0f631a3e3077d47e8b1efae7e17a68 Mon Sep 17 00:00:00 2001 From: xxlv Date: Wed, 3 Jul 2024 10:19:13 +0800 Subject: [PATCH 1/4] Add support for null values in input types --- language/ast/values.go | 34 ++++++++++++++++++++++++ language/kinds/kinds.go | 1 + language/lexer/lexer.go | 2 ++ language/parser/parser.go | 11 +++++++- language/parser/parser_test.go | 9 ------- language/printer/printer.go | 9 ++++++- language/printer/printer_test.go | 14 ++++++++++ rules.go | 7 ++++- rules_arguments_of_correct_type_test.go | 35 +++++++++++++++++++++++++ scalars.go | 4 +++ scalars_test.go | 12 +++++++++ testutil/rules_test_harness.go | 2 ++ validator_test.go | 1 + 13 files changed, 129 insertions(+), 12 deletions(-) diff --git a/language/ast/values.go b/language/ast/values.go index 6c3c88640..f02e1c160 100644 --- a/language/ast/values.go +++ b/language/ast/values.go @@ -19,6 +19,7 @@ var _ Value = (*BooleanValue)(nil) var _ Value = (*EnumValue)(nil) var _ Value = (*ListValue)(nil) var _ Value = (*ObjectValue)(nil) +var _ Value = (*NullValue)(nil) // Variable implements Node, Value type Variable struct { @@ -202,6 +203,39 @@ func (v *EnumValue) GetValue() interface{} { return v.Value } +// NullValue represents the GraphQL null value. +// +// It is used to support passing null as an input value. +// +// Reference: https://spec.graphql.org/October2021/#sec-Null-Value +type NullValue struct { + Kind string + Loc *Location + Value interface{} +} + +func NewNullValue(v *NullValue) *NullValue { + if v == nil { + v = &NullValue{} + } + return &NullValue{ + Kind: kinds.NullValue, + Loc: v.Loc, + Value: nil, + } +} +func (n *NullValue) GetKind() string { + return n.Kind +} + +func (n *NullValue) GetLoc() *Location { + return n.Loc +} + +func (n *NullValue) GetValue() interface{} { + return n.Value +} + // ListValue implements Node, Value type ListValue struct { Kind string diff --git a/language/kinds/kinds.go b/language/kinds/kinds.go index 40bc994eb..d5be9ed89 100644 --- a/language/kinds/kinds.go +++ b/language/kinds/kinds.go @@ -27,6 +27,7 @@ const ( ListValue = "ListValue" ObjectValue = "ObjectValue" ObjectField = "ObjectField" + NullValue = "NullValue" // Directives Directive = "Directive" diff --git a/language/lexer/lexer.go b/language/lexer/lexer.go index 1988c5fdc..a50c335fc 100644 --- a/language/lexer/lexer.go +++ b/language/lexer/lexer.go @@ -34,6 +34,7 @@ const ( STRING BLOCK_STRING AMP + NULL ) var tokenDescription = map[TokenKind]string{ @@ -57,6 +58,7 @@ var tokenDescription = map[TokenKind]string{ STRING: "String", BLOCK_STRING: "BlockString", AMP: "&", + NULL: "null", } func (kind TokenKind) String() string { diff --git a/language/parser/parser.go b/language/parser/parser.go index 4ae3dc335..0e8bc74ae 100644 --- a/language/parser/parser.go +++ b/language/parser/parser.go @@ -614,6 +614,14 @@ func parseValueLiteral(parser *Parser, isConst bool) (ast.Value, error) { Value: token.Value, Loc: loc(parser, token.Start), }), nil + } else { + // If the value literal in the GraphQL input is `null`, converts it into a NullValue AST node. + if err := advance(parser); err != nil { + return nil, err + } + return ast.NewNullValue(&ast.NullValue{ + Loc: loc(parser, token.Start), + }), nil } case lexer.DOLLAR: if !isConst { @@ -1562,7 +1570,8 @@ func unexpectedEmpty(parser *Parser, beginLoc int, openKind, closeKind lexer.Tok return gqlerrors.NewSyntaxError(parser.Source, beginLoc, description) } -// Returns list of parse nodes, determined by +// Returns list of parse nodes, determined by +// // the parseFn. This list begins with a lex token of openKind // and ends with a lex token of closeKind. Advances the parser // to the next lex token after the closing token. diff --git a/language/parser/parser_test.go b/language/parser/parser_test.go index 8f0e0715d..062dd5772 100644 --- a/language/parser/parser_test.go +++ b/language/parser/parser_test.go @@ -183,15 +183,6 @@ func TestDoesNotAcceptFragmentsSpreadOfOn(t *testing.T) { testErrorMessage(t, test) } -func TestDoesNotAllowNullAsValue(t *testing.T) { - test := errorMessageTest{ - `{ fieldWithNullableStringInput(input: null) }'`, - `Syntax Error GraphQL (1:39) Unexpected Name "null"`, - false, - } - testErrorMessage(t, test) -} - func TestParsesMultiByteCharacters_Unicode(t *testing.T) { doc := ` diff --git a/language/printer/printer.go b/language/printer/printer.go index ac771ba60..43ba45c40 100644 --- a/language/printer/printer.go +++ b/language/printer/printer.go @@ -8,6 +8,7 @@ import ( "reflect" "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/lexer" "github.com/graphql-go/graphql/language/visitor" ) @@ -472,7 +473,13 @@ var printDocASTReducer = map[string]visitor.VisitFunc{ } return visitor.ActionNoChange, nil }, - + "NullValue": func(p visitor.VisitFuncParams) (string, interface{}) { + switch p.Node.(type) { + case *ast.NullValue: + return visitor.ActionUpdate, lexer.NULL.String() + } + return visitor.ActionNoChange, nil + }, // Type System Definitions "SchemaDefinition": func(p visitor.VisitFuncParams) (string, interface{}) { switch node := p.Node.(type) { diff --git a/language/printer/printer_test.go b/language/printer/printer_test.go index b6d7de7d6..bf06c59f4 100644 --- a/language/printer/printer_test.go +++ b/language/printer/printer_test.go @@ -200,3 +200,17 @@ func TestPrinter_CorrectlyPrintsStringArgumentsWithProperQuoting(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, results)) } } + +func TestPrinter_CorrectlyPrintsNullArgumentsWithProperQuoting(t *testing.T) { + queryAst := `query { foo(nullArg: null) }` + expected := `{ + foo(nullArg: null) +} +` + astDoc := parse(t, queryAst) + results := printer.Print(astDoc) + + if !reflect.DeepEqual(expected, results) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, results)) + } +} diff --git a/rules.go b/rules.go index ae0c75b9d..95b918a4f 100644 --- a/rules.go +++ b/rules.go @@ -1735,6 +1735,11 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { if valueAST.GetKind() == kinds.Variable { return true, nil } + // Supplying a nullable variable type to a non-null input type is considered invalid. + // nullValue is only valid for nullable input types. + if valueAST.GetKind() == kinds.NullValue { + return true, nil + } } switch ttype := ttype.(type) { case *NonNull: @@ -1742,7 +1747,7 @@ func isValidLiteralValue(ttype Input, valueAST ast.Value) (bool, []string) { if e := ttype.Error(); e != nil { return false, []string{e.Error()} } - if valueAST == nil { + if valueAST == nil || valueAST.GetKind() == kinds.NullValue { if ttype.OfType.Name() != "" { return false, []string{fmt.Sprintf(`Expected "%v!", found null.`, ttype.OfType.Name())} } diff --git a/rules_arguments_of_correct_type_test.go b/rules_arguments_of_correct_type_test.go index ecd4bea4f..b27b78c11 100644 --- a/rules_arguments_of_correct_type_test.go +++ b/rules_arguments_of_correct_type_test.go @@ -8,6 +8,41 @@ import ( "github.com/graphql-go/graphql/testutil" ) +func TestValidate_ArgValuesOfCorrectType_ValidValue_GoodNullValue(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.ArgumentsOfCorrectTypeRule, ` + { + complicatedArgs { + intArgField(intArg: null) + } + } + `) +} + +func TestValidator_NonNullArgsUsingNullValue(t *testing.T) { + testutil.ExpectFailsRule(t, graphql.ArgumentsOfCorrectTypeRule, ` + { + complicatedArgs { + nonNullIntArgField(nonNullIntArg: null) + } + } + `, []gqlerrors.FormattedError{ + testutil.RuleError( + "Argument \"nonNullIntArg\" has invalid value null.\nExpected \"Int!\", found null.", + 4, 47, + ), + }) +} + +func TestValidator_NullArgsUsingNullValue(t *testing.T) { + testutil.ExpectPassesRule(t, graphql.ArgumentsOfCorrectTypeRule, ` + { + complicatedArgs { + stringArgField(stringArg: null) + } + } + `) +} + func TestValidate_ArgValuesOfCorrectType_ValidValue_GoodIntValue(t *testing.T) { testutil.ExpectPassesRule(t, graphql.ArgumentsOfCorrectTypeRule, ` { diff --git a/scalars.go b/scalars.go index 45479b545..68fdbf394 100644 --- a/scalars.go +++ b/scalars.go @@ -163,6 +163,7 @@ var Int = NewScalar(ScalarConfig{ return intValue } } + return nil }, }) @@ -332,6 +333,9 @@ var String = NewScalar(ScalarConfig{ }) func coerceBool(value interface{}) interface{} { + if value == nil { + return nil + } switch value := value.(type) { case bool: return value diff --git a/scalars_test.go b/scalars_test.go index 26987e5ce..aec04ffdf 100644 --- a/scalars_test.go +++ b/scalars_test.go @@ -240,6 +240,10 @@ func TestCoerceInt(t *testing.T) { in: make(map[string]interface{}), want: nil, }, + { + in: nil, + want: nil, + }, } for i, tt := range tests { @@ -438,6 +442,10 @@ func TestCoerceFloat(t *testing.T) { in: make(map[string]interface{}), want: nil, }, + { + in: nil, + want: nil, + }, } for i, tt := range tests { @@ -740,6 +748,10 @@ func TestCoerceBool(t *testing.T) { in: make(map[string]interface{}), want: false, }, + { + in: nil, + want: nil, + }, } for i, tt := range tests { diff --git a/testutil/rules_test_harness.go b/testutil/rules_test_harness.go index 384f447e7..ba2d87e8a 100644 --- a/testutil/rules_test_harness.go +++ b/testutil/rules_test_harness.go @@ -563,6 +563,7 @@ func expectValidRule(t *testing.T, schema *graphql.Schema, rules []graphql.Valid } func expectInvalidRule(t *testing.T, schema *graphql.Schema, rules []graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { + t.Helper() source := source.NewSource(&source.Source{ Body: []byte(queryString), }) @@ -595,6 +596,7 @@ func ExpectPassesRule(t *testing.T, rule graphql.ValidationRuleFn, queryString s expectValidRule(t, TestSchema, []graphql.ValidationRuleFn{rule}, queryString) } func ExpectFailsRule(t *testing.T, rule graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { + t.Helper() expectInvalidRule(t, TestSchema, []graphql.ValidationRuleFn{rule}, queryString, expectedErrors) } func ExpectFailsRuleWithSchema(t *testing.T, schema *graphql.Schema, rule graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { diff --git a/validator_test.go b/validator_test.go index 6eaf00052..b390cbef4 100644 --- a/validator_test.go +++ b/validator_test.go @@ -45,6 +45,7 @@ func TestValidator_SupportsFullValidation_ValidatesQueries(t *testing.T) { `) } + // NOTE: experimental func TestValidator_SupportsFullValidation_ValidatesUsingACustomTypeInfo(t *testing.T) { From d0c80278c4b8b6cb1ff63b6b3c5f67e67af5834b Mon Sep 17 00:00:00 2001 From: xxlv Date: Wed, 3 Jul 2024 10:41:57 +0800 Subject: [PATCH 2/4] Fix typo --- language/printer/printer_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/language/printer/printer_test.go b/language/printer/printer_test.go index bf06c59f4..5b8e064f8 100644 --- a/language/printer/printer_test.go +++ b/language/printer/printer_test.go @@ -201,7 +201,7 @@ func TestPrinter_CorrectlyPrintsStringArgumentsWithProperQuoting(t *testing.T) { } } -func TestPrinter_CorrectlyPrintsNullArgumentsWithProperQuoting(t *testing.T) { +func TestPrinter_CorrectlyPrintsNullArguments(t *testing.T) { queryAst := `query { foo(nullArg: null) }` expected := `{ foo(nullArg: null) From c5ca6e6f198f083125b0d51c117bb2dd0111b050 Mon Sep 17 00:00:00 2001 From: xxlv Date: Wed, 3 Jul 2024 10:50:14 +0800 Subject: [PATCH 3/4] Remove t.helper() for old go version --- testutil/rules_test_harness.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/testutil/rules_test_harness.go b/testutil/rules_test_harness.go index ba2d87e8a..384f447e7 100644 --- a/testutil/rules_test_harness.go +++ b/testutil/rules_test_harness.go @@ -563,7 +563,6 @@ func expectValidRule(t *testing.T, schema *graphql.Schema, rules []graphql.Valid } func expectInvalidRule(t *testing.T, schema *graphql.Schema, rules []graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { - t.Helper() source := source.NewSource(&source.Source{ Body: []byte(queryString), }) @@ -596,7 +595,6 @@ func ExpectPassesRule(t *testing.T, rule graphql.ValidationRuleFn, queryString s expectValidRule(t, TestSchema, []graphql.ValidationRuleFn{rule}, queryString) } func ExpectFailsRule(t *testing.T, rule graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { - t.Helper() expectInvalidRule(t, TestSchema, []graphql.ValidationRuleFn{rule}, queryString, expectedErrors) } func ExpectFailsRuleWithSchema(t *testing.T, schema *graphql.Schema, rule graphql.ValidationRuleFn, queryString string, expectedErrors []gqlerrors.FormattedError) { From 72d7caaa6197c0348c9455e4d75c8b94cb0ec198 Mon Sep 17 00:00:00 2001 From: xxlv Date: Fri, 5 Jul 2024 12:50:20 +0800 Subject: [PATCH 4/4] Fix list null --- executor_test.go | 76 +++++++++++++++++++++++++++++++++++++++++++++++ values.go | 30 +++++++++++++++---- variables_test.go | 33 ++++++++++++++++++++ 3 files changed, 133 insertions(+), 6 deletions(-) diff --git a/executor_test.go b/executor_test.go index 856aadf3e..236fa3bdd 100644 --- a/executor_test.go +++ b/executor_test.go @@ -403,6 +403,82 @@ func TestThreadsSourceCorrectly(t *testing.T) { } } +func TestCorrectlyListArgumentsWithNull(t *testing.T) { + query := ` + query Example { + b(listStringArg: null, listBoolArg: [true,false,null],listIntArg:[123,null,12],listStringNonNullArg:[null]) + } + ` + var resolvedArgs map[string]interface{} + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: graphql.NewObject(graphql.ObjectConfig{ + Name: "Type", + Fields: graphql.Fields{ + "b": &graphql.Field{ + Args: graphql.FieldConfigArgument{ + "listStringArg": &graphql.ArgumentConfig{ + Type: graphql.NewList(graphql.String), + }, + "listStringNonNullArg": &graphql.ArgumentConfig{ + Type: graphql.NewNonNull(graphql.NewList(graphql.String)), + }, + "listBoolArg": &graphql.ArgumentConfig{ + Type: graphql.NewList(graphql.Boolean), + }, + "listIntArg": &graphql.ArgumentConfig{ + Type: graphql.NewList(graphql.Int), + }, + }, + Type: graphql.String, + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + resolvedArgs = p.Args + return resolvedArgs, nil + }, + }, + }, + }), + }) + if err != nil { + t.Fatalf("Error in schema %v", err.Error()) + } + ast := testutil.TestParse(t, query) + + ep := graphql.ExecuteParams{ + Schema: schema, + AST: ast, + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + tests := []struct { + key string + expected interface{} + }{ + { + "listStringArg", nil, + }, + + { + "listStringNonNullArg", []interface{}{nil}, + }, + + { + "listBoolArg", []interface{}{true, false, nil}, + }, + + { + "listIntArg", []interface{}{123, nil, 12}, + }, + } + for _, tt := range tests { + t.Run(fmt.Sprintf("TestCorrectlyListArgumentsWithNull_%s", tt.key), func(t *testing.T) { + if !reflect.DeepEqual(resolvedArgs[tt.key], tt.expected) { + t.Fatalf("Expected args.%s to equal `%v`, got `%v`", tt.key, tt.expected, resolvedArgs[tt.key]) + } + }) + } +} func TestCorrectlyThreadsArguments(t *testing.T) { query := ` diff --git a/values.go b/values.go index 06c08af6e..c1b20c7ec 100644 --- a/values.go +++ b/values.go @@ -57,16 +57,34 @@ func getArgumentValues( if tmpValue, ok := argASTMap[argDef.PrivateName]; ok { value = tmpValue.Value } - if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) { - tmp = argDef.DefaultValue - } - if !isNullish(tmp) { - results[argDef.PrivateName] = tmp + // if ast value is NullValue, and keep args's key + if value != nil && value.GetKind() == kinds.NullValue { + results[argDef.PrivateName] = nil + } else { + if tmp = valueFromAST(value, argDef.Type, variableValues); isNullish(tmp) { + tmp = argDef.DefaultValue + } + if !isNullish(tmp) { + results[argDef.PrivateName] = tmp + } else { + if nullValueWithVairableProvided(value, argDef.PrivateName, variableValues) { + results[argDef.PrivateName] = nil + } + } } } return results } +func nullValueWithVairableProvided(valueAST ast.Value, key string, variables map[string]interface{}) bool { + if valueAST != nil && valueAST.GetKind() == kinds.Variable { + if _, ok := variables[key]; ok { + return true + } + } + return false +} + // Given a variable definition, and any value of input, return a value which // adheres to the variable definition, or throw an error. func getVariableValue(schema Schema, definitionAST *ast.VariableDefinition, input interface{}) (interface{}, error) { @@ -349,7 +367,7 @@ func isIterable(src interface{}) bool { * */ func valueFromAST(valueAST ast.Value, ttype Input, variables map[string]interface{}) interface{} { - if valueAST == nil { + if valueAST == nil || valueAST.GetKind() == kinds.NullValue { return nil } // precedence: value > type diff --git a/variables_test.go b/variables_test.go index 9dc430df1..54ed6c368 100644 --- a/variables_test.go +++ b/variables_test.go @@ -70,6 +70,9 @@ func inputResolved(p graphql.ResolveParams) (interface{}, error) { if !ok { return nil, nil } + if input == nil { + return nil, nil + } b, err := json.Marshal(input) if err != nil { return nil, nil @@ -960,6 +963,36 @@ func TestVariables_ListsAndNullability_AllowsListsToBeNull(t *testing.T) { t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) } } + +func TestVariables_ListsAndNullability_AllowsListsToBeNullWithMoreListValues(t *testing.T) { + doc := ` + query q($input: [String]) { + list(input: $input) + } + ` + params := map[string]interface{}{ + "input": []interface{}{nil, "ok", nil}, + } + + expected := &graphql.Result{ + Data: map[string]interface{}{ + "list": `[null,"ok",null]`, + }, + } + ast := testutil.TestParse(t, doc) + ep := graphql.ExecuteParams{ + Schema: variablesTestSchema, + AST: ast, + Args: params, + } + result := testutil.TestExecute(t, ep) + if len(result.Errors) > 0 { + t.Fatalf("wrong result, unexpected errors: %v", result.Errors) + } + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Unexpected result, Diff: %v", testutil.Diff(expected, result)) + } +} func TestVariables_ListsAndNullability_AllowsListsToContainValues(t *testing.T) { doc := ` query q($input: [String]) {