diff --git a/executor.go b/executor.go index d36db111..a7a7535d 100644 --- a/executor.go +++ b/executor.go @@ -821,6 +821,12 @@ func defaultResolveTypeFn(p ResolveTypeParams, abstractType Abstract) *Object { return nil } +// FieldResolver is used in DefaultResolveFn when the the source value implements this interface. +type FieldResolver interface { + // Resolve resolves the value for the given ResolveParams. It has the same semantics as FieldResolveFn. + Resolve(p ResolveParams) (interface{}, error) +} + // defaultResolveFn If a resolve function is not given, then a default resolve behavior is used // which takes the property of the source object of the same name as the field // and returns it as the result, or if it's a function, returns the result @@ -834,6 +840,12 @@ func DefaultResolveFn(p ResolveParams) (interface{}, error) { if !sourceVal.IsValid() { return nil, nil } + + // Check if value implements 'Resolver' interface + if resolver, ok := sourceVal.Interface().(FieldResolver); ok { + return resolver.Resolve(p) + } + if sourceVal.Type().Kind() == reflect.Struct { for i := 0; i < sourceVal.NumField(); i++ { valueField := sourceVal.Field(i) diff --git a/executor_test.go b/executor_test.go index 2c898f88..954d6d30 100644 --- a/executor_test.go +++ b/executor_test.go @@ -1691,6 +1691,70 @@ func TestGraphqlTag(t *testing.T) { } } +func TestFieldResolver(t *testing.T) { + typeObjectType := graphql.NewObject(graphql.ObjectConfig{ + Name: "Type", + Fields: graphql.Fields{ + "fooBar": &graphql.Field{Type: graphql.String}, + }, + }) + var baz = &graphql.Field{ + Type: typeObjectType, + Description: "typeObjectType", + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return testCustomResolver{}, nil + }, + } + var bazPtr = &graphql.Field{ + Type: typeObjectType, + Description: "typeObjectType", + Resolve: func(p graphql.ResolveParams) (interface{}, error) { + return &testCustomResolver{}, nil + }, + } + q := graphql.NewObject(graphql.ObjectConfig{ + Name: "Query", + Fields: graphql.Fields{ + "baz": baz, + "bazPtr": bazPtr, + }, + }) + schema, err := graphql.NewSchema(graphql.SchemaConfig{ + Query: q, + }) + if err != nil { + t.Fatalf("unexpected error, got: %v", err) + } + query := "{ baz { fooBar }, bazPtr { fooBar } }" + result := graphql.Do(graphql.Params{ + Schema: schema, + RequestString: query, + }) + if len(result.Errors) != 0 { + t.Fatalf("wrong result, unexpected errors: %+v", result.Errors) + } + expectedData := map[string]interface{}{ + "baz": map[string]interface{}{ + "fooBar": "foo bar value", + }, + "bazPtr": map[string]interface{}{ + "fooBar": "foo bar value", + }, + } + if !reflect.DeepEqual(result.Data, expectedData) { + t.Fatalf("unexpected result, got: %+v, expected: %+v", result.Data, expectedData) + } +} + +type testCustomResolver struct{} + +func (r testCustomResolver) Resolve(p graphql.ResolveParams) (interface{}, error) { + if p.Info.FieldName == "fooBar" { + return "foo bar value", nil + } + return "", errors.New("invalid field " + p.Info.FieldName) +} + func TestContextDeadline(t *testing.T) { timeout := time.Millisecond * time.Duration(100) acceptableDelay := time.Millisecond * time.Duration(10)