Skip to content

Commit

Permalink
feat: react agent and host multi-agent can be composed directly as An…
Browse files Browse the repository at this point in the history
…yGraph (#72)

feat: export agent to graph

Change-Id: I3266a9f116ae1767b9a6293f02f1716d80bff03f
  • Loading branch information
shentongmartin authored Feb 19, 2025
1 parent 648ec7c commit 4bfe584
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 40 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/pr-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ jobs:

golangci-lint:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
steps:
- uses: actions/checkout@v4
- name: Set up Go
Expand Down
16 changes: 16 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ jobs:
unit-test:
name: eino-unit-test
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
env:
COVERAGE_FILE: coverage.out
BREAKDOWN_FILE: main.breakdown
Expand Down Expand Up @@ -100,6 +104,10 @@ jobs:
run: echo "coverage check failed" && exit 1
benchmark-test:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
steps:
- uses: actions/checkout@v4
- name: Set up Go
Expand All @@ -115,6 +123,10 @@ jobs:
matrix:
go: [ "1.19", "1.20", "1.21", "1.22" ]
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
steps:
- uses: actions/checkout@v4
- name: Set up Go
Expand All @@ -131,6 +143,10 @@ jobs:
api-compatibility:
name: api-compatibility-check
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
repository-projects: write
if: github.event_name == 'pull_request'

steps:
Expand Down
5 changes: 3 additions & 2 deletions .golangci.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# output configuration options
output:
# Format: colored-line-number|line-number|json|tab|checkstyle|code-climate|junit-xml|github-actions
formats: colored-line-number
formats:
- format: colored-line-number
# All available settings of specific linters.
# Refer to https://golangci-lint.run/usage/linters
linters-settings:
Expand Down Expand Up @@ -31,4 +32,4 @@ issues:
exclude-use-default: true
exclude-files:
- ".*\\.mock\\.go$"
exclude-dirs:
# exclude-dirs:
28 changes: 16 additions & 12 deletions flow/agent/multiagent/host/callback.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (
)

// MultiAgentCallback is the callback interface for host multi-agent.
type MultiAgentCallback interface { // nolint: byted_s_interface_name
type MultiAgentCallback interface {
OnHandOff(ctx context.Context, info *HandOffInfo) context.Context
}

Expand All @@ -38,13 +38,8 @@ type HandOffInfo struct {
Argument string
}

// convertCallbacks reads graph call options, extract host.MultiAgentCallback and convert it to callbacks.Handler.
func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
agentOptions := agent.GetImplSpecificOptions(&options{}, opts...)
if len(agentOptions.agentCallbacks) == 0 {
return nil
}

// ConvertCallbackHandlers converts []host.MultiAgentCallback to callbacks.Handler.
func ConvertCallbackHandlers(handlers ...MultiAgentCallback) callbacks.Handler {
onChatModelEnd := func(ctx context.Context, info *callbacks.RunInfo, output *model.CallbackOutput) context.Context {
if output == nil || info == nil {
return ctx
Expand All @@ -58,8 +53,7 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
agentName := msg.ToolCalls[0].Function.Name
argument := msg.ToolCalls[0].Function.Arguments

for i := range agentOptions.agentCallbacks {
cb := agentOptions.agentCallbacks[i]
for _, cb := range handlers {
ctx = cb.OnHandOff(ctx, &HandOffInfo{
ToAgentName: agentName,
Argument: argument,
Expand Down Expand Up @@ -103,8 +97,7 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
return ctx
}

for i := range agentOptions.agentCallbacks {
cb := agentOptions.agentCallbacks[i]
for _, cb := range handlers {
ctx = cb.OnHandOff(ctx, &HandOffInfo{
ToAgentName: msg.ToolCalls[0].Function.Name,
Argument: msg.ToolCalls[0].Function.Arguments,
Expand All @@ -119,3 +112,14 @@ func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
OnEndWithStreamOutput: onChatModelEndWithStreamOutput,
}).Handler()
}

// convertCallbacks reads graph call options, extract host.MultiAgentCallback and convert it to callbacks.Handler.
func convertCallbacks(opts ...agent.AgentOption) callbacks.Handler {
agentOptions := agent.GetImplSpecificOptions(&options{}, opts...)
if len(agentOptions.agentCallbacks) == 0 {
return nil
}

handlers := agentOptions.agentCallbacks
return ConvertCallbackHandlers(handlers...)
}
17 changes: 10 additions & 7 deletions flow/agent/multiagent/host/compose.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import (
)

const (
hostName = "host"
defaultHostPrompt = "decide which tool is best for the task and call only the best tool."
defaultHostNodeKey = "host" // the key of the host node in the graph
defaultHostPrompt = "decide which tool is best for the task and call only the best tool."
)

type state struct {
Expand Down Expand Up @@ -105,13 +105,16 @@ func NewMultiAgent(ctx context.Context, config *MultiAgentConfig) (*MultiAgent,
return nil, err
}

r, err := g.Compile(ctx, compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(name))
compileOpts := []compose.GraphCompileOption{compose.WithNodeTriggerMode(compose.AnyPredecessor), compose.WithGraphName(name)}
r, err := g.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}

return &MultiAgent{
runnable: r,
runnable: r,
graph: g,
graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)},
}, nil
}

Expand Down Expand Up @@ -161,11 +164,11 @@ func addHostAgent(model model.ChatModel, prompt string, agentTools []*schema.Too
Content: prompt,
}}, input...), nil
}
if err := g.AddChatModelNode(hostName, model, compose.WithStatePreHandler(preHandler), compose.WithNodeName(hostName)); err != nil {
if err := g.AddChatModelNode(defaultHostNodeKey, model, compose.WithStatePreHandler(preHandler), compose.WithNodeName(defaultHostNodeKey)); err != nil {
return err
}

return g.AddEdge(compose.START, hostName)
return g.AddEdge(compose.START, defaultHostNodeKey)
}

func addDirectAnswerBranch(convertorName string, g *compose.Graph[[]*schema.Message, *schema.Message],
Expand All @@ -182,7 +185,7 @@ func addDirectAnswerBranch(convertorName string, g *compose.Graph[[]*schema.Mess
return compose.END, nil
}, map[string]bool{convertorName: true, compose.END: true})

return g.AddBranch(hostName, branch)
return g.AddBranch(defaultHostNodeKey, branch)
}

func addSpecialistsBranch(convertorName string, agentMap map[string]bool, g *compose.Graph[[]*schema.Message, *schema.Message]) error {
Expand Down
57 changes: 57 additions & 0 deletions flow/agent/multiagent/host/compose_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"

"github.com/cloudwego/eino/components/prompt"
"github.com/cloudwego/eino/compose"
"github.com/cloudwego/eino/flow/agent"
"github.com/cloudwego/eino/internal/mock/components/model"
"github.com/cloudwego/eino/schema"
Expand Down Expand Up @@ -362,6 +364,61 @@ func TestHostMultiAgent(t *testing.T) {
},
}, mockCallback.infos)
})

t.Run("multi-agent within graph", func(t *testing.T) {
handOffMsg := &schema.Message{
Role: schema.Assistant,
ToolCalls: []schema.ToolCall{
{
Index: generic.PtrOf(0),
Function: schema.FunctionCall{
Name: specialist1.Name,
Arguments: `{"reason": "specialist 1 is the best"}`,
},
},
},
}

specialistMsg := &schema.Message{
Role: schema.Assistant,
Content: "Beijing",
}

mockHostLLM.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(handOffMsg, nil).Times(1)
mockSpecialistLLM1.EXPECT().Generate(gomock.Any(), gomock.Any()).Return(specialistMsg, nil).Times(1)

mockCallback := &mockAgentCallback{}

hostMA, err := NewMultiAgent(ctx, &MultiAgentConfig{
Host: Host{
ChatModel: mockHostLLM,
},
Specialists: []*Specialist{
specialist1,
specialist2,
},
})

assert.NoError(t, err)

maGraph, opts := hostMA.ExportGraph()

fullGraph, err := compose.NewChain[map[string]any, *schema.Message]().
AppendChatTemplate(prompt.FromMessages(schema.FString, schema.UserMessage("what's the capital city of {country_name}"))).
AppendGraph(maGraph, append(opts, compose.WithNodeKey("host_ma_node"))...).
Compile(ctx)
assert.NoError(t, err)

out, err := fullGraph.Invoke(ctx, map[string]any{"country_name": "China"}, compose.WithCallbacks(ConvertCallbackHandlers(mockCallback)).DesignateNodeWithPath(compose.NewNodePath("host_ma_node", hostMA.HostNodeKey())))
assert.NoError(t, err)
assert.Equal(t, "Beijing", out.Content)
assert.Equal(t, []*HandOffInfo{
{
ToAgentName: specialist1.Name,
Argument: `{"reason": "specialist 1 is the best"}`,
},
}, mockCallback.infos)
})
}

type mockAgentCallback struct {
Expand Down
23 changes: 17 additions & 6 deletions flow/agent/multiagent/host/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,17 @@ import (
// A host agent is responsible for deciding which specialist to 'hand off' the task to.
// One or more specialist agents are responsible for completing the task.
type MultiAgent struct {
runnable compose.Runnable[[]*schema.Message, *schema.Message]
runnable compose.Runnable[[]*schema.Message, *schema.Message]
graph *compose.Graph[[]*schema.Message, *schema.Message]
graphAddNodeOpts []compose.GraphAddNodeOpt
}

func (ma *MultiAgent) Generate(ctx context.Context, input []*schema.Message, opts ...agent.AgentOption) (*schema.Message, error) {
composeOptions := agent.GetComposeOptions(opts...)

handler := convertCallbacks(opts...)
if handler != nil {
composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(hostName))
composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(ma.HostNodeKey()))
}

return ma.runnable.Invoke(ctx, input, composeOptions...)
Expand All @@ -51,12 +53,21 @@ func (ma *MultiAgent) Stream(ctx context.Context, input []*schema.Message, opts

handler := convertCallbacks(opts...)
if handler != nil {
composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(hostName))
composeOptions = append(composeOptions, compose.WithCallbacks(handler).DesignateNode(ma.HostNodeKey()))
}

return ma.runnable.Stream(ctx, input, composeOptions...)
}

// ExportGraph exports the underlying graph from MultiAgent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph.
func (ma *MultiAgent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) {
return ma.graph, ma.graphAddNodeOpts
}

func (ma *MultiAgent) HostNodeKey() string {
return defaultHostNodeKey
}

// MultiAgentConfig is the config for host multi-agent system.
type MultiAgentConfig struct {
Host Host
Expand Down Expand Up @@ -133,9 +144,9 @@ type Host struct {
// Specialist is a specialist agent within a host multi-agent system.
// It can be a model.ChatModel or any Invokable and/or Streamable, such as react.Agent.
// ChatModel and (Invokable / Streamable) are mutually exclusive, only one should be provided.
// If Invokable is provided but not Streamable, then the Specialist will be compose.InvokableLambda.
// If Streamable is provided but not Invokable, then the Specialist will be compose.StreamableLambda.
// if Both Invokable and Streamable is provided, then the Specialist will be compose.AnyLambda.
// If Invokable is provided but not Streamable, then the Specialist will be 'compose.InvokableLambda'.
// If Streamable is provided but not Invokable, then the Specialist will be 'compose.StreamableLambda'.
// if Both Invokable and Streamable is provided, then the Specialist will be 'compose.AnyLambda'.
type Specialist struct {
AgentMeta

Expand Down
18 changes: 15 additions & 3 deletions flow/agent/react/react.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ func firstChunkStreamToolCallChecker(_ context.Context, sr *schema.StreamReader[
// if err != nil {...}
// println(msg.Content)
type Agent struct {
runnable compose.Runnable[[]*schema.Message, *schema.Message]
runnable compose.Runnable[[]*schema.Message, *schema.Message]
graph *compose.Graph[[]*schema.Message, *schema.Message]
graphAddNodeOpts []compose.GraphAddNodeOpt
}

// NewAgent creates a ReAct agent that feeds tool response into next round of Chat Model generation.
Expand Down Expand Up @@ -210,12 +212,17 @@ func NewAgent(ctx context.Context, config *AgentConfig) (_ *Agent, err error) {
return nil, err
}

runnable, err := graph.Compile(ctx, compose.WithMaxRunSteps(config.MaxStep))
compileOpts := []compose.GraphCompileOption{compose.WithMaxRunSteps(config.MaxStep), compose.WithNodeTriggerMode(compose.AnyPredecessor)}
runnable, err := graph.Compile(ctx, compileOpts...)
if err != nil {
return nil, err
}

return &Agent{runnable: runnable}, nil
return &Agent{
runnable: runnable,
graph: graph,
graphAddNodeOpts: []compose.GraphAddNodeOpt{compose.WithGraphCompileOptions(compileOpts...)},
}, nil
}

func buildReturnDirectly(graph *compose.Graph[[]*schema.Message, *schema.Message]) (err error) {
Expand Down Expand Up @@ -318,3 +325,8 @@ func (r *Agent) Stream(ctx context.Context, input []*schema.Message, opts ...age

return res, nil
}

// ExportGraph exports the underlying graph from Agent, along with the []compose.GraphAddNodeOpt to be used when adding this graph to another graph.
func (r *Agent) ExportGraph() (compose.AnyGraph, []compose.GraphAddNodeOpt) {
return r.graph, r.graphAddNodeOpts
}
Loading

0 comments on commit 4bfe584

Please sign in to comment.