diff --git a/go.mod b/go.mod index 399b200d..c40b4e84 100644 --- a/go.mod +++ b/go.mod @@ -1 +1,3 @@ module github.com/graphql-go/graphql + +go 1.14 diff --git a/graphql.go b/graphql.go index 2b1f6a29..def7c86c 100644 --- a/graphql.go +++ b/graphql.go @@ -28,6 +28,12 @@ type Params struct { // one operation. OperationName string + // ValidationRules is for overriding rules of document validation. Default + // SpecifiedRules are ignored when specified this option other than nil. + // So it would be better that combining your rules with SpecifiedRules to + // fill this. + ValidationRules []ValidationRuleFn + // Context may be provided to pass application-specific per-request // information to resolve functions. Context context.Context @@ -84,7 +90,7 @@ func Do(p Params) *Result { } // validate document - validationResult := ValidateDocument(&p.Schema, AST, nil) + validationResult := ValidateDocument(&p.Schema, AST, p.ValidationRules) if !validationResult.IsValid { // run validation finish functions for extensions diff --git a/graphql_test.go b/graphql_test.go index 8b06a7b1..f370cfc7 100644 --- a/graphql_test.go +++ b/graphql_test.go @@ -6,6 +6,9 @@ import ( "testing" "github.com/graphql-go/graphql" + "github.com/graphql-go/graphql/language/ast" + "github.com/graphql-go/graphql/language/kinds" + "github.com/graphql-go/graphql/language/visitor" "github.com/graphql-go/graphql/testutil" ) @@ -268,3 +271,59 @@ func TestEmptyStringIsNotNull(t *testing.T) { t.Errorf("wrong result, query: %v, graphql result diff: %v", query, testutil.Diff(expected, result)) } } + +func TestQueryWithCustomRule(t *testing.T) { + // Test graphql.Do() with custom rule, it extracts query name from each + // Tests. + ruleN := len(graphql.SpecifiedRules) + rules := make([]graphql.ValidationRuleFn, ruleN+1) + copy(rules[:ruleN], graphql.SpecifiedRules) + + var ( + queryFound bool + queryName string + ) + rules[ruleN] = func(context *graphql.ValidationContext) *graphql.ValidationRuleInstance { + return &graphql.ValidationRuleInstance{ + VisitorOpts: &visitor.VisitorOptions{ + KindFuncMap: map[string]visitor.NamedVisitFuncs{ + kinds.OperationDefinition: { + Enter: func(p visitor.VisitFuncParams) (string, interface{}) { + od, ok := p.Node.(*ast.OperationDefinition) + if ok && od.Operation == "query" { + queryFound = true + if od.Name != nil { + queryName = od.Name.Value + } + } + return visitor.ActionNoChange, nil + }, + }, + }, + }, + } + } + + expectedNames := []string{ + "HeroNameQuery", + "HeroNameAndFriendsQuery", + "HumanByIdQuery", + } + + for i, test := range Tests { + queryFound, queryName = false, "" + params := graphql.Params{ + Schema: test.Schema, + RequestString: test.Query, + VariableValues: test.Variables, + ValidationRules: rules, + } + testGraphql(test, params, t) + if !queryFound { + t.Fatal("can't detect \"query\" operation by validation rule") + } + if queryName != expectedNames[i] { + t.Fatalf("unexpected query name: want=%s got=%s", queryName, expectedNames) + } + } +}