Skip to content

Commit

Permalink
fix: add state handler type validate (#86)
Browse files Browse the repository at this point in the history
* fix: add state handler type validate

* feat: update react & multi agent comment
  • Loading branch information
meguminnnnnnnnn authored Feb 27, 2025
1 parent b769ad2 commit 93cb521
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 0 deletions.
6 changes: 6 additions & 0 deletions compose/generic_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ package compose

import (
"context"
"reflect"

"github.com/cloudwego/eino/internal/generic"
)

type newGraphOptions struct {
withState func(ctx context.Context) any
stateType reflect.Type
}

type NewGraphOption func(ngo *newGraphOptions)
Expand All @@ -31,6 +35,7 @@ func WithGenLocalState[S any](gls GenLocalState[S]) NewGraphOption {
ngo.withState = func(ctx context.Context) any {
return gls(ctx)
}
ngo.stateType = generic.TypeOf[S]()
}
}

Expand Down Expand Up @@ -70,6 +75,7 @@ func NewGraph[I, O any](opts ...NewGraphOption) *Graph[I, O] {
newGraphFromGeneric[I, O](
ComponentOfGraph,
options.withState,
options.stateType,
),
}

Expand Down
33 changes: 33 additions & 0 deletions compose/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ type graph struct {
mappings []*FieldMapping
}

stateType reflect.Type
stateGenerator func(ctx context.Context) any

expectedInputType, expectedOutputType reflect.Type
Expand Down Expand Up @@ -187,12 +188,14 @@ type newGraphConfig struct {
inputFieldMappingConverter, outputFieldMappingConverter valueHandler
inputStreamFieldMappingConverter, outputStreamFieldMappingConverter streamHandler
cmp component
stateType reflect.Type
stateGenerator func(ctx context.Context) any
}

func newGraphFromGeneric[I, O any](
cmp component,
stateGenerator func(ctx context.Context) any,
stateType reflect.Type,
) *graph {
return newGraph(&newGraphConfig{
inputType: generic.TypeOf[I](),
Expand All @@ -207,6 +210,7 @@ func newGraphFromGeneric[I, O any](
inputStreamFieldMappingConverter: buildStreamFieldMappingConverter[I](),
outputStreamFieldMappingConverter: buildStreamFieldMappingConverter[O](),
cmp: cmp,
stateType: stateType,
stateGenerator: stateGenerator,
})
}
Expand Down Expand Up @@ -247,6 +251,7 @@ func newGraph(cfg *newGraphConfig) *graph {

cmp: cfg.cmp,

stateType: cfg.stateType,
stateGenerator: cfg.stateGenerator,
handlerOnEdges: make(map[string]map[string][]handlerPair),
handlerPreNode: make(map[string][]handlerPair),
Expand Down Expand Up @@ -306,6 +311,34 @@ func (g *graph) addNode(key string, node *graphNode, options *graphAddNodeOpts)
}
// end: check options

// check pre- / post-handler type
if options.processor != nil {
if options.processor.statePreHandler != nil {
// check state type
if g.stateType != options.processor.preStateType {
return fmt.Errorf("node[%s]'s pre handler state type[%v] is different from graph[%v]", key, options.processor.preStateType, g.stateType)
}
// check input type
if node.inputType() == nil && options.processor.statePreHandler.outputType != reflect.TypeOf((*any)(nil)).Elem() {
return fmt.Errorf("passthrough node[%s]'s pre handler type isn't any", key)
} else if node.inputType() != nil && node.inputType() != options.processor.statePreHandler.outputType {
return fmt.Errorf("node[%s]'s pre handler type[%v] is different from its input type[%v]", key, options.processor.statePreHandler.outputType, node.inputType())
}
}
if options.processor.statePostHandler != nil {
// check state type
if g.stateType != options.processor.postStateType {
return fmt.Errorf("node[%s]'s post handler state type[%v] is different from graph[%v]", key, options.processor.postStateType, g.stateType)
}
// check input type
if node.outputType() == nil && options.processor.statePostHandler.inputType != reflect.TypeOf((*any)(nil)).Elem() {
return fmt.Errorf("passthrough node[%s]'s post handler type isn't any", key)
} else if node.outputType() != nil && node.outputType() != options.processor.statePostHandler.inputType {
return fmt.Errorf("node[%s]'s post handler type[%v] is different from its output type[%v]", key, options.processor.statePostHandler.inputType, node.outputType())
}
}
}

g.nodes[key] = node

return nil
Expand Down
12 changes: 12 additions & 0 deletions compose/graph_add_node_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@

package compose

import (
"reflect"

"github.com/cloudwego/eino/internal/generic"
)

type graphAddNodeOpts struct {
nodeOptions *nodeOptions
processor *processorOpts
Expand Down Expand Up @@ -90,6 +96,7 @@ func WithGraphCompileOptions(opts ...GraphCompileOption) GraphAddNodeOpt {
func WithStatePreHandler[I, S any](pre StatePreHandler[I, S]) GraphAddNodeOpt {
return func(o *graphAddNodeOpts) {
o.processor.statePreHandler = convertPreHandler(pre)
o.processor.preStateType = generic.TypeOf[S]()
o.needState = true
}
}
Expand All @@ -101,6 +108,7 @@ func WithStatePreHandler[I, S any](pre StatePreHandler[I, S]) GraphAddNodeOpt {
func WithStatePostHandler[O, S any](post StatePostHandler[O, S]) GraphAddNodeOpt {
return func(o *graphAddNodeOpts) {
o.processor.statePostHandler = convertPostHandler(post)
o.processor.postStateType = generic.TypeOf[S]()
o.needState = true
}
}
Expand All @@ -114,6 +122,7 @@ func WithStatePostHandler[O, S any](post StatePostHandler[O, S]) GraphAddNodeOpt
func WithStreamStatePreHandler[I, S any](pre StreamStatePreHandler[I, S]) GraphAddNodeOpt {
return func(o *graphAddNodeOpts) {
o.processor.statePreHandler = streamConvertPreHandler(pre)
o.processor.preStateType = generic.TypeOf[S]()
o.needState = true
}
}
Expand All @@ -127,13 +136,16 @@ func WithStreamStatePreHandler[I, S any](pre StreamStatePreHandler[I, S]) GraphA
func WithStreamStatePostHandler[O, S any](post StreamStatePostHandler[O, S]) GraphAddNodeOpt {
return func(o *graphAddNodeOpts) {
o.processor.statePostHandler = streamConvertPostHandler(post)
o.processor.postStateType = generic.TypeOf[S]()
o.needState = true
}
}

type processorOpts struct {
statePreHandler *composableRunnable
preStateType reflect.Type // used for type validation
statePostHandler *composableRunnable
postStateType reflect.Type // used for type validation
}

func getGraphAddNodeOpts(opts ...GraphAddNodeOpt) *graphAddNodeOpts {
Expand Down
110 changes: 110 additions & 0 deletions compose/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1591,3 +1591,113 @@ func TestNestedDAGBranch(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, "hellohello24", result)
}

func TestHandlerTypeValidate(t *testing.T) {
g := NewGraph[string, string](WithGenLocalState(func(ctx context.Context) (state string) {
return ""
}))
// passthrough pre fail
err := g.AddPassthroughNode("1", WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}))
assert.ErrorContains(t, err, "passthrough node[1]'s pre handler type isn't any")
g.buildError = nil
// passthrough pre fail with input key
err = g.AddPassthroughNode("1", WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}), WithInputKey("input"))
assert.ErrorContains(t, err, "node[1]'s pre handler type[string] is different from its input type[map[string]interface {}]")
g.buildError = nil
// passthrough post fail
err = g.AddPassthroughNode("1", WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}))
assert.ErrorContains(t, err, "passthrough node[1]'s post handler type isn't any")
g.buildError = nil
// passthrough post fail with input key
err = g.AddPassthroughNode("1", WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}), WithInputKey("input"))
assert.ErrorContains(t, err, "passthrough node[1]'s post handler type isn't any")
g.buildError = nil
// passthrough pre success
err = g.AddPassthroughNode("1", WithStatePreHandler(func(ctx context.Context, in any, state string) (any, error) {
return "", nil
}))
assert.NoError(t, err)
// passthrough pre success with input key
err = g.AddPassthroughNode("2", WithStatePreHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) {
return nil, nil
}), WithInputKey("input"))
assert.NoError(t, err)
// passthrough post success
err = g.AddPassthroughNode("3", WithStatePostHandler(func(ctx context.Context, in any, state string) (any, error) {
return "", nil
}))
assert.NoError(t, err)
// passthrough post success with output key
err = g.AddPassthroughNode("4", WithStatePostHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) {
return nil, nil
}), WithOutputKey("output"))
assert.NoError(t, err)
// common node pre fail
err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input int) (output int, err error) {
return 0, nil
}), WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}))
assert.ErrorContains(t, err, "node[5]'s pre handler type[string] is different from its input type[int]")
g.buildError = nil
// common node post fail
err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input int) (output int, err error) {
return 0, nil
}), WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}))
assert.ErrorContains(t, err, "node[5]'s post handler type[string] is different from its output type[int]")
g.buildError = nil
// common node pre success
err = g.AddLambdaNode("5", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "", nil
}), WithStatePreHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}))
assert.NoError(t, err)
// common node post success
err = g.AddLambdaNode("6", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "", nil
}), WithStatePostHandler(func(ctx context.Context, in string, state string) (string, error) {
return "", nil
}))
assert.NoError(t, err)
// pre state fail
err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "", nil
}), WithStatePreHandler(func(ctx context.Context, in string, state int) (string, error) {
return "", nil
}))
assert.ErrorContains(t, err, "node[7]'s pre handler state type[int] is different from graph[string]")
g.buildError = nil
// post state fail
err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "", nil
}), WithStatePostHandler(func(ctx context.Context, in string, state int) (string, error) {
return "", nil
}))
assert.ErrorContains(t, err, "node[7]'s post handler state type[int] is different from graph[string]")
g.buildError = nil
// common pre success with input key
err = g.AddLambdaNode("7", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "", nil
}), WithStatePreHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) {
return nil, nil
}), WithInputKey("input"))
assert.NoError(t, err)
// common post success with output key
err = g.AddLambdaNode("8", InvokableLambda(func(ctx context.Context, input string) (output string, err error) {
return "", nil
}), WithStatePostHandler(func(ctx context.Context, in map[string]any, state string) (map[string]any, error) {
return nil, nil
}), WithOutputKey("output"))
assert.NoError(t, err)
}
1 change: 1 addition & 0 deletions compose/workflow.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ func NewWorkflow[I, O any](opts ...NewGraphOption) *Workflow[I, O] {
g: newGraphFromGeneric[I, O](
ComponentOfWorkflow,
options.withState,
options.stateType,
),
}

Expand Down
1 change: 1 addition & 0 deletions flow/agent/multiagent/host/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ type MultiAgentConfig struct {
// Note: The handler MUST close the modelOutput stream before returning
// Optional. By default, it checks if the first chunk contains tool calls.
// Note: The default implementation does not work well with Claude, which typically outputs tool calls after text content.
// Note: If your ChatModel doesn't output tool calls first, you can try adding prompts to constrain the model from generating extra text during the tool call.
StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error)
}

Expand Down
1 change: 1 addition & 0 deletions flow/agent/react/react.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type AgentConfig struct {
// Note: The handler MUST close the modelOutput stream before returning
// Optional. By default, it checks if the first chunk contains tool calls.
// Note: The default implementation does not work well with Claude, which typically outputs tool calls after text content.
// Note: If your ChatModel doesn't output tool calls first, you can try adding prompts to constrain the model from generating extra text during the tool call.
StreamToolCallChecker func(ctx context.Context, modelOutput *schema.StreamReader[*schema.Message]) (bool, error)
}

Expand Down

0 comments on commit 93cb521

Please sign in to comment.