Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions BotSharp.sln
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.ImageHandle
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.FuzzySharp", "src\Plugins\BotSharp.Plugin.FuzzySharp\BotSharp.Plugin.FuzzySharp.csproj", "{E7C243B9-E751-B3B4-8F16-95C76CA90D31}"
EndProject
Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "BotSharp.Plugin.MMPEmbedding", "src\Plugins\BotSharp.Plugin.MMPEmbedding\BotSharp.Plugin.MMPEmbedding.csproj", "{394B858B-9C26-B977-A2DA-8CC7BE5914CB}"
EndProject
Global
GlobalSection(SolutionConfigurationPlatforms) = preSolution
Debug|Any CPU = Debug|Any CPU
Expand Down Expand Up @@ -639,6 +641,14 @@ Global
{E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|Any CPU.Build.0 = Release|Any CPU
{E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|x64.ActiveCfg = Release|Any CPU
{E7C243B9-E751-B3B4-8F16-95C76CA90D31}.Release|x64.Build.0 = Release|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|Any CPU.Build.0 = Debug|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.ActiveCfg = Debug|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Debug|x64.Build.0 = Debug|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.ActiveCfg = Release|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|Any CPU.Build.0 = Release|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.ActiveCfg = Release|Any CPU
{394B858B-9C26-B977-A2DA-8CC7BE5914CB}.Release|x64.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -712,6 +722,7 @@ Global
{FC63C875-E880-D8BB-B8B5-978AB7B62983} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
{242F2D93-FCCE-4982-8075-F3052ECCA92C} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
{E7C243B9-E751-B3B4-8F16-95C76CA90D31} = {51AFE054-AE99-497D-A593-69BAEFB5106F}
{394B858B-9C26-B977-A2DA-8CC7BE5914CB} = {2635EC9B-2E5F-4313-AC21-0B847F31F36C}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {A9969D89-C98B-40A5-A12B-FC87E55B3A19}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>$(TargetFramework)</TargetFramework>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" />
<PackageReference Include="OpenAI" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="..\..\Infrastructure\BotSharp.Core\BotSharp.Core.csproj" />
</ItemGroup>

</Project>
19 changes: 19 additions & 0 deletions src/Plugins/BotSharp.Plugin.MMPEmbedding/MMPEmbeddingPlugin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
using BotSharp.Abstraction.Plugins;
using BotSharp.Plugin.MMPEmbedding.Providers;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;

namespace BotSharp.Plugin.MMPEmbedding
{
public class MMPEmbeddingPlugin : IBotSharpPlugin
{
public string Id => "54d04e10-fc84-493e-a8c9-39da1c83f45a";
public string Name => "MMPEmbedding";
public string Description => "MMP Embedding Service";

public void RegisterDI(IServiceCollection services, IConfiguration config)
{
services.AddScoped<ITextEmbedding, MMPEmbeddingProvider>();
}
}
}
70 changes: 70 additions & 0 deletions src/Plugins/BotSharp.Plugin.MMPEmbedding/ProviderHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
using OpenAI;
using Azure.AI.OpenAI;
using System.ClientModel;
using Microsoft.Extensions.DependencyInjection;

namespace BotSharp.Plugin.MMPEmbedding;

/// <summary>
/// Helper class to get the appropriate client based on provider type
/// Supports multiple providers: OpenAI, Azure OpenAI, DeepSeek, etc.
/// </summary>
public static class ProviderHelper
{
/// <summary>
/// Gets an OpenAI-compatible client based on the provider name
/// </summary>
/// <param name="provider">Provider name (e.g., "openai", "azure-openai")</param>
/// <param name="model">Model name</param>
/// <param name="services">Service provider for dependency injection</param>
/// <returns>OpenAIClient instance configured for the specified provider</returns>
public static OpenAIClient GetClient(string provider, string model, IServiceProvider services)
{
var settingsService = services.GetRequiredService<ILlmProviderService>();
var settings = settingsService.GetSetting(provider, model);

if (settings == null)
{
throw new InvalidOperationException($"Cannot find settings for provider '{provider}' and model '{model}'");
}

// Handle Azure OpenAI separately as it uses AzureOpenAIClient
if (provider.Equals("azure-openai", StringComparison.OrdinalIgnoreCase))
{
return GetAzureOpenAIClient(settings);
}

// For OpenAI, DeepSeek, and other OpenAI-compatible providers
return GetOpenAICompatibleClient(settings);
}

/// <summary>
/// Gets an Azure OpenAI client
/// </summary>
private static OpenAIClient GetAzureOpenAIClient(LlmModelSetting settings)
{
if (string.IsNullOrEmpty(settings.Endpoint))
{
throw new InvalidOperationException("Azure OpenAI endpoint is required");
}

var client = new AzureOpenAIClient(
new Uri(settings.Endpoint),
new ApiKeyCredential(settings.ApiKey)
);

return client;
}

/// <summary>
/// Gets an OpenAI-compatible client (OpenAI, DeepSeek, etc.)
/// </summary>
private static OpenAIClient GetOpenAICompatibleClient(LlmModelSetting settings)
{
var options = !string.IsNullOrEmpty(settings.Endpoint)
? new OpenAIClientOptions { Endpoint = new Uri(settings.Endpoint) }
: null;

return new OpenAIClient(new ApiKeyCredential(settings.ApiKey), options);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
using System.Collections.Generic;
using System.Text.RegularExpressions;
using BotSharp.Plugin.MMPEmbedding;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using OpenAI.Embeddings;

namespace BotSharp.Plugin.MMPEmbedding.Providers;

/// <summary>
/// Text embedding provider that uses Mean-Max Pooling strategy
/// This provider gets embeddings for individual tokens and combines them using mean and max pooling
/// </summary>
public class MMPEmbeddingProvider : ITextEmbedding
{
protected readonly IServiceProvider _serviceProvider;
protected readonly ILogger<MMPEmbeddingProvider> _logger;

private const int DEFAULT_DIMENSION = 1536;
protected string _model = "text-embedding-3-small";
protected int _dimension = DEFAULT_DIMENSION;

// The underlying provider to use (e.g., "openai", "azure-openai", "deepseek-ai")
protected string _underlyingProvider = "openai";

public string Provider => "mmp-embedding";
public string Model => _model;

private static readonly Regex WordRegex = new(@"\b\w+\b", RegexOptions.Compiled);

public MMPEmbeddingProvider(IServiceProvider serviceProvider, ILogger<MMPEmbeddingProvider> logger)
{
_serviceProvider = serviceProvider;
_logger = logger;
}

/// <summary>
/// Gets a single embedding vector using mean-max pooling
/// </summary>
public async Task<float[]> GetVectorAsync(string text)
{
if (string.IsNullOrWhiteSpace(text))
{
return new float[_dimension];
}

var tokens = Tokenize(text).ToList();

if (tokens.Count == 0)
{
return new float[_dimension];
}

// Get embeddings for all tokens
var tokenEmbeddings = await GetTokenEmbeddingsAsync(tokens);

// Apply mean-max pooling
var pooledEmbedding = MeanMaxPooling(tokenEmbeddings);

return pooledEmbedding;
}

/// <summary>
/// Gets multiple embedding vectors using mean-max pooling
/// </summary>
public async Task<List<float[]>> GetVectorsAsync(List<string> texts)
{
var results = new List<float[]>();

foreach (var text in texts)
{
var embedding = await GetVectorAsync(text);
results.Add(embedding);
}

return results;
}

/// <summary>
/// Gets embeddings for individual tokens using the underlying provider
/// </summary>
private async Task<List<float[]>> GetTokenEmbeddingsAsync(List<string> tokens)
{
try
{
// Get the appropriate client based on the underlying provider
var client = ProviderHelper.GetClient(_underlyingProvider, _model, _serviceProvider);
var embeddingClient = client.GetEmbeddingClient(_model);

// Prepare options
var options = new EmbeddingGenerationOptions
{
Dimensions = _dimension > 0 ? _dimension : null
};

// Get embeddings for all tokens in batch
var response = await embeddingClient.GenerateEmbeddingsAsync(tokens, options);
var embeddings = response.Value;

return embeddings.Select(e => e.ToFloats().ToArray()).ToList();
}
catch (Exception ex)
{
_logger.LogError(ex, "Error getting token embeddings from provider {Provider} with model {Model}",
_underlyingProvider, _model);
throw;
}
}

/// <summary>
/// Applies mean-max pooling to combine token embeddings
/// Mean pooling: average of all token embeddings
/// Max pooling: element-wise maximum of all token embeddings
/// Result: concatenation of mean and max pooled vectors
/// </summary>
private float[] MeanMaxPooling(IReadOnlyList<float[]> vectors, double meanWeight = 0.5, double maxWeight = 0.5)
{
var numTokens = vectors.Count;

if (numTokens == 0)
return [];

var meanPooled = Enumerable.Range(0, _dimension)
.Select(i => vectors.Average(v => v[i]))
.ToArray();
var maxPooled = Enumerable.Range(0, _dimension)
.Select(i => vectors.Max(v => v[i]))
.ToArray();

return Enumerable.Range(0, _dimension)
.Select(i => (float)meanWeight * meanPooled[i] + (float)maxWeight * maxPooled[i])
.ToArray();
}

public void SetDimension(int dimension)
{
_dimension = dimension > 0 ? dimension : DEFAULT_DIMENSION;
}

public int GetDimension()
{
return _dimension;
}

public void SetModelName(string model)
{
_model = model;
}

/// <summary>
/// Sets the underlying provider to use for getting token embeddings
/// </summary>
/// <param name="provider">Provider name (e.g., "openai", "azure-openai", "deepseek-ai")</param>
public void SetUnderlyingProvider(string provider)
{
_underlyingProvider = provider;
}

/// <summary>
/// Tokenizes text into individual words
/// </summary>
public static IEnumerable<string> Tokenize(string text, string? pattern = null)
{
var patternRegex = string.IsNullOrEmpty(pattern) ? WordRegex : new(pattern, RegexOptions.Compiled);
return patternRegex.Matches(text).Cast<Match>().Select(m => m.Value);
}
}
10 changes: 10 additions & 0 deletions src/Plugins/BotSharp.Plugin.MMPEmbedding/Using.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
global using System;
global using System.Collections.Generic;
global using System.Linq;
global using System.Text;
global using System.Threading.Tasks;

global using BotSharp.Abstraction.MLTasks;
global using BotSharp.Abstraction.MLTasks.Settings;
global using Microsoft.Extensions.Logging;

Loading