diff --git a/bind.go b/bind.go new file mode 100644 index 00000000..d8c8dfcd --- /dev/null +++ b/bind.go @@ -0,0 +1,251 @@ +package graphql + +import ( + "context" + "encoding/json" + "fmt" + "reflect" +) + +var ctxType = reflect.TypeOf((*context.Context)(nil)).Elem() +var errType = reflect.TypeOf((*error)(nil)).Elem() + +/* + Bind will create a Field around a function formatted a certain way, or any value. + + The input parameters can be, in any order, + - context.Context, or *context.Context (optional) + - An input struct, or pointer (optional) + + The output parameters can be, in any order, + - A primitive, an output struct, or pointer (required for use in schema) + - error (optional) + + Input or output types provided will be automatically bound using BindType. +*/ +func Bind(bindTo interface{}, additionalFields ...Fields) *Field { + combinedAdditionalFields := MergeFields(additionalFields...) + val := reflect.ValueOf(bindTo) + tipe := reflect.TypeOf(bindTo) + if tipe.Kind() == reflect.Func { + in := tipe.NumIn() + out := tipe.NumOut() + + var ctxIn *int + var inputIn *int + + var errOut *int + var outputOut *int + + queryArgs := FieldConfigArgument{} + + if in > 2 { + panic(fmt.Sprintf("Mismatch on number of inputs. Expected 0, 1, or 2. got %d.", tipe.NumIn())) + } + + if out > 2 { + panic(fmt.Sprintf("Mismatch on number of outputs. Expected 0, 1, or 2, got %d.", tipe.NumOut())) + } + + // inTypes := make([]reflect.Type, in) + // outTypes := make([]reflect.Type, out) + + for i := 0; i < in; i++ { + t := tipe.In(i) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + switch t { + case ctxType: + if ctxIn != nil { + panic(fmt.Sprintf("Unexpected multiple *context.Context inputs.")) + } + ctxIn = intP(i) + default: + if inputIn != nil { + panic(fmt.Sprintf("Unexpected multiple inputs.")) + } + inputType := tipe.In(i) + if inputType.Kind() == reflect.Ptr { + inputType = inputType.Elem() + } + inputFields := BindFields(reflect.New(inputType).Interface()) + for key, inputField := range inputFields { + queryArgs[key] = &ArgumentConfig{ + Type: inputField.Type, + } + } + + inputIn = intP(i) + } + } + + for i := 0; i < out; i++ { + t := tipe.Out(i) + switch t.String() { + case errType.String(): + if errOut != nil { + panic(fmt.Sprintf("Unexpected multiple error outputs")) + } + errOut = intP(i) + default: + if outputOut != nil { + panic(fmt.Sprintf("Unexpected multiple outputs")) + } + outputOut = intP(i) + } + } + + resolve := func(p ResolveParams) (output interface{}, err error) { + inputs := make([]reflect.Value, in) + if ctxIn != nil { + isPtr := tipe.In(*ctxIn).Kind() == reflect.Ptr + if isPtr { + if p.Context == nil { + inputs[*ctxIn] = reflect.New(ctxType) + } else { + inputs[*ctxIn] = reflect.ValueOf(&p.Context) + } + } else { + if p.Context == nil { + inputs[*ctxIn] = reflect.New(ctxType).Elem() + } else { + inputs[*ctxIn] = reflect.ValueOf(p.Context).Convert(ctxType).Elem() + } + } + } + if inputIn != nil { + var inputType, inputBaseType, sourceType, sourceBaseType reflect.Type + sourceVal := reflect.ValueOf(p.Source) + sourceExists := !sourceVal.IsZero() + if sourceExists { + sourceType = sourceVal.Type() + if sourceType.Kind() == reflect.Ptr { + sourceBaseType = sourceType.Elem() + } else { + sourceBaseType = sourceType + } + } + inputType = tipe.In(*inputIn) + isPtr := tipe.In(*inputIn).Kind() == reflect.Ptr + if isPtr { + inputBaseType = inputType.Elem() + } else { + inputBaseType = inputType + } + var input interface{} + if sourceExists && sourceBaseType.AssignableTo(inputBaseType) { + input = sourceVal.Interface() + } else { + input = reflect.New(inputBaseType).Interface() + j, err := json.Marshal(p.Args) + if err == nil { + err = json.Unmarshal(j, &input) + } + if err != nil { + return nil, err + } + } + + inputs[*inputIn], err = convertValue(reflect.ValueOf(input), inputType) + if err != nil { + return nil, err + } + } + results := val.Call(inputs) + if errOut != nil { + val := results[*errOut].Interface() + if val != nil { + err = val.(error) + } + if err != nil { + return output, err + } + } + if outputOut != nil { + var val reflect.Value + val, err = convertValue(results[*outputOut], tipe.Out(*outputOut)) + if err != nil { + return nil, err + } + if !val.IsZero() { + output = val.Interface() + } + } + return output, err + } + + var outputType Output + if outputOut != nil { + outputType = BindType(tipe.Out(*outputOut)) + extendType(outputType, combinedAdditionalFields) + } + + field := &Field{ + Type: outputType, + Resolve: resolve, + Args: queryArgs, + } + + return field + } else if tipe.Kind() == reflect.Struct { + fieldType := BindType(reflect.TypeOf(bindTo)) + extendType(fieldType, combinedAdditionalFields) + field := &Field{ + Type: fieldType, + Resolve: func(p ResolveParams) (data interface{}, err error) { + return bindTo, nil + }, + } + return field + } else { + if len(additionalFields) > 0 { + panic("Cannot add field resolvers to a scalar type.") + } + return &Field{ + Type: getGraphType(tipe), + Resolve: func(p ResolveParams) (data interface{}, err error) { + return bindTo, nil + }, + } + } +} + +func extendType(t Type, fields Fields) { + switch t.(type) { + case *Object: + object := t.(*Object) + for fieldName, fieldConfig := range fields { + object.AddFieldConfig(fieldName, fieldConfig) + } + return + case *List: + list := t.(*List) + extendType(list.OfType, fields) + return + } +} + +func convertValue(value reflect.Value, targetType reflect.Type) (ret reflect.Value, err error) { + if !value.IsValid() || value.IsZero() { + return reflect.Zero(targetType), nil + } + if value.Type().Kind() == reflect.Ptr { + if targetType.Kind() == reflect.Ptr { + return value, nil + } else { + return value.Elem(), nil + } + } else { + if targetType.Kind() == reflect.Ptr { + // Will throw an informative error + return value.Convert(targetType), nil + } else { + return value, nil + } + } +} + +func intP(i int) *int { + return &i +} diff --git a/bind_test.go b/bind_test.go new file mode 100644 index 00000000..f2fa89cb --- /dev/null +++ b/bind_test.go @@ -0,0 +1,241 @@ +package graphql_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "strings" + "testing" + "time" + + "github.com/graphql-go/graphql" +) + +type HelloOutput struct { + Message string `json:"message"` +} + +func Hello(ctx *context.Context) (output *HelloOutput, err error) { + output = &HelloOutput{ + Message: "Hello World", + } + return output, nil +} + +func Hellos() []HelloOutput { + return []HelloOutput{ + { + Message: "Hello One", + }, + { + Message: "Hello Two", + }, + } +} + +func Upper(ctx *context.Context, source HelloOutput) string { + return strings.ToUpper(source.Message) +} + +type GreetingInput struct { + Name string `json:"name"` +} + +type GreetingOutput struct { + Message string `json:"message"` + Timestamp time.Time `json:"timestamp"` +} + +func GreetingPtr(ctx *context.Context, input *GreetingInput) (output *GreetingOutput, err error) { + return &GreetingOutput{ + Message: fmt.Sprintf("Hello %s.", input.Name), + Timestamp: time.Now(), + }, nil +} + +func Greeting(ctx context.Context, input GreetingInput) (output GreetingOutput, err error) { + return GreetingOutput{ + Message: fmt.Sprintf("Hello %s.", input.Name), + Timestamp: time.Now(), + }, nil +} + +type FriendRecur struct { + Name string `json:"name"` + Friends []FriendRecur `json:"friends"` +} + +func friends(ctx *context.Context) (output *FriendRecur) { + recursiveFriendRecur := FriendRecur{ + Name: "Recursion", + } + recursiveFriendRecur.Friends = make([]FriendRecur, 2) + recursiveFriendRecur.Friends[0] = recursiveFriendRecur + recursiveFriendRecur.Friends[1] = recursiveFriendRecur + + return &FriendRecur{ + Name: "Alan", + Friends: []FriendRecur{ + recursiveFriendRecur, + { + Name: "Samantha", + Friends: []FriendRecur{ + { + Name: "Olivia", + }, + { + Name: "Eric", + }, + }, + }, + { + Name: "Brian", + Friends: []FriendRecur{ + { + Name: "Windy", + }, + { + Name: "Kevin", + }, + }, + }, + { + Name: "Kevin", + Friends: []FriendRecur{ + { + Name: "Sergei", + }, + { + Name: "Michael", + }, + }, + }, + }, + } +} + +func TestBindHappyPath(t *testing.T) { + // Schema + fields := graphql.Fields{ + "hello": graphql.Bind(Hello), + "hellos": graphql.Bind(Hellos, graphql.Fields{ + "upper": graphql.Bind(Upper), + }), + "greeting": graphql.Bind(Greeting), + "greetingPtr": graphql.Bind(GreetingPtr), + "friends": graphql.Bind(friends), + "string": graphql.Bind("Hello World"), + "number": graphql.Bind(12345), + "float": graphql.Bind(123.45), + "anonymous": graphql.Bind(struct { + SomeField string `json:"someField"` + }{ + SomeField: "Some Value", + }), + "simpleFunc": graphql.Bind(func() string { + return "Hello World" + }), + } + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: fields} + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + hello { + message + upper + } + hellos { + message + upper + } + greeting(name:"Alan") { + message + timestamp + } + greetingPtr(name:"Alan") { + message + timestamp + } + friends { + name + friends { + name + friends { + name + friends { + name + friends { + name + } + } + } + } + } + string + number + float + anonymous { + someField + } + simpleFunc + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) > 0 { + t.Errorf("failed to execute graphql operation, errors: %+v", r.Errors) + } + json, err := json.MarshalIndent(r.Data, "", " ") + fmt.Println(string(json)) +} + +func TestBindPanicImproperInput(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected Bind to panic due to improper function signature") + } + }() + graphql.Bind(func(a, b, c string) {}) +} + +func TestBindPanicImproperOutput(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Expected Bind to panic due to improper function signature") + } + }() + graphql.Bind(func() (string, string) { return "Hello", "World" }) +} + +func TestBindWithRuntimeError(t *testing.T) { + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ + "throwError": graphql.Bind(func() (string, error) { + return "", errors.New("Some Error") + }), + }} + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + throwError + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) == 0 { + t.Error("Expected error") + } +} diff --git a/examples/bind-complex/main.go b/examples/bind-complex/main.go new file mode 100644 index 00000000..0bed337b --- /dev/null +++ b/examples/bind-complex/main.go @@ -0,0 +1,85 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + + "github.com/graphql-go/graphql" +) + +var people = []Person{ + { + Name: "Alan", + Friends: []Person{ + { + Name: "Nadeem", + Friends: []Person{ + { + Name: "Heidi", + }, + }, + }, + }, + }, +} + +type Person struct { + Name string `json:"name"` + Friends []Person `json:"friends"` +} + +type GetPersonInput struct { + Name string `json:"name"` +} + +type GetPersonOutput struct { + Person +} + +func GetPerson(ctx context.Context, input GetPersonInput) (*GetPersonOutput, error) { + for _, person := range people { + if person.Name == input.Name { + return &GetPersonOutput{ + Person: person, + }, nil + } + } + return nil, errors.New("Could not find person.") +} + +func main() { + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ + "person": graphql.Bind(GetPerson), + }} + + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + person(name: "Alan") { + name + friends { + name + friends { + name + } + } + } + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) > 0 { + log.Fatalf("failed to execute graphql operation, errors: %+v", r.Errors) + } + rJSON, _ := json.Marshal(r) + fmt.Printf("%s \n", rJSON) +} diff --git a/examples/bind-simple/main.go b/examples/bind-simple/main.go new file mode 100644 index 00000000..ada9aee5 --- /dev/null +++ b/examples/bind-simple/main.go @@ -0,0 +1,43 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + + "github.com/graphql-go/graphql" +) + +type GreetingInput struct { + Name string `json:"name"` +} + +func Greeting(input GreetingInput) string { + return fmt.Sprintf("Hello %s", input.Name) +} + +func main() { + rootQuery := graphql.ObjectConfig{Name: "RootQuery", Fields: graphql.Fields{ + "greeting": graphql.Bind(Greeting), + }} + + schemaConfig := graphql.SchemaConfig{Query: graphql.NewObject(rootQuery)} + schema, err := graphql.NewSchema(schemaConfig) + if err != nil { + log.Fatalf("failed to create new schema, error: %v", err) + } + + // Query + query := ` + { + greeting(name: "Alan") + } + ` + params := graphql.Params{Schema: schema, RequestString: query} + r := graphql.Do(params) + if len(r.Errors) > 0 { + log.Fatalf("failed to execute graphql operation, errors: %+v", r.Errors) + } + rJSON, _ := json.Marshal(r) + fmt.Printf("%s \n", rJSON) +} diff --git a/executor.go b/executor.go index 7440ae21..096f6fbc 100644 --- a/executor.go +++ b/executor.go @@ -943,7 +943,7 @@ func DefaultResolveFn(p ResolveParams) (interface{}, error) { } // try to resolve p.Source as a struct - if sourceVal.IsValid() && sourceVal.Type().Kind() == reflect.Ptr { + if sourceVal.IsValid() && !sourceVal.IsZero() && sourceVal.Type().Kind() == reflect.Ptr { sourceVal = sourceVal.Elem() } if !sourceVal.IsValid() { diff --git a/util.go b/util.go index ae374c33..a48ebca2 100644 --- a/util.go +++ b/util.go @@ -8,13 +8,88 @@ import ( ) const TAG = "json" +const TYPETAG = "graphql" + +var boundTypes = map[string]*Object{} +var anonTypes = 0 + +func MergeFields(fieldses ...Fields) (ret Fields) { + ret = Fields{} + for _, fields := range fieldses { + for key, field := range fields { + if _, ok := ret[key]; ok { + panic(fmt.Sprintf("Dupliate field: %s", key)) + } + ret[key] = field + } + } + return ret +} + +func BindType(tipe reflect.Type) Type { + if tipe.Kind() == reflect.Ptr { + tipe = tipe.Elem() + } + + kind := tipe.Kind() + switch kind { + case reflect.String: + return String + case reflect.Int, reflect.Int8, reflect.Int32, reflect.Int64: + return Int + case reflect.Float32, reflect.Float64: + return Float + case reflect.Bool: + return Boolean + case reflect.Slice: + return getGraphList(tipe) + } + + typeName := safeName(tipe) + object, ok := boundTypes[typeName] + if !ok { + // Allows for recursion + object = &Object{} + boundTypes[typeName] = object + *object = *NewObject(ObjectConfig{ + Name: typeName, + Fields: BindFields(reflect.New(tipe).Interface()), + }) + } + + return object +} + +func safeName(tipe reflect.Type) string { + name := fmt.Sprint(tipe) + if strings.HasPrefix(name, "struct ") { + anonTypes++ + name = fmt.Sprintf("Anon%d", anonTypes) + } else { + name = strings.Replace(fmt.Sprint(tipe), ".", "_", -1) + } + return name +} + +func getType(typeTag string) Output { + switch strings.ToLower(typeTag) { + case "int": + return Int + case "float": + return Float + case "string": + return String + case "boolean": + return Boolean + case "id": + return ID + case "datetime": + return DateTime + default: + panic(fmt.Sprintf("Unsupported graphql type: %s", typeTag)) + } +} -// can't take recursive slice type -// e.g -// type Person struct{ -// Friends []Person -// } -// it will throw panic stack-overflow func BindFields(obj interface{}) Fields { t := reflect.TypeOf(obj) v := reflect.ValueOf(obj) @@ -33,14 +108,17 @@ func BindFields(obj interface{}) Fields { continue } + typeTag := field.Tag.Get(TYPETAG) + fieldType := field.Type if fieldType.Kind() == reflect.Ptr { fieldType = fieldType.Elem() } - var graphType Output - if fieldType.Kind() == reflect.Struct { + if typeTag != "" { + graphType = getType(typeTag) + } else if fieldType.Kind() == reflect.Struct { itf := v.Field(i).Interface() if _, ok := itf.(encoding.TextMarshaler); ok { fieldType = reflect.TypeOf("") @@ -53,10 +131,7 @@ func BindFields(obj interface{}) Fields { fields = appendFields(fields, structFields) continue } else { - graphType = NewObject(ObjectConfig{ - Name: tag, - Fields: structFields, - }) + graphType = BindType(fieldType) } } @@ -110,11 +185,7 @@ func getGraphList(tipe reflect.Type) *List { } // finally bind object t := reflect.New(tipe.Elem()) - name := strings.Replace(fmt.Sprint(tipe.Elem()), ".", "_", -1) - obj := NewObject(ObjectConfig{ - Name: name, - Fields: BindFields(t.Elem().Interface()), - }) + obj := BindType(t.Elem().Type()) return NewList(obj) } @@ -132,21 +203,27 @@ func extractValue(originTag string, obj interface{}) interface{} { field := val.Type().Field(j) found := originTag == extractTag(field.Tag) if field.Type.Kind() == reflect.Struct { - itf := val.Field(j).Interface() + fieldVal := val.Field(j) + if !fieldVal.IsZero() { + itf := fieldVal.Interface() - if str, ok := itf.(encoding.TextMarshaler); ok && found { - byt, _ := str.MarshalText() - return string(byt) - } + if str, ok := itf.(encoding.TextMarshaler); ok && found { + byt, _ := str.MarshalText() + return string(byt) + } - res := extractValue(originTag, itf) - if res != nil { - return res + res := extractValue(originTag, itf) + if res != nil { + return res + } } } if found { - return reflect.Indirect(val.Field(j)).Interface() + fieldVal := val.Field(j) + if !fieldVal.IsZero() { + return reflect.Indirect(fieldVal).Interface() + } } } return nil