Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

Commit c81a76b

Browse files
committed
Enable passing Python functions to Go for invocation.
1 parent 004b7a8 commit c81a76b

File tree

3 files changed

+218
-5
lines changed

3 files changed

+218
-5
lines changed

runtime/function.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ func functionGet(_ *Frame, desc, instance *Object, owner *Type) (*Object, *BaseE
125125
return NewMethod(toFunctionUnsafe(desc), instance, owner).ToObject(), nil
126126
}
127127

128+
func functionNative(f *Frame, o *Object) (reflect.Value, *BaseException) {
129+
return reflect.ValueOf(o.Call), nil
130+
}
131+
128132
func functionRepr(_ *Frame, o *Object) (*Object, *BaseException) {
129133
fun := toFunctionUnsafe(o)
130134
return NewStr(fmt.Sprintf("<%s %s at %p>", fun.typ.Name(), fun.Name(), fun)).ToObject(), nil
@@ -134,6 +138,7 @@ func initFunctionType(map[string]*Object) {
134138
FunctionType.flags &= ^(typeFlagInstantiable | typeFlagBasetype)
135139
FunctionType.slots.Call = &callSlot{functionCall}
136140
FunctionType.slots.Get = &getSlot{functionGet}
141+
FunctionType.slots.Native = &nativeSlot{functionNative}
137142
FunctionType.slots.Repr = &unaryOpSlot{functionRepr}
138143
}
139144

runtime/native.go

Lines changed: 97 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ var (
5656
}
5757
nativeTypesMutex = sync.Mutex{}
5858
sliceIteratorType = newBasisType("sliceiterator", reflect.TypeOf(sliceIterator{}), toSliceIteratorUnsafe, ObjectType)
59+
60+
baseExceptionReflectType = reflect.TypeOf((*BaseException)(nil))
61+
frameReflectType = reflect.TypeOf((*Frame)(nil))
5962
)
6063

6164
type nativeMetaclass struct {
@@ -489,22 +492,111 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect
489492
if raised != nil {
490493
return reflect.Value{}, raised
491494
}
492-
rtype := val.Type()
493495
for {
496+
rtype := val.Type()
494497
if rtype == expectedRType {
495498
return val, nil
496499
}
497500
if rtype.ConvertibleTo(expectedRType) {
498501
return val.Convert(expectedRType), nil
499502
}
500-
if rtype.Kind() == reflect.Ptr {
503+
switch rtype.Kind() {
504+
case reflect.Ptr:
501505
val = val.Elem()
502-
rtype = val.Type()
503506
continue
507+
508+
case reflect.Func:
509+
if fn, ok := val.Interface().(func(*Frame, Args, KWArgs) (*Object, *BaseException)); ok {
510+
val = nativeToPyFuncBridge(fn, expectedRType)
511+
continue
512+
}
504513
}
505-
break
514+
return val, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
515+
}
516+
}
517+
518+
// pyToNativeRaised supports pushing a `raised` exception from python code to
519+
// native calling code. If the raised exception can't be returned to native
520+
// code, then the raised exception is panic-ed.
521+
func pyToNativeRaised(outs []reflect.Type, raised *BaseException) []reflect.Value {
522+
last := len(outs) - 1
523+
if len(outs) == 0 || outs[last] != baseExceptionReflectType {
524+
panic(raised)
506525
}
507-
return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
526+
ret := make([]reflect.Value, len(outs))
527+
for i, out := range outs[:last] {
528+
ret[i] = reflect.Zero(out)
529+
}
530+
ret[last] = reflect.ValueOf(raised)
531+
return ret
532+
}
533+
534+
func nativeToPyFuncBridge(fn func(*Frame, Args, KWArgs) (*Object, *BaseException), target reflect.Type) reflect.Value {
535+
firstInIsFrame := target.NumIn() > 0 && target.In(0) == frameReflectType
536+
537+
outs := make([]reflect.Type, target.NumOut())
538+
for i := range outs {
539+
outs[i] = target.Out(i)
540+
}
541+
542+
return reflect.MakeFunc(target, func(args []reflect.Value) []reflect.Value {
543+
var f *Frame
544+
if firstInIsFrame {
545+
f, args = args[0].Interface().(*Frame), args[1:]
546+
} else {
547+
f = NewRootFrame()
548+
}
549+
550+
pyArgs := f.MakeArgs(len(args))
551+
for i, arg := range args {
552+
var raised *BaseException
553+
pyArgs[i], raised = WrapNative(f, arg)
554+
if raised != nil {
555+
return pyToNativeRaised(outs, raised)
556+
}
557+
}
558+
559+
ret, raised := fn(f, pyArgs, nil)
560+
f.FreeArgs(pyArgs)
561+
if raised != nil {
562+
return pyToNativeRaised(outs, raised)
563+
}
564+
565+
switch len(outs) {
566+
case 0:
567+
if ret != nil && ret != None {
568+
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("unexpected return of %v when None expected", ret)))
569+
}
570+
return nil
571+
572+
case 1:
573+
v, raised := maybeConvertValue(f, ret, outs[0])
574+
if raised != nil {
575+
return pyToNativeRaised(outs, raised)
576+
}
577+
return []reflect.Value{v}
578+
579+
default:
580+
converted := make([]reflect.Value, 0, len(outs))
581+
if raised := seqForEach(f, ret, func(o *Object) *BaseException {
582+
i := len(converted)
583+
if i >= len(outs) {
584+
return f.RaiseType(TypeErrorType, fmt.Sprintf("return value too long, want %d items", len(outs)))
585+
}
586+
v, raised := maybeConvertValue(f, o, outs[i])
587+
converted = append(converted, v)
588+
return raised
589+
}); raised != nil {
590+
return pyToNativeRaised(outs, raised)
591+
}
592+
593+
if len(converted) != len(outs) {
594+
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("return value wrong size %d, want %d", len(converted), len(outs))))
595+
}
596+
597+
return converted
598+
}
599+
})
508600
}
509601

510602
func nativeFuncTypeName(rtype reflect.Type) string {

runtime/native_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,122 @@ func TestMaybeConvertValue(t *testing.T) {
422422
}
423423
}
424424

425+
func TestNativveToPyFuncBridge(t *testing.T) {
426+
tests := []struct {
427+
name string
428+
fn func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException)
429+
typ reflect.Type
430+
args []interface{}
431+
ret []interface{}
432+
panc *BaseException
433+
}{
434+
{
435+
name: "no args",
436+
fn: func(t *testing.T, f *Frame, a Args, k KWArgs) (*Object, *BaseException) {
437+
if f == nil || len(a) != 0 || len(k) != 0 {
438+
t.Errorf("fn called with (%v, %v, %v), want (non-nil, %v, %v)", f, a, k, Args{}, KWArgs{})
439+
}
440+
return nil, nil
441+
},
442+
typ: reflect.TypeOf(func() {}),
443+
ret: []interface{}{},
444+
},
445+
{
446+
name: "return wrong size",
447+
fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) {
448+
return NewInt(1).ToObject(), nil
449+
},
450+
typ: reflect.TypeOf(func() {}),
451+
panc: mustCreateException(TypeErrorType, "unexpected return of 1 when None expected"),
452+
},
453+
{
454+
name: "single return value",
455+
fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) {
456+
return NewInt(1).ToObject(), nil
457+
},
458+
typ: reflect.TypeOf((*func() int)(nil)).Elem(),
459+
ret: []interface{}{1},
460+
},
461+
{
462+
name: "wrong size multiple return value",
463+
fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) {
464+
return NewTuple(NewInt(1).ToObject(), NewInt(2).ToObject(), NewInt(3).ToObject()).ToObject(), nil
465+
},
466+
typ: reflect.TypeOf((*func() (int, int))(nil)).Elem(),
467+
panc: mustCreateException(TypeErrorType, "return value too long, want 2 items"),
468+
},
469+
{
470+
name: "multiple return value",
471+
fn: func(*testing.T, *Frame, Args, KWArgs) (*Object, *BaseException) {
472+
return NewTuple(NewInt(1).ToObject(), NewInt(2).ToObject(), NewInt(3).ToObject()).ToObject(), nil
473+
},
474+
typ: reflect.TypeOf((*func() (int, int, int))(nil)).Elem(),
475+
ret: []interface{}{1, 2, 3},
476+
},
477+
478+
{
479+
name: "func takes args",
480+
fn: func(t *testing.T, f *Frame, a Args, k KWArgs) (*Object, *BaseException) {
481+
want := Args{
482+
NewInt(1).ToObject(),
483+
NewInt(2).ToObject(),
484+
}
485+
if f == nil || !reflect.DeepEqual(a, want) || len(k) != 0 {
486+
t.Errorf("fn called with (%v, %v, %v), want (non-nil, %v, %v)", f, a, k, want, KWArgs{})
487+
}
488+
return nil, nil
489+
},
490+
typ: reflect.TypeOf(func(int, int) {}),
491+
args: []interface{}{1, 2},
492+
ret: []interface{}{},
493+
},
494+
}
495+
496+
for _, test := range tests {
497+
t.Run(test.name, func(t *testing.T) {
498+
called := false
499+
fn := func(f *Frame, a Args, k KWArgs) (*Object, *BaseException) {
500+
called = true
501+
return test.fn(t, f, a, k)
502+
}
503+
504+
args := make([]reflect.Value, len(test.args))
505+
for i, a := range test.args {
506+
args[i] = reflect.ValueOf(a)
507+
}
508+
509+
nativeFn := nativeToPyFuncBridge(fn, test.typ)
510+
ret := func() []reflect.Value {
511+
if test.panc != nil {
512+
defer func() {
513+
r := recover()
514+
raised, ok := r.(*BaseException)
515+
if r == nil || !ok || !exceptionsAreEquivalent(raised, test.panc) {
516+
t.Errorf("recover()=%v (type %T), want %v", r, r, test.panc)
517+
}
518+
}()
519+
}
520+
return nativeFn.Call(args)
521+
}()
522+
523+
if test.panc == nil {
524+
got := make([]interface{}, 0, len(test.ret))
525+
for _, v := range ret {
526+
got = append(got, v.Interface())
527+
}
528+
529+
if !reflect.DeepEqual(got, test.ret) {
530+
t.Errorf("fn returned %v, want %v", got, test.ret)
531+
}
532+
}
533+
534+
if !called {
535+
t.Errorf("fn not called, want to be called")
536+
}
537+
})
538+
}
539+
}
540+
425541
func TestNativeTypedefNative(t *testing.T) {
426542
fun := wrapFuncForTest(func(f *Frame, o *Object, wantType reflect.Type) (bool, *BaseException) {
427543
val, raised := ToNative(f, o)

0 commit comments

Comments
 (0)