diff --git a/dl.go b/dl.go index 85b2cd3..deee316 100644 --- a/dl.go +++ b/dl.go @@ -3,6 +3,7 @@ package dl // #include // #include // #cgo LDFLAGS: -ldl +// #cgo CFLAGS: -O0 import "C" import ( @@ -98,7 +99,8 @@ func (d *DL) Sym(symbol string, out interface{}) error { return errors.New("out can't be nil") } elem := val.Elem() - switch elem.Kind() { + typ := reflect.TypeOf(elem.Interface()) + switch typ.Kind() { case reflect.Int: // We treat Go's int as long, since it // varies depending on the platform bit size @@ -130,7 +132,6 @@ func (d *DL) Sym(symbol string, out interface{}) error { case reflect.Float64: elem.SetFloat(float64(*(*float64)(handle))) case reflect.Func: - typ := elem.Type() tr, err := makeTrampoline(typ, handle) if err != nil { return err diff --git a/dl_test.go b/dl_test.go index 85f2469..c432350 100644 --- a/dl_test.go +++ b/dl_test.go @@ -3,6 +3,7 @@ package dl import ( "os/exec" "path/filepath" + "reflect" "testing" "unsafe" ) @@ -175,6 +176,7 @@ func TestFunctions(t *testing.T) { if err := dl.Sym("fill42", &fill42); err != nil { t.Fatal(err) } + b := make([]byte, 42) fill42(b, int32(len(b))) for ii, v := range b { @@ -237,6 +239,24 @@ func TestReturnString(t *testing.T) { } } +func TestInterface(t *testing.T) { + dl := openTestLib(t) + defer dl.Close() + + var square struct{ Call func(float64) float64 } + + v := reflect.ValueOf(&square).Elem().Field(0) + sym := v.Interface() + if err := dl.Sym("square", &sym); err != nil { + t.Fatal(err) + } + v.Set(reflect.ValueOf(sym)) + + if r := square.Call(4); r != 16 { + t.Errorf("expecting square(4) = 16, got %v instead", r) + } +} + func init() { if err := exec.Command("make", "-C", "testdata").Run(); err != nil { panic(err) diff --git a/examples_test.go b/examples_test.go deleted file mode 100644 index d558862..0000000 --- a/examples_test.go +++ /dev/null @@ -1,25 +0,0 @@ -package dl_test - -import ( - "bytes" - "fmt" - - "dl" -) - -func ExampleOpen_snprintf() { - lib, err := dl.Open("libc", 0) - if err != nil { - panic(err) - } - defer lib.Close() - var snprintf func([]byte, uint, string, ...interface{}) int - if err := lib.Sym("snprintf", &snprintf); err != nil { - panic(err) - } - buf := make([]byte, 200) - snprintf(buf, uint(len(buf)), "hello %s!\n", "world") - s := string(buf[:bytes.IndexByte(buf, 0)]) - fmt.Println(s) - // Output: hello world! -} diff --git a/trampoline.go b/trampoline.go index 8f5c967..14ba7ca 100644 --- a/trampoline.go +++ b/trampoline.go @@ -44,7 +44,7 @@ func makeTrampoline(typ reflect.Type, handle unsafe.Pointer) (rFunc, error) { } } count := len(in) - args := make([]unsafe.Pointer, count) + args := make([]uintptr, count) flags := make([]C.int, count+1) flags[count] = outFlag for ii, v := range in { @@ -55,62 +55,62 @@ func makeTrampoline(typ reflect.Type, handle unsafe.Pointer) (rFunc, error) { case reflect.String: s := C.CString(v.String()) defer C.free(unsafe.Pointer(s)) - args[ii] = unsafe.Pointer(s) + args[ii] = uintptr(unsafe.Pointer(s)) flags[ii] |= C.ARG_FLAG_SIZE_PTR case reflect.Int: - args[ii] = unsafe.Pointer(uintptr(v.Int())) + args[ii] = uintptr(v.Int()) if v.Type().Size() == 4 { flags[ii] = C.ARG_FLAG_SIZE_32 } else { flags[ii] = C.ARG_FLAG_SIZE_64 } case reflect.Int8: - args[ii] = unsafe.Pointer(uintptr(v.Int())) + args[ii] = uintptr(v.Int()) flags[ii] = C.ARG_FLAG_SIZE_8 case reflect.Int16: - args[ii] = unsafe.Pointer(uintptr(v.Int())) + args[ii] = uintptr(v.Int()) flags[ii] = C.ARG_FLAG_SIZE_16 case reflect.Int32: - args[ii] = unsafe.Pointer(uintptr(v.Int())) + args[ii] = uintptr(v.Int()) flags[ii] = C.ARG_FLAG_SIZE_32 case reflect.Int64: - args[ii] = unsafe.Pointer(uintptr(v.Int())) + args[ii] = uintptr(v.Int()) flags[ii] = C.ARG_FLAG_SIZE_64 case reflect.Uint: - args[ii] = unsafe.Pointer(uintptr(v.Uint())) + args[ii] = uintptr(v.Uint()) if v.Type().Size() == 4 { flags[ii] = C.ARG_FLAG_SIZE_32 } else { flags[ii] = C.ARG_FLAG_SIZE_64 } case reflect.Uint8: - args[ii] = unsafe.Pointer(uintptr(v.Uint())) + args[ii] = uintptr(v.Uint()) flags[ii] = C.ARG_FLAG_SIZE_8 case reflect.Uint16: - args[ii] = unsafe.Pointer(uintptr(v.Uint())) + args[ii] = uintptr(v.Uint()) flags[ii] = C.ARG_FLAG_SIZE_16 case reflect.Uint32: - args[ii] = unsafe.Pointer(uintptr(v.Uint())) + args[ii] = uintptr(v.Uint()) flags[ii] = C.ARG_FLAG_SIZE_32 case reflect.Uint64: - args[ii] = unsafe.Pointer(uintptr(v.Uint())) + args[ii] = uintptr(v.Uint()) flags[ii] = C.ARG_FLAG_SIZE_64 case reflect.Float32: - args[ii] = unsafe.Pointer(uintptr(math.Float32bits(float32(v.Float())))) + args[ii] = uintptr(math.Float32bits(float32(v.Float()))) flags[ii] |= C.ARG_FLAG_FLOAT | C.ARG_FLAG_SIZE_32 case reflect.Float64: - args[ii] = unsafe.Pointer(uintptr(math.Float64bits(v.Float()))) + args[ii] = uintptr(math.Float64bits(v.Float())) flags[ii] |= C.ARG_FLAG_FLOAT | C.ARG_FLAG_SIZE_64 case reflect.Ptr: - args[ii] = unsafe.Pointer(v.Pointer()) + args[ii] = v.Pointer() flags[ii] |= C.ARG_FLAG_SIZE_PTR case reflect.Slice: if v.Len() > 0 { - args[ii] = unsafe.Pointer(v.Index(0).UnsafeAddr()) + args[ii] = v.Index(0).UnsafeAddr() } flags[ii] |= C.ARG_FLAG_SIZE_PTR case reflect.Uintptr: - args[ii] = unsafe.Pointer(uintptr(v.Uint())) + args[ii] = uintptr(v.Uint()) flags[ii] |= C.ARG_FLAG_SIZE_PTR default: panic(fmt.Errorf("can't bind value of type %s", v.Type())) @@ -118,7 +118,7 @@ func makeTrampoline(typ reflect.Type, handle unsafe.Pointer) (rFunc, error) { } var argp *unsafe.Pointer if count > 0 { - argp = &args[0] + argp = (*unsafe.Pointer)(unsafe.Pointer(&args[0])) } var ret unsafe.Pointer if C.call(handle, argp, &flags[0], C.int(count), &ret) != 0 {