Skip to content
This repository was archived by the owner on Mar 23, 2023. It is now read-only.

Commit 627e648

Browse files
committed
Enable passing Python functions to Go for invocation.
1 parent c0a76bc commit 627e648

File tree

2 files changed

+103
-5
lines changed

2 files changed

+103
-5
lines changed

runtime/function.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ func functionGet(_ *Frame, desc, instance *Object, owner *Type) (*Object, *BaseE
125125
return NewMethod(toFunctionUnsafe(desc), instance, owner).ToObject(), nil
126126
}
127127

128+
func functionNative(f *Frame, o *Object) (reflect.Value, *BaseException) {
129+
return reflect.ValueOf(o.Call), nil
130+
}
131+
128132
func functionRepr(_ *Frame, o *Object) (*Object, *BaseException) {
129133
fun := toFunctionUnsafe(o)
130134
return NewStr(fmt.Sprintf("<%s %s at %p>", fun.typ.Name(), fun.Name(), fun)).ToObject(), nil
@@ -134,6 +138,7 @@ func initFunctionType(map[string]*Object) {
134138
FunctionType.flags &= ^(typeFlagInstantiable | typeFlagBasetype)
135139
FunctionType.slots.Call = &callSlot{functionCall}
136140
FunctionType.slots.Get = &getSlot{functionGet}
141+
FunctionType.slots.Native = &nativeSlot{functionNative}
137142
FunctionType.slots.Repr = &unaryOpSlot{functionRepr}
138143
}
139144

runtime/native.go

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -479,22 +479,115 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect
479479
if raised != nil {
480480
return reflect.Value{}, raised
481481
}
482-
rtype := val.Type()
483482
for {
483+
rtype := val.Type()
484484
if rtype == expectedRType {
485485
return val, nil
486486
}
487487
if rtype.ConvertibleTo(expectedRType) {
488488
return val.Convert(expectedRType), nil
489489
}
490-
if rtype.Kind() == reflect.Ptr {
490+
switch rtype.Kind() {
491+
case reflect.Ptr:
491492
val = val.Elem()
492-
rtype = val.Type()
493493
continue
494+
495+
case reflect.Func:
496+
if fn, ok := val.Interface().(func(*Frame, Args, KWArgs) (*Object, *BaseException)); ok {
497+
val = nativeToPyFuncBridge(fn, expectedRType)
498+
continue
499+
}
494500
}
495-
break
501+
return val, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
502+
}
503+
}
504+
505+
var baseExceptionReflectType = reflect.TypeOf((*BaseException)(nil))
506+
507+
// pyToNativeRaised supports pushing a `raised` exception from python code to
508+
// native calling code. If the raised exception can't be returned to native
509+
// code, then the raised exception is panic-ed.
510+
func pyToNativeRaised(outs []reflect.Type, raised *BaseException) []reflect.Value {
511+
last := len(outs) - 1
512+
if len(outs) == 0 || outs[last] != baseExceptionReflectType {
513+
panic(raised)
514+
}
515+
ret := make([]reflect.Value, len(outs))
516+
for i, out := range outs[:last] {
517+
ret[i] = reflect.Zero(out)
496518
}
497-
return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
519+
ret[last] = reflect.ValueOf(raised)
520+
return ret
521+
}
522+
523+
var frameReflectType = reflect.TypeOf((*Frame)(nil))
524+
525+
func nativeToPyFuncBridge(fn func(*Frame, Args, KWArgs) (*Object, *BaseException), target reflect.Type) reflect.Value {
526+
firstInIsFrame := target.NumIn() > 0 && target.In(0) == frameReflectType
527+
528+
outs := make([]reflect.Type, target.NumOut())
529+
for i := range outs {
530+
outs[i] = target.Out(i)
531+
}
532+
533+
return reflect.MakeFunc(target, func(args []reflect.Value) []reflect.Value {
534+
var f *Frame
535+
if firstInIsFrame {
536+
f, args = args[0].Interface().(*Frame), args[1:]
537+
} else {
538+
f = NewRootFrame()
539+
}
540+
541+
pyArgs := f.MakeArgs(len(args))
542+
for i, arg := range args {
543+
var raised *BaseException
544+
pyArgs[i], raised = WrapNative(f, arg)
545+
if raised != nil {
546+
return pyToNativeRaised(outs, raised)
547+
}
548+
}
549+
550+
ret, raised := fn(f, pyArgs, nil)
551+
f.FreeArgs(pyArgs)
552+
if raised != nil {
553+
return pyToNativeRaised(outs, raised)
554+
}
555+
556+
switch len(outs) {
557+
case 0:
558+
if ret != nil && ret != None {
559+
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("unexpected return of %v when None expected", ret)))
560+
}
561+
return nil
562+
563+
case 1:
564+
v, raised := maybeConvertValue(f, ret, outs[0])
565+
if raised != nil {
566+
return pyToNativeRaised(outs, raised)
567+
}
568+
return []reflect.Value{v}
569+
570+
default:
571+
converted := make([]reflect.Value, 0, len(outs))
572+
if raised := seqForEach(f, ret, func(o *Object) *BaseException {
573+
i := len(converted)
574+
if i >= len(outs) {
575+
return f.RaiseType(TypeErrorType, fmt.Sprintf("return value too long, want %d items", len(outs)))
576+
}
577+
v, raised := maybeConvertValue(f, o, outs[i])
578+
converted = append(converted, v)
579+
return raised
580+
}); raised != nil {
581+
return pyToNativeRaised(outs, raised)
582+
}
583+
584+
if len(converted) != len(outs) {
585+
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("return value wrong size %d, want %d", len(converted), len(outs))))
586+
}
587+
588+
return converted
589+
}
590+
})
498591
}
499592

500593
func nativeFuncTypeName(rtype reflect.Type) string {

0 commit comments

Comments
 (0)