Skip to content

fix fork conversation #1082

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,9 @@ public MessageState(string key, object value, int activeRounds = -1)
Value = value;
ActiveRounds = activeRounds;
}

public override string ToString()
{
return $"Key: {Key} => Value: {Value}, ActiveRounds: {ActiveRounds}";
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,23 @@ private static MethodInfo GetMethod(string name)

private (IConversationSideCar?, MethodInfo?) GetSideCarMethod(IServiceProvider serviceProvider, string methodName, Type retType, object[] args)
{
var sidecar = serviceProvider.GetService<IConversationSideCar>();
var argTypes = args.Select(x => x.GetType()).ToArray();
var sidecarMethod = sidecar?.GetType()?.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.FirstOrDefault(x => x.Name == methodName
&& x.ReturnType == retType
&& x.GetParameters().Length == argTypes.Length
&& x.GetParameters().Select(p => p.ParameterType)
.Zip(argTypes, (paramType, argType) => paramType.IsAssignableFrom(argType)).All(y => y));

return (sidecar, sidecarMethod);
try
{
var sidecar = serviceProvider.GetService<IConversationSideCar>();
var argTypes = args.Select(x => x.GetType()).ToArray();
var sidecarMethod = sidecar?.GetType()?.GetMethods(BindingFlags.Public | BindingFlags.Instance)
.FirstOrDefault(x => x.Name == methodName
&& x.ReturnType == retType
&& x.GetParameters().Length == argTypes.Length
&& x.GetParameters().Select(p => p.ParameterType)
.Zip(argTypes, (paramType, argType) => paramType.IsAssignableFrom(argType)).All(y => y));

return (sidecar, sidecarMethod);
}
catch
{
return (null, null);
}
}

private async Task<(bool, object?)> CallAsyncMethod(IConversationSideCar instance, MethodInfo method, Type retType, object[] args)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public interface IConversationSideCar
List<DialogElement> GetConversationDialogs(string conversationId);
void UpdateConversationBreakpoint(string conversationId, ConversationBreakpoint breakpoint);
ConversationBreakpoint? GetConversationBreakpoint(string conversationId);
void UpdateConversationStates(string conversationId, List<StateKeyValue> states);
Task<RoleDialogModel> SendMessage(string agentId, string text,
PostbackMessageModel? postback = null, List<MessageState>? states = null, List<DialogElement>? dialogs = null);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ public class BotSharpConversationSideCar : IConversationSideCar
private readonly IServiceProvider _services;
private readonly ILogger<BotSharpConversationSideCar> _logger;

private Stack<ConversationContext> contextStack = new();
private Stack<ConversationContext> _contextStack = new();

private bool enabled = false;
private bool _enabled = false;
private string _conversationId = string.Empty;

public string Provider => "botsharp";

Expand All @@ -39,46 +40,63 @@ public BotSharpConversationSideCar(

public bool IsEnabled()
{
return enabled;
return _enabled;
}

public void AppendConversationDialogs(string conversationId, List<DialogElement> messages)
{
if (contextStack.IsNullOrEmpty()) return;
if (!IsValid(conversationId))
{
return;
}

var top = contextStack.Peek();
var top = _contextStack.Peek();
top.Dialogs.AddRange(messages);
}

public List<DialogElement> GetConversationDialogs(string conversationId)
{
if (contextStack.IsNullOrEmpty())
if (!IsValid(conversationId))
{
return new List<DialogElement>();
}

return contextStack.Peek().Dialogs;
return _contextStack.Peek().Dialogs;
}

public void UpdateConversationBreakpoint(string conversationId, ConversationBreakpoint breakpoint)
{
if (contextStack.IsNullOrEmpty()) return;
if (!IsValid(conversationId))
{
return;
}

var top = contextStack.Peek().Breakpoints;
var top = _contextStack.Peek().Breakpoints;
top.Add(breakpoint);
}

public ConversationBreakpoint? GetConversationBreakpoint(string conversationId)
{
if (contextStack.IsNullOrEmpty())
if (!IsValid(conversationId))
{
return null;
}

var top = contextStack.Peek().Breakpoints;
var top = _contextStack.Peek().Breakpoints;
return top.LastOrDefault();
}

public void UpdateConversationStates(string conversationId, List<StateKeyValue> states)
{
if (!IsValid(conversationId))
{
return;
}

var top = _contextStack.Peek();
top.State = new ConversationState(states);
}

public async Task<RoleDialogModel> SendMessage(string agentId, string text,
PostbackMessageModel? postback = null, List<MessageState>? states = null, List<DialogElement>? dialogs = null)
{
Expand All @@ -94,6 +112,7 @@ private async Task<RoleDialogModel> InnerExecute(string agentId, string text,
var conv = _services.GetRequiredService<IConversationService>();
var routing = _services.GetRequiredService<IRoutingService>();
var state = _services.GetRequiredService<IConversationStateService>();
_conversationId = conv.ConversationId;

var inputMsg = new RoleDialogModel(AgentRole.User, text);
routing.Context.SetMessageId(conv.ConversationId, inputMsg.MessageId);
Expand All @@ -116,7 +135,7 @@ await conv.SendMessage(agentId, inputMsg,

private void BeforeExecute(List<DialogElement>? dialogs)
{
enabled = true;
_enabled = true;
var state = _services.GetRequiredService<IConversationStateService>();
var routing = _services.GetRequiredService<IRoutingService>();

Expand All @@ -129,7 +148,7 @@ private void BeforeExecute(List<DialogElement>? dialogs)
RecursiveCounter = routing.Context.GetRecursiveCounter(),
RoutingStack = routing.Context.GetAgentStack()
};
contextStack.Push(node);
_contextStack.Push(node);

// Reset
state.ResetCurrentState();
Expand All @@ -144,14 +163,22 @@ private void AfterExecute()
var state = _services.GetRequiredService<IConversationStateService>();
var routing = _services.GetRequiredService<IRoutingService>();

var node = contextStack.Pop();
var node = _contextStack.Pop();

// Recover
state.SetCurrentState(node.State);
routing.Context.SetRecursiveCounter(node.RecursiveCounter);
routing.Context.SetAgentStack(node.RoutingStack);
routing.Context.SetDialogs(node.RoutingDialogs);
Utilities.ClearCache();
enabled = false;
_enabled = false;
}

private bool IsValid(string conversationId)
{
return !_contextStack.IsNullOrEmpty()
&& _conversationId == conversationId
&& !string.IsNullOrEmpty(conversationId)
&& !string.IsNullOrEmpty(_conversationId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ public ConversationState GetConversationStates(string conversationId)
return new ConversationState(states);
}

[SideCar]
public void UpdateConversationStates(string conversationId, List<StateKeyValue> states)
{
if (states.IsNullOrEmpty()) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ public ConversationState GetConversationStates(string conversationId)
return new ConversationState(savedStates);
}

[SideCar]
public void UpdateConversationStates(string conversationId, List<StateKeyValue> states)
{
if (string.IsNullOrEmpty(conversationId) || states == null) return;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using BotSharp.Abstraction.Files;
using BotSharp.Abstraction.Infrastructures.Enums;
using BotSharp.Abstraction.Options;
using BotSharp.Abstraction.Repositories;
using BotSharp.Abstraction.Routing;
using BotSharp.Abstraction.Utilities;
using BotSharp.Core.Infrastructures;
using BotSharp.Plugin.Twilio.Interfaces;
using BotSharp.Plugin.Twilio.Models;
Expand Down Expand Up @@ -134,9 +136,10 @@ public async Task<bool> Execute(RoleDialogModel message)
}
}

private async Task ForkConversation(LlmContextIn args,
string entryAgentId,
string originConversationId,
private async Task ForkConversation(
LlmContextIn args,
string entryAgentId,
string originConversationId,
string newConversationId,
CallResource call)
{
Expand All @@ -145,6 +148,8 @@ private async Task ForkConversation(LlmContextIn args,
var services = scope.ServiceProvider;
var convService = services.GetRequiredService<IConversationService>();
var convStorage = services.GetRequiredService<IConversationStorage>();
var state = _services.GetRequiredService<IConversationStateService>();
var db = _services.GetRequiredService<IBotSharpRepository>();

var newConv = await convService.NewConversation(new Conversation
{
Expand All @@ -170,15 +175,45 @@ private async Task ForkConversation(LlmContextIn args,
}
});

convService.SetConversationId(newConversationId,
[
new MessageState(StateConst.ORIGIN_CONVERSATION_ID, originConversationId),
new MessageState("channel", "phone"),
new MessageState("phone_from", call.From),
new MessageState("phone_direction", call.Direction),
new MessageState("phone_number", call.To),
new MessageState("twilio_call_sid", call.Sid)
]);
convService.SaveStates();
var utcNow = DateTime.UtcNow;
var excludStates = new List<string>
{
"provider",
"model",
"prompt_total",
"completion_total",
"llm_total_cost"
};

var curStates = state.GetStates().Select(x => new MessageState(x.Key, x.Value)).ToList();
var subConvStates = new List<MessageState>
{
new(StateConst.ORIGIN_CONVERSATION_ID, originConversationId),
new("channel", "phone"),
new("phone_from", call.From),
new("phone_direction", call.Direction),
new("phone_number", call.To),
new("twilio_call_sid", call.Sid)
};
var subStateKeys = subConvStates.Select(x => x.Key).ToList();
var included = curStates.Where(x => !subStateKeys.Contains(x.Key) && !excludStates.Contains(x.Key));
var newStates = subConvStates.Concat(included).Select(x => new StateKeyValue
{
Key = x.Key,
Versioning = true,
Values = [
new StateValue
{
Data = x.Value.ConvertToString(_options.JsonSerializerOptions),
MessageId = messageId,
Active = true,
ActiveRounds = x.ActiveRounds,
Source = StateSource.Application,
UpdateTime = utcNow
}
]
}).ToList();

db.UpdateConversationStates(newConversationId, newStates);
}
}
Loading