Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions asm_amd64.S
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,28 @@
*/

SYMBOL(make_call):

pushq %rbp
movq %rsp, %rbp
pushq %rbp # save the base pointer
movq %rsp, %rbp # set new base pointer
pushq %rbx
pushq %r12
pushq %r13
pushq %r14
pushq %r15

// We're pushing 7 values, need to align
// %rsp to a 16 byte boundary.

// When this function was called, return address (8 bytes) was pushed on the stack.
// Therefore the stack is already un-aligned.
// We've pushed 6 values (6 * 8 bytes), so must subtract 8 more bytes to
// get %rsp aligned to a 16 byte boundary.
// We use this 8 byte space to store the %rcx value (stack arguments count)
// available with '-40(%rbp)'
sub $8, %rsp
movq %rcx, -40(%rbp) # Save the stack arguments count

// Save function pointer
movq %rdi, %r12
// Can be used to test if %rsp is properly aligned
// This alligns properly, so if %rsp changes, it was not already aligned
//andq $~0xF, %rsp

movq %rdi, %r12 # Save function pointer
// Save wheter we want to return from %xmm0
movq %r9, %r13

Expand Down Expand Up @@ -70,11 +76,18 @@ stack_done:
movq 32(%r14), %r8
movq 40(%r14), %r9

call *%r12
call *%r12 # Call the function
movq -40(%rbp), %rcx # Restore stack arguments count to %rcx to know how much we should pop from stack

// Restore %esp adjustment for 16 byte boundary
restore_stack:
test %rcx, %rcx
je stack_restored
add $8, %rsp
dec %rcx
jmp restore_stack

stack_restored:
add $8, %rsp # Restore %rsp stack adjustment for 16 byte boundary
test %r13, %r13
je restore_registers
movq %xmm0, %rax
Expand Down
4 changes: 2 additions & 2 deletions dl.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ func (d *DL) Sym(symbol string, out interface{}) error {
return errors.New("out can't be nil")
}
elem := val.Elem()
switch elem.Kind() {
typ := reflect.TypeOf(elem.Interface())
switch typ.Kind() {
case reflect.Int:
// We treat Go's int as long, since it
// varies depending on the platform bit size
Expand Down Expand Up @@ -130,7 +131,6 @@ func (d *DL) Sym(symbol string, out interface{}) error {
case reflect.Float64:
elem.SetFloat(float64(*(*float64)(handle)))
case reflect.Func:
typ := elem.Type()
tr, err := makeTrampoline(typ, handle)
if err != nil {
return err
Expand Down
20 changes: 20 additions & 0 deletions dl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dl
import (
"os/exec"
"path/filepath"
"reflect"
"testing"
"unsafe"
)
Expand Down Expand Up @@ -175,6 +176,7 @@ func TestFunctions(t *testing.T) {
if err := dl.Sym("fill42", &fill42); err != nil {
t.Fatal(err)
}

b := make([]byte, 42)
fill42(b, int32(len(b)))
for ii, v := range b {
Expand Down Expand Up @@ -237,6 +239,24 @@ func TestReturnString(t *testing.T) {
}
}

func TestInterface(t *testing.T) {
dl := openTestLib(t)
defer dl.Close()

var square struct{ Call func(float64) float64 }

v := reflect.ValueOf(&square).Elem().Field(0)
sym := v.Interface()
if err := dl.Sym("square", &sym); err != nil {
t.Fatal(err)
}
v.Set(reflect.ValueOf(sym))

if r := square.Call(4); r != 16 {
t.Errorf("expecting square(4) = 16, got %v instead", r)
}
}

func init() {
if err := exec.Command("make", "-C", "testdata").Run(); err != nil {
panic(err)
Expand Down
25 changes: 0 additions & 25 deletions examples_test.go

This file was deleted.

36 changes: 18 additions & 18 deletions trampoline.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ func makeTrampoline(typ reflect.Type, handle unsafe.Pointer) (rFunc, error) {
}
}
count := len(in)
args := make([]unsafe.Pointer, count)
args := make([]uintptr, count)
flags := make([]C.int, count+1)
flags[count] = outFlag
for ii, v := range in {
Expand All @@ -55,70 +55,70 @@ func makeTrampoline(typ reflect.Type, handle unsafe.Pointer) (rFunc, error) {
case reflect.String:
s := C.CString(v.String())
defer C.free(unsafe.Pointer(s))
args[ii] = unsafe.Pointer(s)
args[ii] = uintptr(unsafe.Pointer(s))
flags[ii] |= C.ARG_FLAG_SIZE_PTR
case reflect.Int:
args[ii] = unsafe.Pointer(uintptr(v.Int()))
args[ii] = uintptr(v.Int())
if v.Type().Size() == 4 {
flags[ii] = C.ARG_FLAG_SIZE_32
} else {
flags[ii] = C.ARG_FLAG_SIZE_64
}
case reflect.Int8:
args[ii] = unsafe.Pointer(uintptr(v.Int()))
args[ii] = uintptr(v.Int())
flags[ii] = C.ARG_FLAG_SIZE_8
case reflect.Int16:
args[ii] = unsafe.Pointer(uintptr(v.Int()))
args[ii] = uintptr(v.Int())
flags[ii] = C.ARG_FLAG_SIZE_16
case reflect.Int32:
args[ii] = unsafe.Pointer(uintptr(v.Int()))
args[ii] = uintptr(v.Int())
flags[ii] = C.ARG_FLAG_SIZE_32
case reflect.Int64:
args[ii] = unsafe.Pointer(uintptr(v.Int()))
args[ii] = uintptr(v.Int())
flags[ii] = C.ARG_FLAG_SIZE_64
case reflect.Uint:
args[ii] = unsafe.Pointer(uintptr(v.Uint()))
args[ii] = uintptr(v.Uint())
if v.Type().Size() == 4 {
flags[ii] = C.ARG_FLAG_SIZE_32
} else {
flags[ii] = C.ARG_FLAG_SIZE_64
}
case reflect.Uint8:
args[ii] = unsafe.Pointer(uintptr(v.Uint()))
args[ii] = uintptr(v.Uint())
flags[ii] = C.ARG_FLAG_SIZE_8
case reflect.Uint16:
args[ii] = unsafe.Pointer(uintptr(v.Uint()))
args[ii] = uintptr(v.Uint())
flags[ii] = C.ARG_FLAG_SIZE_16
case reflect.Uint32:
args[ii] = unsafe.Pointer(uintptr(v.Uint()))
args[ii] = uintptr(v.Uint())
flags[ii] = C.ARG_FLAG_SIZE_32
case reflect.Uint64:
args[ii] = unsafe.Pointer(uintptr(v.Uint()))
args[ii] = uintptr(v.Uint())
flags[ii] = C.ARG_FLAG_SIZE_64
case reflect.Float32:
args[ii] = unsafe.Pointer(uintptr(math.Float32bits(float32(v.Float()))))
args[ii] = uintptr(math.Float32bits(float32(v.Float())))
flags[ii] |= C.ARG_FLAG_FLOAT | C.ARG_FLAG_SIZE_32
case reflect.Float64:
args[ii] = unsafe.Pointer(uintptr(math.Float64bits(v.Float())))
args[ii] = uintptr(math.Float64bits(v.Float()))
flags[ii] |= C.ARG_FLAG_FLOAT | C.ARG_FLAG_SIZE_64
case reflect.Ptr:
args[ii] = unsafe.Pointer(v.Pointer())
args[ii] = v.Pointer()
flags[ii] |= C.ARG_FLAG_SIZE_PTR
case reflect.Slice:
if v.Len() > 0 {
args[ii] = unsafe.Pointer(v.Index(0).UnsafeAddr())
args[ii] = v.Index(0).UnsafeAddr()
}
flags[ii] |= C.ARG_FLAG_SIZE_PTR
case reflect.Uintptr:
args[ii] = unsafe.Pointer(uintptr(v.Uint()))
args[ii] = uintptr(v.Uint())
flags[ii] |= C.ARG_FLAG_SIZE_PTR
default:
panic(fmt.Errorf("can't bind value of type %s", v.Type()))
}
}
var argp *unsafe.Pointer
if count > 0 {
argp = &args[0]
argp = (*unsafe.Pointer)(unsafe.Pointer(&args[0]))
}
var ret unsafe.Pointer
if C.call(handle, argp, &flags[0], C.int(count), &ret) != 0 {
Expand Down