diff --git a/src/Infrastructure/BotSharp.Abstraction/Models/MessageState.cs b/src/Infrastructure/BotSharp.Abstraction/Models/MessageState.cs index ff26e981d..a87098105 100644 --- a/src/Infrastructure/BotSharp.Abstraction/Models/MessageState.cs +++ b/src/Infrastructure/BotSharp.Abstraction/Models/MessageState.cs @@ -19,4 +19,9 @@ public MessageState(string key, object value, int activeRounds = -1) Value = value; ActiveRounds = activeRounds; } + + public override string ToString() + { + return $"Key: {Key} => Value: {Value}, ActiveRounds: {ActiveRounds}"; + } } diff --git a/src/Infrastructure/BotSharp.Abstraction/SideCar/Attributes/SideCarAttribute.cs b/src/Infrastructure/BotSharp.Abstraction/SideCar/Attributes/SideCarAttribute.cs index c10ffd8df..833cb013c 100644 --- a/src/Infrastructure/BotSharp.Abstraction/SideCar/Attributes/SideCarAttribute.cs +++ b/src/Infrastructure/BotSharp.Abstraction/SideCar/Attributes/SideCarAttribute.cs @@ -54,16 +54,23 @@ private static MethodInfo GetMethod(string name) private (IConversationSideCar?, MethodInfo?) GetSideCarMethod(IServiceProvider serviceProvider, string methodName, Type retType, object[] args) { - var sidecar = serviceProvider.GetService(); - var argTypes = args.Select(x => x.GetType()).ToArray(); - var sidecarMethod = sidecar?.GetType()?.GetMethods(BindingFlags.Public | BindingFlags.Instance) - .FirstOrDefault(x => x.Name == methodName - && x.ReturnType == retType - && x.GetParameters().Length == argTypes.Length - && x.GetParameters().Select(p => p.ParameterType) - .Zip(argTypes, (paramType, argType) => paramType.IsAssignableFrom(argType)).All(y => y)); - - return (sidecar, sidecarMethod); + try + { + var sidecar = serviceProvider.GetService(); + var argTypes = args.Select(x => x.GetType()).ToArray(); + var sidecarMethod = sidecar?.GetType()?.GetMethods(BindingFlags.Public | BindingFlags.Instance) + .FirstOrDefault(x => x.Name == methodName + && x.ReturnType == retType + && x.GetParameters().Length == argTypes.Length + && x.GetParameters().Select(p => p.ParameterType) + .Zip(argTypes, (paramType, argType) => paramType.IsAssignableFrom(argType)).All(y => y)); + + return (sidecar, sidecarMethod); + } + catch + { + return (null, null); + } } private async Task<(bool, object?)> CallAsyncMethod(IConversationSideCar instance, MethodInfo method, Type retType, object[] args) diff --git a/src/Infrastructure/BotSharp.Abstraction/SideCar/IConversationSideCar.cs b/src/Infrastructure/BotSharp.Abstraction/SideCar/IConversationSideCar.cs index 84b655738..28162a7e2 100644 --- a/src/Infrastructure/BotSharp.Abstraction/SideCar/IConversationSideCar.cs +++ b/src/Infrastructure/BotSharp.Abstraction/SideCar/IConversationSideCar.cs @@ -9,6 +9,7 @@ public interface IConversationSideCar List GetConversationDialogs(string conversationId); void UpdateConversationBreakpoint(string conversationId, ConversationBreakpoint breakpoint); ConversationBreakpoint? GetConversationBreakpoint(string conversationId); + void UpdateConversationStates(string conversationId, List states); Task SendMessage(string agentId, string text, PostbackMessageModel? postback = null, List? states = null, List? dialogs = null); } diff --git a/src/Infrastructure/BotSharp.Core.SideCar/Services/BotSharpConversationSideCar.cs b/src/Infrastructure/BotSharp.Core.SideCar/Services/BotSharpConversationSideCar.cs index 34a980170..a948e5e28 100644 --- a/src/Infrastructure/BotSharp.Core.SideCar/Services/BotSharpConversationSideCar.cs +++ b/src/Infrastructure/BotSharp.Core.SideCar/Services/BotSharpConversationSideCar.cs @@ -23,9 +23,10 @@ public class BotSharpConversationSideCar : IConversationSideCar private readonly IServiceProvider _services; private readonly ILogger _logger; - private Stack contextStack = new(); + private Stack _contextStack = new(); - private bool enabled = false; + private bool _enabled = false; + private string _conversationId = string.Empty; public string Provider => "botsharp"; @@ -39,46 +40,63 @@ public BotSharpConversationSideCar( public bool IsEnabled() { - return enabled; + return _enabled; } public void AppendConversationDialogs(string conversationId, List messages) { - if (contextStack.IsNullOrEmpty()) return; + if (!IsValid(conversationId)) + { + return; + } - var top = contextStack.Peek(); + var top = _contextStack.Peek(); top.Dialogs.AddRange(messages); } public List GetConversationDialogs(string conversationId) { - if (contextStack.IsNullOrEmpty()) + if (!IsValid(conversationId)) { return new List(); } - return contextStack.Peek().Dialogs; + return _contextStack.Peek().Dialogs; } public void UpdateConversationBreakpoint(string conversationId, ConversationBreakpoint breakpoint) { - if (contextStack.IsNullOrEmpty()) return; + if (!IsValid(conversationId)) + { + return; + } - var top = contextStack.Peek().Breakpoints; + var top = _contextStack.Peek().Breakpoints; top.Add(breakpoint); } public ConversationBreakpoint? GetConversationBreakpoint(string conversationId) { - if (contextStack.IsNullOrEmpty()) + if (!IsValid(conversationId)) { return null; } - var top = contextStack.Peek().Breakpoints; + var top = _contextStack.Peek().Breakpoints; return top.LastOrDefault(); } + public void UpdateConversationStates(string conversationId, List states) + { + if (!IsValid(conversationId)) + { + return; + } + + var top = _contextStack.Peek(); + top.State = new ConversationState(states); + } + public async Task SendMessage(string agentId, string text, PostbackMessageModel? postback = null, List? states = null, List? dialogs = null) { @@ -94,6 +112,7 @@ private async Task InnerExecute(string agentId, string text, var conv = _services.GetRequiredService(); var routing = _services.GetRequiredService(); var state = _services.GetRequiredService(); + _conversationId = conv.ConversationId; var inputMsg = new RoleDialogModel(AgentRole.User, text); routing.Context.SetMessageId(conv.ConversationId, inputMsg.MessageId); @@ -116,7 +135,7 @@ await conv.SendMessage(agentId, inputMsg, private void BeforeExecute(List? dialogs) { - enabled = true; + _enabled = true; var state = _services.GetRequiredService(); var routing = _services.GetRequiredService(); @@ -129,7 +148,7 @@ private void BeforeExecute(List? dialogs) RecursiveCounter = routing.Context.GetRecursiveCounter(), RoutingStack = routing.Context.GetAgentStack() }; - contextStack.Push(node); + _contextStack.Push(node); // Reset state.ResetCurrentState(); @@ -144,7 +163,7 @@ private void AfterExecute() var state = _services.GetRequiredService(); var routing = _services.GetRequiredService(); - var node = contextStack.Pop(); + var node = _contextStack.Pop(); // Recover state.SetCurrentState(node.State); @@ -152,6 +171,14 @@ private void AfterExecute() routing.Context.SetAgentStack(node.RoutingStack); routing.Context.SetDialogs(node.RoutingDialogs); Utilities.ClearCache(); - enabled = false; + _enabled = false; + } + + private bool IsValid(string conversationId) + { + return !_contextStack.IsNullOrEmpty() + && _conversationId == conversationId + && !string.IsNullOrEmpty(conversationId) + && !string.IsNullOrEmpty(_conversationId); } } \ No newline at end of file diff --git a/src/Infrastructure/BotSharp.Core/Repository/FileRepository/FileRepository.Conversation.cs b/src/Infrastructure/BotSharp.Core/Repository/FileRepository/FileRepository.Conversation.cs index 3b60a61a6..3d03e816a 100644 --- a/src/Infrastructure/BotSharp.Core/Repository/FileRepository/FileRepository.Conversation.cs +++ b/src/Infrastructure/BotSharp.Core/Repository/FileRepository/FileRepository.Conversation.cs @@ -307,6 +307,7 @@ public ConversationState GetConversationStates(string conversationId) return new ConversationState(states); } + [SideCar] public void UpdateConversationStates(string conversationId, List states) { if (states.IsNullOrEmpty()) return; diff --git a/src/Plugins/BotSharp.Plugin.MongoStorage/Repository/MongoRepository.Conversation.cs b/src/Plugins/BotSharp.Plugin.MongoStorage/Repository/MongoRepository.Conversation.cs index 02f5a6951..edcf3a5f2 100644 --- a/src/Plugins/BotSharp.Plugin.MongoStorage/Repository/MongoRepository.Conversation.cs +++ b/src/Plugins/BotSharp.Plugin.MongoStorage/Repository/MongoRepository.Conversation.cs @@ -266,6 +266,7 @@ public ConversationState GetConversationStates(string conversationId) return new ConversationState(savedStates); } + [SideCar] public void UpdateConversationStates(string conversationId, List states) { if (string.IsNullOrEmpty(conversationId) || states == null) return; diff --git a/src/Plugins/BotSharp.Plugin.Twilio/OutboundPhoneCallHandler/Functions/OutboundPhoneCallFn.cs b/src/Plugins/BotSharp.Plugin.Twilio/OutboundPhoneCallHandler/Functions/OutboundPhoneCallFn.cs index 5ffd692e9..8d34dd22b 100644 --- a/src/Plugins/BotSharp.Plugin.Twilio/OutboundPhoneCallHandler/Functions/OutboundPhoneCallFn.cs +++ b/src/Plugins/BotSharp.Plugin.Twilio/OutboundPhoneCallHandler/Functions/OutboundPhoneCallFn.cs @@ -1,7 +1,9 @@ using BotSharp.Abstraction.Files; using BotSharp.Abstraction.Infrastructures.Enums; using BotSharp.Abstraction.Options; +using BotSharp.Abstraction.Repositories; using BotSharp.Abstraction.Routing; +using BotSharp.Abstraction.Utilities; using BotSharp.Core.Infrastructures; using BotSharp.Plugin.Twilio.Interfaces; using BotSharp.Plugin.Twilio.Models; @@ -134,9 +136,10 @@ public async Task Execute(RoleDialogModel message) } } - private async Task ForkConversation(LlmContextIn args, - string entryAgentId, - string originConversationId, + private async Task ForkConversation( + LlmContextIn args, + string entryAgentId, + string originConversationId, string newConversationId, CallResource call) { @@ -145,6 +148,8 @@ private async Task ForkConversation(LlmContextIn args, var services = scope.ServiceProvider; var convService = services.GetRequiredService(); var convStorage = services.GetRequiredService(); + var state = _services.GetRequiredService(); + var db = _services.GetRequiredService(); var newConv = await convService.NewConversation(new Conversation { @@ -170,15 +175,45 @@ private async Task ForkConversation(LlmContextIn args, } }); - convService.SetConversationId(newConversationId, - [ - new MessageState(StateConst.ORIGIN_CONVERSATION_ID, originConversationId), - new MessageState("channel", "phone"), - new MessageState("phone_from", call.From), - new MessageState("phone_direction", call.Direction), - new MessageState("phone_number", call.To), - new MessageState("twilio_call_sid", call.Sid) - ]); - convService.SaveStates(); + var utcNow = DateTime.UtcNow; + var excludStates = new List + { + "provider", + "model", + "prompt_total", + "completion_total", + "llm_total_cost" + }; + + var curStates = state.GetStates().Select(x => new MessageState(x.Key, x.Value)).ToList(); + var subConvStates = new List + { + new(StateConst.ORIGIN_CONVERSATION_ID, originConversationId), + new("channel", "phone"), + new("phone_from", call.From), + new("phone_direction", call.Direction), + new("phone_number", call.To), + new("twilio_call_sid", call.Sid) + }; + var subStateKeys = subConvStates.Select(x => x.Key).ToList(); + var included = curStates.Where(x => !subStateKeys.Contains(x.Key) && !excludStates.Contains(x.Key)); + var newStates = subConvStates.Concat(included).Select(x => new StateKeyValue + { + Key = x.Key, + Versioning = true, + Values = [ + new StateValue + { + Data = x.Value.ConvertToString(_options.JsonSerializerOptions), + MessageId = messageId, + Active = true, + ActiveRounds = x.ActiveRounds, + Source = StateSource.Application, + UpdateTime = utcNow + } + ] + }).ToList(); + + db.UpdateConversationStates(newConversationId, newStates); } }