Skip to content

Commit

Permalink
fix: move concat & concat message extra by ConcatItems (#66)
Browse files Browse the repository at this point in the history
  • Loading branch information
meguminnnnnnnnn authored Feb 13, 2025
1 parent 1345294 commit 7b64b27
Show file tree
Hide file tree
Showing 7 changed files with 330 additions and 274 deletions.
2 changes: 1 addition & 1 deletion .testcoverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ local-prefix: "github.com/cloudwego/eino"
threshold:
# (optional; default 0)
# Minimum overall project coverage percentage required.
total: 75
total: 83

package: 30

Expand Down
270 changes: 3 additions & 267 deletions compose/stream_concat.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,77 +19,11 @@ package compose
import (
"fmt"
"io"
"reflect"
"strings"

"github.com/cloudwego/eino/internal"
"github.com/cloudwego/eino/schema"
"github.com/cloudwego/eino/utils/generic"
)

var (
concatFuncs = map[reflect.Type]any{
generic.TypeOf[*schema.Message](): schema.ConcatMessages,
generic.TypeOf[string](): concatStrings,
generic.TypeOf[[]*schema.Message](): concatMessageArray,
}
)

func concatStrings(ss []string) (string, error) {
var n int
for _, s := range ss {
n += len(s)
}

var b strings.Builder
b.Grow(n)
for _, s := range ss {
_, err := b.WriteString(s)
if err != nil {
return "", err
}
}

return b.String(), nil
}

func concatMessageArray(mas [][]*schema.Message) ([]*schema.Message, error) {
arrayLen := len(mas[0])

ret := make([]*schema.Message, arrayLen)
slicesToConcat := make([][]*schema.Message, arrayLen)

for _, ma := range mas {
if len(ma) != arrayLen {
return nil, fmt.Errorf("unexpected array length. "+
"Got %d, expected %d", len(ma), arrayLen)
}

for i := 0; i < arrayLen; i++ {
m := ma[i]
if m != nil {
slicesToConcat[i] = append(slicesToConcat[i], m)
}
}
}

for i, slice := range slicesToConcat {
if len(slice) == 0 {
ret[i] = nil
} else if len(slice) == 1 {
ret[i] = slice[0]
} else {
cm, err := schema.ConcatMessages(slice)
if err != nil {
return nil, err
}

ret[i] = cm
}
}

return ret, nil
}

// RegisterStreamChunkConcatFunc registers a function to concat stream chunks.
// It's required when you want to concat stream chunks of a specific type.
// for example you call Invoke() but node only implements Stream().
Expand All @@ -109,41 +43,7 @@ func concatMessageArray(mas [][]*schema.Message) ([]*schema.Message, error) {
// }, nil
// })
func RegisterStreamChunkConcatFunc[T any](fn func([]T) (T, error)) {
concatFuncs[generic.TypeOf[T]()] = fn
}

func getConcatFunc(tpe reflect.Type) func(reflect.Value) (reflect.Value, error) {
if fn, ok := concatFuncs[tpe]; ok {
return func(a reflect.Value) (reflect.Value, error) {
rvs := reflect.ValueOf(fn).Call([]reflect.Value{a})
var err error
if !rvs[1].IsNil() {
err = rvs[1].Interface().(error)
}
return rvs[0], err
}
}

return nil
}

func toSliceValue(vs []any) (reflect.Value, error) {
typ := reflect.TypeOf(vs[0])

ret := reflect.MakeSlice(reflect.SliceOf(typ), len(vs), len(vs))
ret.Index(0).Set(reflect.ValueOf(vs[0]))

for i := 1; i < len(vs); i++ {
v := vs[i]
vt := reflect.TypeOf(v)
if typ != vt {
return reflect.Value{}, fmt.Errorf("unexpected slice element type. Got %v, expected %v", typ, vt)
}

ret.Index(i).Set(reflect.ValueOf(v))
}

return ret, nil
internal.RegisterStreamChunkConcatFunc(fn)
}

func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) {
Expand Down Expand Up @@ -174,174 +74,10 @@ func concatStreamReader[T any](sr *schema.StreamReader[T]) (T, error) {
return items[0], nil
}

res, err := concatItems(items)
res, err := internal.ConcatItems(items)
if err != nil {
var t T
return t, err
}
return res, nil
}

// the caller should ensure len(items) > 1
func concatItems[T any](items []T) (T, error) {
typ := generic.TypeOf[T]()
v := reflect.ValueOf(items)

var cv reflect.Value
var err error

// handle map kind
if typ.Kind() == reflect.Map {
cv, err = concatMaps(v)
} else {
cv, err = concatSliceValue(v)
}

if err != nil {
var t T
return t, err
}

return cv.Interface().(T), nil
}

func concatMaps(ms reflect.Value) (reflect.Value, error) {
typ := ms.Type().Elem()

rms := reflect.MakeMap(reflect.MapOf(typ.Key(), generic.TypeOf[[]any]()))
ret := reflect.MakeMap(typ)

n := ms.Len()
for i := 0; i < n; i++ {
m := ms.Index(i)

for _, key := range m.MapKeys() {
vals := rms.MapIndex(key)
if !vals.IsValid() {
var s []any
vals = reflect.ValueOf(s)
}

val := m.MapIndex(key)
vals = reflect.Append(vals, val)
rms.SetMapIndex(key, vals)
}
}

for _, key := range rms.MapKeys() {
vals := rms.MapIndex(key)

anyVals := vals.Interface().([]any)
v, err := toSliceValue(anyVals)
if err != nil {
return reflect.Value{}, err
}

var cv reflect.Value

if v.Type().Elem().Kind() == reflect.Map {
cv, err = concatMaps(v)
} else {
cv, err = concatSliceValue(v)
}

if err != nil {
return reflect.Value{}, err
}

ret.SetMapIndex(key, cv)
}

return ret, nil
}

func concatSliceValue(val reflect.Value) (reflect.Value, error) {
elmType := val.Type().Elem()

if val.Len() == 1 {
return val.Index(0), nil
}

f := getConcatFunc(elmType)
if f != nil {
return f(val)
}

var (
structType reflect.Type
isStructPtr bool
)

if elmType.Kind() == reflect.Struct {
structType = elmType
} else if elmType.Kind() == reflect.Pointer && elmType.Elem().Kind() == reflect.Struct {
isStructPtr = true
structType = elmType.Elem()
}

if structType != nil {
maps := make([]map[string]any, 0, val.Len())
for i := 0; i < val.Len(); i++ {
sliceElem := val.Index(i)
m, err := structToMap(sliceElem)
if err != nil {
return reflect.Value{}, err
}

maps = append(maps, m)
}

result, err := concatMaps(reflect.ValueOf(maps))
if err != nil {
return reflect.Value{}, err
}

return mapToStruct(result.Interface().(map[string]any), structType, isStructPtr), nil
}

var filtered reflect.Value
for i := 0; i < val.Len(); i++ {
oneVal := val.Index(i)
if !oneVal.IsZero() {
if filtered.IsValid() {
return reflect.Value{}, fmt.Errorf("cannot concat multiple non-zero value of type %s", elmType)
}

filtered = oneVal
}
}

return filtered, nil
}

func structToMap(s reflect.Value) (map[string]any, error) {
if s.Kind() == reflect.Ptr {
s = s.Elem()
}

ret := make(map[string]any, s.NumField())
for i := 0; i < s.NumField(); i++ {
fieldType := s.Type().Field(i)
if !fieldType.IsExported() {
return nil, fmt.Errorf("structToMap: field %s is not exported", fieldType.Name)
}

ret[fieldType.Name] = s.Field(i).Interface()
}

return ret, nil
}

func mapToStruct(m map[string]any, t reflect.Type, toPtr bool) reflect.Value {
ret := reflect.New(t).Elem()
for k, v := range m {
field := ret.FieldByName(k)
field.Set(reflect.ValueOf(v))
}

if toPtr {
ret = ret.Addr()
}

return ret
}
23 changes: 20 additions & 3 deletions compose/stream_concat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (

"github.com/stretchr/testify/assert"

"github.com/cloudwego/eino/internal"
"github.com/cloudwego/eino/schema"
)

Expand Down Expand Up @@ -112,7 +113,7 @@ func TestMessageConcat(t *testing.T) {
assert.Equal(t, "0123456789", lastVal.Content)
assert.Len(t, lastVal.Extra, 4)
assert.Equal(t, map[string]any{
"key_1": "8",
"key_1": "048",
"0": "0",
"4": "4",
"8": "8",
Expand Down Expand Up @@ -193,14 +194,30 @@ func TestConcatError(t *testing.T) {
"str": "string_02",
"x": 123,
}
_, err := concatItems([]map[string]any{a, b})
_, err := internal.ConcatItems([]map[string]any{a, b})
assert.NotNil(t, err)
})

t.Run("merge error", func(t *testing.T) {
RegisterStreamChunkConcatFunc(concatTStreamError)

_, err := concatItems([]tConcatErrForTest{{}, {}})
_, err := internal.ConcatItems([]tConcatErrForTest{{}, {}})
assert.NotNil(t, err)
})
}

func TestConcatSliceValue(t *testing.T) {
type testStruct struct {
A string
}

s := []testStruct{{}, {A: "123"}, {}}
result, err := internal.ConcatItems(s)
assert.Nil(t, err)
assert.Equal(t, testStruct{A: "123"}, result)

s = []testStruct{{}, {}, {}}
result, err = internal.ConcatItems(s)
assert.Nil(t, err)
assert.Equal(t, testStruct{}, result)
}
3 changes: 2 additions & 1 deletion compose/tool_node_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/cloudwego/eino/components/model"
"github.com/cloudwego/eino/components/tool"
"github.com/cloudwego/eino/components/tool/utils"
"github.com/cloudwego/eino/internal"
"github.com/cloudwego/eino/schema"
)

Expand Down Expand Up @@ -410,7 +411,7 @@ func TestToolsNodeOptions(t *testing.T) {
}
outStream.Close()

msgs, err := concatMessageArray(outMessages)
msgs, err := internal.ConcatItems(outMessages)
assert.NoError(t, err)

assert.Len(t, msgs, 1)
Expand Down
11 changes: 11 additions & 0 deletions compose/workflow_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ func TestWorkflow(t *testing.T) {
B int
StateTemp string
}
RegisterStreamChunkConcatFunc(func(ts []*structF) (*structF, error) {
ret := &structF{}
for _, tt := range ts {
ret.Field1 += tt.Field1
ret.Field2 += tt.Field2
ret.Field3 = append(ret.Field3, tt.Field3...)
ret.B += tt.B
ret.StateTemp += tt.StateTemp
}
return ret, nil
})

type state struct {
temp string
Expand Down
Loading

0 comments on commit 7b64b27

Please sign in to comment.