From 89195f855400b51ab33e35ac619af4e47072f2a1 Mon Sep 17 00:00:00 2001 From: Pavlo Paliychuk Date: Fri, 14 Jun 2024 23:19:11 -0400 Subject: [PATCH] memory: Add Zep Cloud integration (#746) * memory: Add Zep Cloud integration --- chains/conversation_test.go | 36 +++++ examples/zep-memory-chain-example/main.go | 48 +++++++ go.mod | 3 +- go.sum | 6 + memory/buffer.go | 6 +- memory/zep/zep_chat_history.go | 164 ++++++++++++++++++++++ memory/zep/zep_chat_history_options.go | 40 ++++++ memory/zep/zep_memory.go | 119 ++++++++++++++++ memory/zep/zep_memory_options.go | 74 ++++++++++ 9 files changed, 492 insertions(+), 4 deletions(-) create mode 100644 examples/zep-memory-chain-example/main.go create mode 100644 memory/zep/zep_chat_history.go create mode 100644 memory/zep/zep_chat_history_options.go create mode 100644 memory/zep/zep_memory.go create mode 100644 memory/zep/zep_memory_options.go diff --git a/chains/conversation_test.go b/chains/conversation_test.go index 0f3056f65..b1d86c9d5 100644 --- a/chains/conversation_test.go +++ b/chains/conversation_test.go @@ -6,9 +6,13 @@ import ( "strings" "testing" + z "github.com/getzep/zep-go" + zClient "github.com/getzep/zep-go/client" + zOption "github.com/getzep/zep-go/option" "github.com/stretchr/testify/require" "github.com/tmc/langchaingo/llms/openai" "github.com/tmc/langchaingo/memory" + "github.com/tmc/langchaingo/memory/zep" ) func TestConversation(t *testing.T) { @@ -29,6 +33,38 @@ func TestConversation(t *testing.T) { require.True(t, strings.Contains(res, "Jim"), `result does not contain the keyword 'Jim'`) } +func TestConversationWithZepMemory(t *testing.T) { + t.Parallel() + + if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + llm, err := openai.New() + require.NoError(t, err) + + zc := zClient.NewClient( + zOption.WithAPIKey(os.Getenv("ZEP_API_KEY")), + ) + sessionID := os.Getenv("ZEP_SESSION_ID") + + c := NewConversation( + llm, + zep.NewMemory( + zc, + sessionID, + zep.WithMemoryType(z.MemoryGetRequestMemoryTypePerpetual), + zep.WithHumanPrefix("Joe"), + zep.WithAIPrefix("Robot"), + ), + ) + _, err = Run(context.Background(), c, "Hi! I'm Jim") + require.NoError(t, err) + + res, err := Run(context.Background(), c, "What is my name?") + require.NoError(t, err) + require.True(t, strings.Contains(res, "Jim"), `result does not contain the keyword 'Jim'`) +} + func TestConversationWithChatLLM(t *testing.T) { t.Parallel() diff --git a/examples/zep-memory-chain-example/main.go b/examples/zep-memory-chain-example/main.go new file mode 100644 index 000000000..6d9b1a916 --- /dev/null +++ b/examples/zep-memory-chain-example/main.go @@ -0,0 +1,48 @@ +package main + +import ( + "context" + "fmt" + "github.com/getzep/zep-go" + zepClient "github.com/getzep/zep-go/client" + zepOption "github.com/getzep/zep-go/option" + "github.com/tmc/langchaingo/chains" + "github.com/tmc/langchaingo/llms/openai" + zepLangchainMemory "github.com/tmc/langchaingo/memory/zep" + "os" +) + +func main() { + ctx := context.Background() + + client := zepClient.NewClient(zepOption.WithAPIKey(os.Getenv("ZEP_API_KEY"))) + sessionID := os.Getenv("ZEP_SESSION_ID") + + llm, err := openai.New() + if err != nil { + fmt.Printf("Error: %s\n", err) + } + + c := chains.NewConversation( + llm, + zepLangchainMemory.NewMemory( + client, + sessionID, + zepLangchainMemory.WithMemoryType(zep.MemoryGetRequestMemoryTypePerpetual), + zepLangchainMemory.WithReturnMessages(true), + zepLangchainMemory.WithAIPrefix("Robot"), + zepLangchainMemory.WithHumanPrefix("Joe"), + ), + ) + res, err := chains.Run(ctx, c, "Hi! I'm John Doe") + if err != nil { + fmt.Printf("Error: %s\n", err) + } + fmt.Printf("Response: %s\n", res) + + res, err = chains.Run(ctx, c, "What is my name?") + if err != nil { + fmt.Printf("Error: %s\n", err) + } + fmt.Printf("Response: %s\n", res) +} diff --git a/go.mod b/go.mod index 71c18c892..9ea5eb97a 100644 --- a/go.mod +++ b/go.mod @@ -135,7 +135,6 @@ require ( github.com/pmezard/go-difflib v1.0.0 // indirect github.com/power-devops/perfstat v0.0.0-20210106213030-5aafc221ea8c // indirect github.com/rogpeppe/go-internal v1.11.0 // indirect - github.com/rs/zerolog v1.31.0 // indirect github.com/saintfish/chardet v0.0.0-20230101081208-5e3ef4b5456d // indirect github.com/shirou/gopsutil/v3 v3.23.12 // indirect github.com/shoenig/go-m1cpu v0.1.6 // indirect @@ -195,6 +194,7 @@ require ( github.com/cohere-ai/tokenizer v1.1.2 github.com/fatih/color v1.17.0 github.com/gage-technologies/mistral-go v1.0.0 + github.com/getzep/zep-go v1.0.4 github.com/go-openapi/strfmt v0.21.3 github.com/go-sql-driver/mysql v1.7.1 github.com/gocolly/colly v1.2.0 @@ -214,6 +214,7 @@ require ( github.com/pinecone-io/go-pinecone v0.4.1 github.com/pkoukk/tiktoken-go v0.1.6 github.com/redis/rueidis v1.0.34 + github.com/rs/zerolog v1.31.0 github.com/weaviate/weaviate v1.24.1 github.com/weaviate/weaviate-go-client/v4 v4.13.1 gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a diff --git a/go.sum b/go.sum index 619677da9..b11adb140 100644 --- a/go.sum +++ b/go.sum @@ -200,6 +200,11 @@ github.com/gavv/httpexpect v2.0.0+incompatible/go.mod h1:x+9tiU1YnrOvnB725RkpoLv github.com/getsentry/raven-go v0.2.0/go.mod h1:KungGk8q33+aIAZUIVWZDr2OfAEBsO49PX4NzFV5kcQ= github.com/getsentry/sentry-go v0.12.0 h1:era7g0re5iY13bHSdN/xMkyV+5zZppjRVQhZrXCaEIk= github.com/getsentry/sentry-go v0.12.0/go.mod h1:NSap0JBYWzHND8oMbyi0+XZhUalc1TBdRL1M71JZW2c= +github.com/getzep/zep-go v0.0.7 h1:7ZFTtXr6ZHXnFe/z3iK0RyhCVie6wPl2swC7alGlENs= +github.com/getzep/zep-go v0.0.7/go.mod h1:HC1Gz7oiyrzOTvzeKC4dQKUiUy87zpIJl0ZFXXdHuss= +github.com/getzep/zep-go v1.0.4 h1:09o26bPP2RAPKFjWuVWwUWLbtFDF/S8bfbilxzeZAAg= +github.com/getzep/zep-go v1.0.4/go.mod h1:HC1Gz7oiyrzOTvzeKC4dQKUiUy87zpIJl0ZFXXdHuss= +github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s= github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= @@ -367,6 +372,7 @@ github.com/google/s2a-go v0.1.7 h1:60BLSyTrOV4/haCDW4zb1guZItoSq8foHCXrAnjBo/o= github.com/google/s2a-go v0.1.7/go.mod h1:50CgR4k1jNlWBu4UfS4AcfhVe1r6pdZPygJ3R8F0Qdw= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/enterprise-certificate-proxy v0.3.2 h1:Vie5ybvEvT75RniqhfFxPRy3Bf7vr3h0cechB90XaQs= diff --git a/memory/buffer.go b/memory/buffer.go index 4b4c14206..5d83a8d2f 100644 --- a/memory/buffer.go +++ b/memory/buffer.go @@ -77,7 +77,7 @@ func (m *ConversationBuffer) SaveContext( inputValues map[string]any, outputValues map[string]any, ) error { - userInputValue, err := getInputValue(inputValues, m.InputKey) + userInputValue, err := GetInputValue(inputValues, m.InputKey) if err != nil { return err } @@ -86,7 +86,7 @@ func (m *ConversationBuffer) SaveContext( return err } - aiOutputValue, err := getInputValue(outputValues, m.OutputKey) + aiOutputValue, err := GetInputValue(outputValues, m.OutputKey) if err != nil { return err } @@ -107,7 +107,7 @@ func (m *ConversationBuffer) GetMemoryKey(context.Context) string { return m.MemoryKey } -func getInputValue(inputValues map[string]any, inputKey string) (string, error) { +func GetInputValue(inputValues map[string]any, inputKey string) (string, error) { // If the input key is set, return the value in the inputValues with the input key. if inputKey != "" { inputValue, ok := inputValues[inputKey] diff --git a/memory/zep/zep_chat_history.go b/memory/zep/zep_chat_history.go new file mode 100644 index 000000000..864dc7151 --- /dev/null +++ b/memory/zep/zep_chat_history.go @@ -0,0 +1,164 @@ +package zep + +import ( + "context" + "fmt" + + "github.com/getzep/zep-go" + zepClient "github.com/getzep/zep-go/client" + "github.com/rs/zerolog/log" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/schema" +) + +// ChatMessageHistory is a struct that stores chat messages. +type ChatMessageHistory struct { + ZepClient *zepClient.Client + SessionID string + MemoryType zep.MemoryGetRequestMemoryType + HumanPrefix string + AIPrefix string +} + +// Statically assert that ZepChatMessageHistory implement the chat message history interface. +var _ schema.ChatMessageHistory = &ChatMessageHistory{} + +// NewZepChatMessageHistory creates a new ZepChatMessageHistory using chat message options. +func NewZepChatMessageHistory(zep *zepClient.Client, sessionID string, options ...ChatMessageHistoryOption) *ChatMessageHistory { + messageHistory := applyZepChatHistoryOptions(options...) + messageHistory.ZepClient = zep + messageHistory.SessionID = sessionID + return messageHistory +} + +func (h *ChatMessageHistory) messagesFromZepMessages(zepMessages []*zep.Message) []llms.ChatMessage { + var chatMessages []llms.ChatMessage + for _, zepMessage := range zepMessages { + switch *zepMessage.RoleType { // nolint We do not store other message types in zep memory + case zep.RoleTypeUserRole: + chatMessages = append(chatMessages, llms.HumanChatMessage{Content: *zepMessage.Content}) + case zep.RoleTypeAssistantRole: + chatMessages = append(chatMessages, llms.AIChatMessage{Content: *zepMessage.Content}) + case zep.RoleTypeToolRole: + case zep.RoleTypeFunctionRole: + chatMessages = append(chatMessages, llms.ToolChatMessage{Content: *zepMessage.Content}) + default: + log.Print(fmt.Errorf("unknown role: %s", *zepMessage.RoleType)) + continue + } + } + return chatMessages +} + +func (h *ChatMessageHistory) messagesToZepMessages(messages []llms.ChatMessage) []*zep.Message { + var zepMessages []*zep.Message //nolint We don't know the final size of the messages as some might be skipped due to unsupported role. + for _, m := range messages { + zepMessage := zep.Message{ + Content: zep.String(m.GetContent()), + } + switch m.GetType() { // nolint We only expect to bring these three types into chat history + case llms.ChatMessageTypeHuman: + zepMessage.RoleType = zep.RoleTypeUserRole.Ptr() + if h.HumanPrefix != "" { + zepMessage.Role = zep.String(h.HumanPrefix) + } + case llms.ChatMessageTypeAI: + zepMessage.RoleType = zep.RoleTypeAssistantRole.Ptr() + if h.AIPrefix != "" { + zepMessage.Role = zep.String(h.AIPrefix) + } + case llms.ChatMessageTypeFunction: + zepMessage.RoleType = zep.RoleTypeFunctionRole.Ptr() + case llms.ChatMessageTypeTool: + zepMessage.RoleType = zep.RoleTypeToolRole.Ptr() + default: + log.Print(fmt.Errorf("unknown role: %s", *zepMessage.RoleType)) + continue + } + zepMessages = append(zepMessages, &zepMessage) + } + return zepMessages +} + +// Messages returns all messages stored. +func (h *ChatMessageHistory) Messages(ctx context.Context) ([]llms.ChatMessage, error) { + memory, err := h.ZepClient.Memory.Get(ctx, h.SessionID, &zep.MemoryGetRequest{ + MemoryType: h.MemoryType.Ptr(), + }) + if err != nil { + return nil, err + } + messages := h.messagesFromZepMessages(memory.Messages) + zepFacts := memory.Facts + systemPromptContent := "" + for _, fact := range zepFacts { + systemPromptContent += fmt.Sprintf("%s\n", fact) + } + if memory.Summary != nil && memory.Summary.Content != nil { + systemPromptContent += fmt.Sprintf("%s\n", *memory.Summary.Content) + } + if systemPromptContent != "" { + // Add system prompt to the beginning of the messages. + messages = append( + []llms.ChatMessage{ + llms.SystemChatMessage{ + Content: systemPromptContent, + }, + }, + messages..., + ) + } + return messages, nil +} + +// AddAIMessage adds an AIMessage to the chat message history. +func (h *ChatMessageHistory) AddAIMessage(ctx context.Context, text string) error { + _, err := h.ZepClient.Memory.Add(ctx, h.SessionID, &zep.AddMemoryRequest{ + Messages: h.messagesToZepMessages( + []llms.ChatMessage{ + llms.AIChatMessage{Content: text}, + }, + ), + }) + if err != nil { + return err + } + return nil +} + +// AddUserMessage adds a user to the chat message history. +func (h *ChatMessageHistory) AddUserMessage(ctx context.Context, text string) error { + _, err := h.ZepClient.Memory.Add(ctx, h.SessionID, &zep.AddMemoryRequest{ + Messages: h.messagesToZepMessages( + []llms.ChatMessage{ + llms.HumanChatMessage{Content: text}, + }, + ), + }) + if err != nil { + return err + } + return nil +} + +func (h *ChatMessageHistory) Clear(ctx context.Context) error { + _, err := h.ZepClient.Memory.Delete(ctx, h.SessionID) + if err != nil { + return err + } + return nil +} + +func (h *ChatMessageHistory) AddMessage(ctx context.Context, message llms.ChatMessage) error { + _, err := h.ZepClient.Memory.Add(ctx, h.SessionID, &zep.AddMemoryRequest{ + Messages: h.messagesToZepMessages([]llms.ChatMessage{message}), + }) + if err != nil { + return err + } + return nil +} + +func (*ChatMessageHistory) SetMessages(_ context.Context, _ []llms.ChatMessage) error { + return nil +} diff --git a/memory/zep/zep_chat_history_options.go b/memory/zep/zep_chat_history_options.go new file mode 100644 index 000000000..27fa7aa23 --- /dev/null +++ b/memory/zep/zep_chat_history_options.go @@ -0,0 +1,40 @@ +package zep + +import "github.com/getzep/zep-go" + +// ChatMessageHistoryOption is a function for creating new chat message history +// with other than the default values. +type ChatMessageHistoryOption func(m *ChatMessageHistory) + +// WithChatHistoryMemoryType specifies zep memory type. +func WithChatHistoryMemoryType(memoryType zep.MemoryGetRequestMemoryType) ChatMessageHistoryOption { + return func(b *ChatMessageHistory) { + b.MemoryType = memoryType + } +} + +// WithChatHistoryHumanPrefix is an option for specifying the human prefix. Will be passed as role for the message to zep. +func WithChatHistoryHumanPrefix(humanPrefix string) ChatMessageHistoryOption { + return func(b *ChatMessageHistory) { + b.HumanPrefix = humanPrefix + } +} + +// WithChatHistoryAIPrefix is an option for specifying the AI prefix. Will be passed as role for the message to zep. +func WithChatHistoryAIPrefix(aiPrefix string) ChatMessageHistoryOption { + return func(b *ChatMessageHistory) { + b.AIPrefix = aiPrefix + } +} + +func applyZepChatHistoryOptions(options ...ChatMessageHistoryOption) *ChatMessageHistory { + h := &ChatMessageHistory{ + MemoryType: zep.MemoryGetRequestMemoryTypePerpetual, + } + + for _, option := range options { + option(h) + } + + return h +} diff --git a/memory/zep/zep_memory.go b/memory/zep/zep_memory.go new file mode 100644 index 000000000..0ff76fd9d --- /dev/null +++ b/memory/zep/zep_memory.go @@ -0,0 +1,119 @@ +package zep + +import ( + "context" + + "github.com/getzep/zep-go" + zepClient "github.com/getzep/zep-go/client" + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/memory" + "github.com/tmc/langchaingo/schema" +) + +// Memory is a simple form of memory that remembers previous conversational back and forth directly. +type Memory struct { + ChatHistory schema.ChatMessageHistory + ReturnMessages bool + InputKey string + OutputKey string + HumanPrefix string + AIPrefix string + MemoryKey string + MemoryType zep.MemoryGetRequestMemoryType + ZepClient *zepClient.Client + SessionID string +} + +// Statically assert that ZepMemory implement the memory interface. +var _ schema.Memory = &Memory{} + +// NewMemory is a function for crating a new buffer memory. +func NewMemory(client *zepClient.Client, sessionID string, options ...MemoryOption) *Memory { + m := applyZepMemoryOptions(options...) + m.ZepClient = client + m.SessionID = sessionID + m.ChatHistory = NewZepChatMessageHistory( + m.ZepClient, + m.SessionID, + WithChatHistoryMemoryType(m.MemoryType), + WithChatHistoryHumanPrefix(m.HumanPrefix), + WithChatHistoryAIPrefix(m.AIPrefix), + ) + return m +} + +// MemoryVariables gets the input key the buffer memory class will load dynamically. +func (m *Memory) MemoryVariables(context.Context) []string { + return []string{m.MemoryKey} +} + +// LoadMemoryVariables returns the previous chat messages stored in memory +// as well as a system message with conversation facts and most relevant summary. +// Previous chat messages are returned in a map with the key specified in the MemoryKey field. This key defaults to +// "history". If ReturnMessages is set to true the output is a slice of schema.ChatMessage. Otherwise, +// the output is a buffer string of the chat messages. +func (m *Memory) LoadMemoryVariables( + ctx context.Context, _ map[string]any, +) (map[string]any, error) { + messages, err := m.ChatHistory.Messages(ctx) + if err != nil { + return nil, err + } + + if m.ReturnMessages { + return map[string]any{ + m.MemoryKey: messages, + }, nil + } + + bufferString, err := llms.GetBufferString(messages, m.HumanPrefix, m.AIPrefix) + if err != nil { + return nil, err + } + + return map[string]any{ + m.MemoryKey: bufferString, + }, nil +} + +// SaveContext uses the input values to the llm to save a user message, and the output values +// of the llm to save an AI message. If the input or output key is not set, the input values or +// output values must contain only one key such that the function can know what string to +// add as a user and AI message. On the other hand, if the output key or input key is set, the +// input key must be a key in the input values and the output key must be a key in the output +// values. The values in the input and output values used to save a user and AI message must +// be strings. +func (m *Memory) SaveContext( + ctx context.Context, + inputValues map[string]any, + outputValues map[string]any, +) error { + userInputValue, err := memory.GetInputValue(inputValues, m.InputKey) + if err != nil { + return err + } + err = m.ChatHistory.AddUserMessage(ctx, userInputValue) + if err != nil { + return err + } + + aiOutputValue, err := memory.GetInputValue(outputValues, m.OutputKey) + if err != nil { + return err + } + err = m.ChatHistory.AddAIMessage(ctx, aiOutputValue) + if err != nil { + return err + } + + return nil +} + +// Clear sets the chat messages to a new and empty chat message history. +func (m *Memory) Clear(ctx context.Context) error { + return m.ChatHistory.Clear(ctx) +} + +func (m *Memory) GetMemoryKey(context.Context) string { + return m.MemoryKey +} diff --git a/memory/zep/zep_memory_options.go b/memory/zep/zep_memory_options.go new file mode 100644 index 000000000..e39464f61 --- /dev/null +++ b/memory/zep/zep_memory_options.go @@ -0,0 +1,74 @@ +package zep + +import "github.com/getzep/zep-go" + +// MemoryOption ZepMemoryOption is a function for creating new buffer +// with other than the default values. +type MemoryOption func(b *Memory) + +// WithReturnMessages is an option for specifying should it return messages. +func WithReturnMessages(returnMessages bool) MemoryOption { + return func(b *Memory) { + b.ReturnMessages = returnMessages + } +} + +// WithInputKey is an option for specifying the input key. +func WithInputKey(inputKey string) MemoryOption { + return func(b *Memory) { + b.InputKey = inputKey + } +} + +// WithOutputKey is an option for specifying the output key. +func WithOutputKey(outputKey string) MemoryOption { + return func(b *Memory) { + b.OutputKey = outputKey + } +} + +// WithHumanPrefix is an option for specifying the human prefix. Will be passed as role for the message to zep. +func WithHumanPrefix(humanPrefix string) MemoryOption { + return func(b *Memory) { + b.HumanPrefix = humanPrefix + } +} + +// WithAIPrefix is an option for specifying the AI prefix. Will be passed as role for the message to zep. +func WithAIPrefix(aiPrefix string) MemoryOption { + return func(b *Memory) { + b.AIPrefix = aiPrefix + } +} + +// WithMemoryKey is an option for specifying the memory key. +func WithMemoryKey(memoryKey string) MemoryOption { + return func(b *Memory) { + b.MemoryKey = memoryKey + } +} + +// WithMemoryType specifies zep memory type. +func WithMemoryType(memoryType zep.MemoryGetRequestMemoryType) MemoryOption { + return func(b *Memory) { + b.MemoryType = memoryType + } +} + +func applyZepMemoryOptions(opts ...MemoryOption) *Memory { + m := &Memory{ + ReturnMessages: true, + InputKey: "", + OutputKey: "", + HumanPrefix: "Human", + AIPrefix: "AI", + MemoryKey: "history", + MemoryType: zep.MemoryGetRequestMemoryTypePerpetual, + } + + for _, opt := range opts { + opt(m) + } + + return m +}