Skip to content

Commit

Permalink
fix: fix decode in mixed struct
Browse files Browse the repository at this point in the history
  • Loading branch information
bytemain committed Mar 6, 2024
1 parent bad1f1e commit a9f4f03
Show file tree
Hide file tree
Showing 4 changed files with 613 additions and 181 deletions.
315 changes: 315 additions & 0 deletions internal/luai/decode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,315 @@
// Lua Interface
// Marshal and Unmarshal Lua Table to Go Struct

package luai

import (
"errors"
"fmt"
"reflect"
"strconv"

"github.com/version-fox/vfox/internal/logger"
lua "github.com/yuin/gopher-lua"
)

// indirect walks down v allocating pointers as needed,
// until it gets to a non-pointer.
// If it encounters an Unmarshaler, indirect stops and returns that.
// If decodingNull is true, indirect stops at the first settable pointer so it
// can be set to nil.
func indirect(v reflect.Value, decodingNull bool) reflect.Value {
// Issue #24153 indicates that it is generally not a guaranteed property
// that you may round-trip a reflect.Value by calling Value.Addr().Elem()
// and expect the value to still be settable for values derived from
// unexported embedded struct fields.
//
// The logic below effectively does this when it first addresses the value
// (to satisfy possible pointer methods) and continues to dereference
// subsequent pointers as necessary.
//
// After the first round-trip, we set v back to the original value to
// preserve the original RW flags contained in reflect.Value.
v0 := v
haveAddr := false

// If v is a named type and is addressable,
// start with its address, so that if the type has pointer methods,
// we find them.
if v.Kind() != reflect.Pointer && v.Type().Name() != "" && v.CanAddr() {
haveAddr = true
v = v.Addr()
}
for {
// Load value from interface, but only if the result will be
// usefully addressable.
if v.Kind() == reflect.Interface && !v.IsNil() {
e := v.Elem()
if e.Kind() == reflect.Pointer && !e.IsNil() && (!decodingNull || e.Elem().Kind() == reflect.Pointer) {
haveAddr = false
v = e
continue
}
}

if v.Kind() != reflect.Pointer {
break
}

if decodingNull && v.CanSet() {
break
}

// Prevent infinite loop if v is an interface pointing to its own address:
// var v interface{}
// v = &v
if v.Elem().Kind() == reflect.Interface && v.Elem().Elem() == v {
v = v.Elem()
break
}
if v.IsNil() {
v.Set(reflect.New(v.Type().Elem()))
}

if haveAddr {
v = v0 // restore original value after round-trip Value.Addr().Elem()
haveAddr = false
} else {
v = v.Elem()
}
}
return v
}

func storeLiteral(value reflect.Value, lvalue lua.LValue) {
value = indirect(value, false)
switch lvalue.Type() {
case lua.LTString:
value.SetString(lvalue.String())
case lua.LTNumber:
value.SetInt(int64(lvalue.(lua.LNumber)))
case lua.LTBool:
value.SetBool(bool(lvalue.(lua.LBool)))
}
}

func objectInterface(lvalue *lua.LTable) any {
var v = make(map[string]any)
lvalue.ForEach(func(key, value lua.LValue) {
v[key.String()] = valueInterface(value)
})
return v
}

func valueInterface(lvalue lua.LValue) any {
switch lvalue.Type() {
case lua.LTTable:
isArray := lvalue.(*lua.LTable).RawGetInt(0) != lua.LNil
if isArray {
return arrayInterface(lvalue.(*lua.LTable))
}
return objectInterface(lvalue.(*lua.LTable))
case lua.LTString:
return lvalue.String()
case lua.LTNumber:
return int(lvalue.(lua.LNumber))
case lua.LTBool:
return bool(lvalue.(lua.LBool))
}
return nil

}

func arrayInterface(lvalue *lua.LTable) any {
var v = make([]any, 0)
lvalue.ForEach(func(key, value lua.LValue) {
v = append(v, valueInterface(value))
})

return v
}

func unmarshalWorker(value lua.LValue, reflected reflect.Value) error {
logger.Debugf("reflected: %+v, value type: %s, kind: %s\n", reflected, value.Type(), reflected.Kind())

switch value.Type() {
case lua.LTTable:
reflected = indirect(reflected, false)
tagMap := make(map[string]int)

switch reflected.Kind() {
case reflect.Interface:
// Decoding into nil interface? Switch to non-reflect code.
if reflected.NumMethod() == 0 {
result := valueInterface(value)
reflected.Set(reflect.ValueOf(result))
}
// map[T1]T2 where T1 is string, an integer type
case reflect.Map:
t := reflected.Type()
keyType := t.Key()
// Map key must either have string kind, have an integer kind
switch keyType.Kind() {
case reflect.String,
reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
default:
return errors.New("luai: unsupported map key type " + keyType.String())
}

if reflected.IsNil() {
reflected.Set(reflect.MakeMap(t))
}

var mapElem reflect.Value

value.(*lua.LTable).ForEach(func(key, value lua.LValue) {
// Figure out field corresponding to key.
var subv reflect.Value

elemType := t.Elem()
if !mapElem.IsValid() {
mapElem = reflect.New(elemType).Elem()
} else {
mapElem.SetZero()
}

subv = mapElem

unmarshalWorker(value, subv)

var kv reflect.Value
switch keyType.Kind() {
case reflect.String:
kv = reflect.New(keyType).Elem()
kv.SetString(key.String())
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
s := key.String()
n, err := strconv.ParseInt(s, 10, 64)
if err != nil {
logger.Errorf("unmarshal: %v\n", err)
break
}
kv = reflect.New(keyType).Elem()
kv.SetInt(n)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
s := key.String()
n, err := strconv.ParseUint(s, 10, 64)
if err != nil {
logger.Errorf("unmarshal: %v\n", err)
break
}
kv = reflect.New(keyType).Elem()
kv.SetUint(n)
default:
panic("luai: Unexpected key type") // should never occur
}
if kv.IsValid() {
reflected.SetMapIndex(kv, subv)
}

})
case reflect.Slice:
i := 0

value.(*lua.LTable).ForEach(func(key, value lua.LValue) {
// Expand slice length, growing the slice if necessary.
if i >= reflected.Cap() {
reflected.Grow(1)
}
if i >= reflected.Len() {
reflected.SetLen(i + 1)
}
if i < reflected.Len() {
// Decode into element.
unmarshalWorker(value, reflected.Index(i))
} else {
unmarshalWorker(value, reflect.Value{})
}
i++
})

// Truncate slice if necessary.
if i < reflected.Len() {
reflected.SetLen(i)
}

if i == 0 {
reflected.Set(reflect.MakeSlice(reflected.Type(), 0, 0))
}
case reflect.Struct:
for i := 0; i < reflected.NumField(); i++ {
fieldTypeField := reflected.Type().Field(i)
tag := fieldTypeField.Tag.Get("luai")
if tag != "" {
tagMap[tag] = i
}
}
logger.Debugf("reflected: %+v, kind: %s, tagMap: %+v\n", reflected, reflected.Kind(), tagMap)

(value.(*lua.LTable)).ForEach(func(key, value lua.LValue) {
fieldName := key.String()
logger.Debugf("fieldName: %s, value type: %s\n", fieldName, value.Type())

switch reflected.Kind() {
case reflect.Struct:
field := reflected.FieldByName(fieldName)

// if field is not found, try to find it by tag
if !field.IsValid() {
fieldIndex, ok := tagMap[fieldName]
if !ok {
logger.Debugf("unmarshal: field %s not found in tagMap\n", fieldName)
return
}
field = reflected.Field(fieldIndex)
}

if !field.IsValid() {
logger.Debugf("unmarshal: field %s not found in struct\n", fieldName)
return
}

unmarshalWorker(value, field)
}
})
}
default:
switch reflected.Kind() {
case reflect.Interface:
// Decoding into nil interface? Switch to non-reflect code.
if reflected.NumMethod() == 0 {
result := valueInterface(value)
fmt.Printf("value: %v\n", value)
fmt.Printf("result: %v\n", result)
reflected.Set(reflect.ValueOf(result))
}
default:
storeLiteral(reflected, value)
}
}
return nil
}

func Unmarshal(value lua.LValue, v any) error {
reflected := reflect.ValueOf(v)

if reflected.Kind() != reflect.Pointer || reflected.IsNil() {
return errors.New("unmarshal: value must be a pointer")
}

return unmarshalWorker(value, reflected)
}

func applyValue(value lua.LValue, field reflect.Value) {
switch value.Type() {
case lua.LTString:
field.SetString(value.String())
case lua.LTNumber:
field.SetInt(int64(value.(lua.LNumber)))
case lua.LTBool:
field.SetBool(bool(value.(lua.LBool)))
case lua.LTTable:
Unmarshal(value.(*lua.LTable), field.Interface())
}
}
79 changes: 79 additions & 0 deletions internal/luai/encode.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
// Lua Interface
// Marshal and Unmarshal Lua Table to Go Struct

package luai

import (
"errors"
"reflect"

"github.com/version-fox/vfox/internal/logger"
lua "github.com/yuin/gopher-lua"
)

func Marshal(state *lua.LState, v any) (lua.LValue, error) {
reflected := reflect.ValueOf(v)
if reflected.Kind() == reflect.Ptr {
reflected = reflected.Elem()
}
logger.Debug(v, reflected.Kind())

switch reflected.Kind() {
case reflect.Struct:
table := state.NewTable()
for i := 0; i < reflected.NumField(); i++ {
field := reflected.Field(i)
fieldType := reflected.Type().Field(i)
if field.Kind() == reflect.Ptr {
field = field.Elem()
}

tag := fieldType.Tag.Get("luai")
if tag == "" {
tag = fieldType.Name
}

sub, err := Marshal(state, field.Interface())
if err != nil {
return nil, err
}
logger.Debugf("field: %v, tag: %v, sub: %v, kind: %s\n", field, tag, sub, field.Kind())
table.RawSetString(tag, sub)
}
return table, nil
case reflect.String:
return lua.LString(reflected.String()), nil
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return lua.LNumber(reflected.Int()), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return lua.LNumber(reflected.Uint()), nil
case reflect.Float32, reflect.Float64:
return lua.LNumber(reflected.Float()), nil
case reflect.Bool:
return lua.LBool(reflected.Bool()), nil
case reflect.Array, reflect.Slice:
table := state.NewTable()
for i := 0; i < reflected.Len(); i++ {
value, err := Marshal(state, reflected.Index(i).Interface())
if err != nil {
return nil, err
}
table.RawSetInt(i+1, value)
}
return table, nil
case reflect.Map:
table := state.NewTable()
for _, key := range reflected.MapKeys() {
value, err := Marshal(state, reflected.MapIndex(key).Interface())
if err != nil {
return nil, err
}

table.RawSetString(key.String(), value)
}
return table, nil
default:
return nil, errors.New("marshal: unsupported type " + reflected.Kind().String() + " for reflected ")
}

}
Loading

0 comments on commit a9f4f03

Please sign in to comment.