diff --git a/build.psm1 b/build.psm1 index 30efc9fb..702c41e4 100644 --- a/build.psm1 +++ b/build.psm1 @@ -91,6 +91,43 @@ function Start-Build $app_csproj = GetProjectFile $app_dir dotnet publish $app_csproj -c $Configuration -o $app_out_dir -r $RID --sc + ## Move the 'Modules' folder to the appbase folder. + if ($LASTEXITCODE -eq 0) { + ## Remove the artifacts that are not for the current platform, to reduce size. + $otherPlatDir = Join-Path $app_out_dir 'runtimes' ($IsWindows ? 'unix' : 'win') + if (Test-Path $otherPlatDir -PathType Container) { + Remove-Item $otherPlatDir -Recurse -Force -ErrorAction Stop + } + + ## Move the 'Modules' folder and if possible, remove the 'runtimes' folder all together afterward. + $platDir = Join-Path $app_out_dir 'runtimes' ($IsWindows ? 'win' : 'unix') 'lib' + if (Test-Path $platDir -PathType Container) { + $moduleDir = Get-ChildItem $platDir -Directory -Include 'Modules' -Recurse + if ($moduleDir) { + ## Remove the existing 'Modules' folder if it already exists. + $target = Join-Path $app_out_dir 'Modules' + if (Test-Path $target -PathType Container) { + Remove-Item $target -Recurse -Force -ErrorAction Stop + } + + ## Move 'Modules' folder. + Move-Item $moduleDir.FullName $app_out_dir -Force -ErrorAction Stop + + ## Remove the 'runtimes' folder if possible. + $parent = $moduleDir.Parent + while ($parent.FullName -ne $app_out_dir) { + $files = Get-ChildItem $parent.FullName -File -Recurse + if (-not $files) { + Remove-Item $parent.FullName -Recurse -Force -ErrorAction Stop + $parent = $parent.Parent + } else { + break + } + } + } + } + } + if ($LASTEXITCODE -eq 0 -and $AgentToInclude -contains 'openai-gpt') { Write-Host "`n[Build the OpenAI agent ...]`n" -ForegroundColor Green $openai_csproj = GetProjectFile $openai_agent_dir diff --git a/shell/AIShell.App/AIShell.App.csproj b/shell/AIShell.App/AIShell.App.csproj index 4bfbb25a..dd8f9592 100644 --- a/shell/AIShell.App/AIShell.App.csproj +++ b/shell/AIShell.App/AIShell.App.csproj @@ -13,6 +13,7 @@ + diff --git a/shell/AIShell.Kernel/Utility/LoadContext.cs b/shell/AIShell.Kernel/Utility/LoadContext.cs index 42d9960a..3cb70c08 100644 --- a/shell/AIShell.Kernel/Utility/LoadContext.cs +++ b/shell/AIShell.Kernel/Utility/LoadContext.cs @@ -1,4 +1,6 @@ -using System.Reflection; +using System.Collections.Concurrent; +using System.Reflection; +using System.Runtime.InteropServices; using System.Runtime.Loader; namespace AIShell.Kernel; @@ -6,6 +8,10 @@ namespace AIShell.Kernel; internal class AgentAssemblyLoadContext : AssemblyLoadContext { private readonly string _dependencyDir; + private readonly string _nativeLibExt; + private readonly List _runtimeLibDir; + private readonly List _runtimeNativeDir; + private readonly ConcurrentDictionary _cache; internal AgentAssemblyLoadContext(string name, string dependencyDir) : base($"{name.Replace(' ', '.')}-ALC", isCollectible: false) @@ -17,20 +23,112 @@ internal AgentAssemblyLoadContext(string name, string dependencyDir) // Save the full path to the dependencies directory when creating the context. _dependencyDir = dependencyDir; + _runtimeLibDir = []; + _runtimeNativeDir = []; + _cache = []; + + if (OperatingSystem.IsWindows()) + { + _nativeLibExt = ".dll"; + AddToList(_runtimeLibDir, Path.Combine(dependencyDir, "runtimes", "win", "lib")); + AddToList(_runtimeNativeDir, Path.Combine(dependencyDir, "runtimes", "win", "native")); + } + else if (OperatingSystem.IsLinux()) + { + _nativeLibExt = ".so"; + AddToList(_runtimeLibDir, Path.Combine(dependencyDir, "runtimes", "unix", "lib")); + AddToList(_runtimeLibDir, Path.Combine(dependencyDir, "runtimes", "linux", "lib")); + + AddToList(_runtimeNativeDir, Path.Combine(dependencyDir, "runtimes", "unix", "native")); + AddToList(_runtimeNativeDir, Path.Combine(dependencyDir, "runtimes", "linux", "native")); + } + else if (OperatingSystem.IsMacOS()) + { + _nativeLibExt = ".dylib"; + AddToList(_runtimeLibDir, Path.Combine(dependencyDir, "runtimes", "unix", "lib")); + AddToList(_runtimeLibDir, Path.Combine(dependencyDir, "runtimes", "osx", "lib")); + + AddToList(_runtimeNativeDir, Path.Combine(dependencyDir, "runtimes", "unix", "native")); + AddToList(_runtimeNativeDir, Path.Combine(dependencyDir, "runtimes", "osx", "native")); + } + + AddToList(_runtimeLibDir, Path.Combine(dependencyDir, "runtimes", RuntimeInformation.RuntimeIdentifier, "lib")); + AddToList(_runtimeNativeDir, Path.Combine(dependencyDir, "runtimes", RuntimeInformation.RuntimeIdentifier, "native")); + + ResolvingUnmanagedDll += ResolveUnmanagedDll; } protected override Assembly Load(AssemblyName assemblyName) { - // Create a path to the assembly in the dependencies directory. - string path = Path.Combine(_dependencyDir, $"{assemblyName.Name}.dll"); + if (_cache.TryGetValue(assemblyName.Name, out Assembly assembly)) + { + return assembly; + } + + lock (this) + { + if (_cache.TryGetValue(assemblyName.Name, out assembly)) + { + return assembly; + } + + // Create a path to the assembly in the dependencies directory. + string assemblyFile = $"{assemblyName.Name}.dll"; + string path = Path.Combine(_dependencyDir, assemblyFile); + + if (File.Exists(path)) + { + // If the assembly exists in our dependency directory, then load it into this load context. + assembly = LoadFromAssemblyPath(path); + } + else + { + foreach (string dir in _runtimeLibDir) + { + IEnumerable result = Directory.EnumerateFiles(dir, assemblyFile, SearchOption.AllDirectories); + path = result.FirstOrDefault(); + + if (path is not null) + { + assembly = LoadFromAssemblyPath(path); + break; + } + } + } + + // Add the probing result to cache, regardless of whether we found it. + // If we didn't find it, we will add 'null' to the cache so that we don't probe + // again in case another loading request comes for the same assembly. + _cache.TryAdd(assemblyName.Name, assembly); + + // Return the assembly if we found it, or return 'null' otherwise to depend on the default load context to resolve the request. + return assembly; + } + } + + private nint ResolveUnmanagedDll(Assembly assembly, string libraryName) + { + string libraryFile = $"{libraryName}{_nativeLibExt}"; - if (File.Exists(path)) + foreach (string dir in _runtimeNativeDir) { - // If the assembly exists in our dependency directory, then load it into this load context. - return LoadFromAssemblyPath(path); + IEnumerable result = Directory.EnumerateFiles(dir, libraryFile, SearchOption.AllDirectories); + string path = result.FirstOrDefault(); + + if (path is not null) + { + return NativeLibrary.Load(path); + } } - // Otherwise we will depend on the default load context to resolve the request. - return null; + return nint.Zero; + } + + private static void AddToList(List depList, string dirPath) + { + if (Directory.Exists(dirPath)) + { + depList.Add(dirPath); + } } } diff --git a/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs b/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs index 3a098ca7..e70a0c30 100644 --- a/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs +++ b/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs @@ -27,7 +27,7 @@ public sealed class AzureAgent : ILLMAgent 2. DO NOT include the command for creating a new resource group unless the query explicitly asks for it. Otherwise, assume a resource group already exists. 3. DO NOT include an additional example with made-up values unless it provides additional context or value beyond the initial command. 4. DO NOT use the line continuation operator (backslash `\` in Bash) in the generated commands. - 5. Always represent a placeholder in the form of ``. + 5. Always represent a placeholder in the form of `` and enclose it within double quotes. 6. Always use the consistent placeholder names across all your responses. For example, `` should be used for all the places where a resource group name value is needed. 7. When the commands contain placeholders, the placeholders should be summarized in markdown bullet points at the end of the response in the same order as they appear in the commands, following this format: ``` @@ -260,9 +260,9 @@ public async Task ChatAsync(string input, IShell shell) { // Process CLI handler response specially to support parameter injection. ResponseData data = null; - if (_copilotResponse.TopicName == CopilotActivity.CLIHandlerTopic) + if (_copilotResponse.TopicName is CopilotActivity.CLIHandlerTopic or CopilotActivity.PSHandlerTopic) { - data = ParseCLIHandlerResponse(shell); + data = ParseCodeResponse(shell); } if (data?.PlaceholderSet is not null) @@ -349,7 +349,7 @@ public async Task ChatAsync(string input, IShell shell) return true; } - private ResponseData ParseCLIHandlerResponse(IShell shell) + private ResponseData ParseCodeResponse(IShell shell) { string text = _copilotResponse.Text; List codeBlocks = shell.ExtractCodeBlocks(text, out List sourceInfos); @@ -402,6 +402,7 @@ private ResponseData ParseCLIHandlerResponse(IShell shell) ResponseData data = new() { Text = text, CommandSet = commands, + TopicName = _copilotResponse.TopicName, PlaceholderSet = placeholders, Locale = _copilotResponse.Locale, }; diff --git a/shell/agents/Microsoft.Azure.Agent/Command.cs b/shell/agents/Microsoft.Azure.Agent/Command.cs index 06784df6..427f0c07 100644 --- a/shell/agents/Microsoft.Azure.Agent/Command.cs +++ b/shell/agents/Microsoft.Azure.Agent/Command.cs @@ -33,9 +33,19 @@ private static string SyntaxHighlightAzCommand(string command, string parameter, const string vtReset = "\x1b[0m"; StringBuilder cStr = new(capacity: command.Length + parameter.Length + placeholder.Length + 50); - cStr.Append(vtItalic) - .Append(vtCommand).Append("az").Append(vtFgDefault).Append(command.AsSpan(2)).Append(' ') - .Append(vtParameter).Append(parameter).Append(vtFgDefault).Append(' ') + cStr.Append(vtItalic); + + int index = command.IndexOf(' '); + if (index is -1) + { + cStr.Append(vtCommand).Append(command).Append(vtFgDefault).Append(' '); + } + else + { + cStr.Append(vtCommand).Append(command.AsSpan(0, index)).Append(vtFgDefault).Append(command.AsSpan(index)).Append(' '); + } + + cStr.Append(vtParameter).Append(parameter).Append(vtFgDefault).Append(' ') .Append(vtVariable).Append(placeholder).Append(vtFgDefault) .Append(vtReset); @@ -125,15 +135,9 @@ private void ReplaceAction() // Prompt for argument without printing captions again. string value = host.PromptForArgument(argInfo, printCaption: false); + value = value?.Trim(); if (!string.IsNullOrEmpty(value)) { - // Add quotes for the value if needed. - value = value.Trim(); - if (value.StartsWith('-') || value.Contains(' ') || value.Contains('|')) - { - value = $"\"{value}\""; - } - _values.Add(item.Name, value); _agent.SaveUserValue(item.Name, value); diff --git a/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs b/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs index 7662f48c..7ac3eee9 100644 --- a/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs +++ b/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs @@ -1,23 +1,31 @@ using System.Collections.Concurrent; using System.ComponentModel; +using System.Collections.ObjectModel; using System.Diagnostics; using System.Text; using System.Text.Json; using System.Text.RegularExpressions; +using System.Management.Automation; +using System.Management.Automation.Language; +using System.Management.Automation.Runspaces; using AIShell.Abstraction; using Serilog; namespace Microsoft.Azure.Agent; +using PowerShell = System.Management.Automation.PowerShell; + internal class DataRetriever : IDisposable { + private const int WorkerCount = 3; private const string MetadataQueryTemplate = "{{\"command\":\"{0}\"}}"; private const string MetadataEndpoint = "https://cli-validation-tool-meta-qry.azurewebsites.net/api/command_metadata"; private static readonly string s_azCompleteCmd, s_azCompleteArg; private static readonly Dictionary s_azNamingRules; private static readonly ConcurrentDictionary s_azStaticDataCache; + private static readonly PowerShellPool s_psPool; private readonly Task _rootTask; private readonly HttpClient _httpClient; @@ -29,6 +37,7 @@ internal class DataRetriever : IDisposable static DataRetriever() { + s_psPool = new(size: WorkerCount); List rules = [ new("API Management Service", "apim", @@ -351,15 +360,111 @@ internal DataRetriever(ResponseData data, HttpClient httpClient) { _stop = false; _httpClient = httpClient; - _semaphore = new SemaphoreSlim(3, 3); + _semaphore = new SemaphoreSlim(WorkerCount, WorkerCount); _placeholders = new(capacity: data.PlaceholderSet.Count); _placeholderMap = new(capacity: data.PlaceholderSet.Count); - PairPlaceholders(data); + switch (data.TopicName) + { + case CopilotActivity.CLIHandlerTopic: + PairPlaceholdersForCLICode(data); + break; + case CopilotActivity.PSHandlerTopic: + PairPlaceholdersForPSCode(data); + break; + default: + throw new UnreachableException(); + } + _rootTask = Task.Run(StartProcessing); } - private void PairPlaceholders(ResponseData data) + private void PairPlaceholdersForPSCode(ResponseData data) + { + var asts = new Dictionary(data.CommandSet.Count); + + foreach (var item in data.PlaceholderSet) + { + string command = null, parameter = null; + VariableExpressionAst variable = null; + + foreach (var cmd in data.CommandSet) + { + if (!asts.TryGetValue(cmd.Script, out ScriptBlockAst sbAst)) + { + sbAst = Parser.ParseInput(cmd.Script, out _, out _); + asts.Add(cmd.Script, sbAst); + } + + if (variable is null) + { + Ast argAst = sbAst.Find( + predicate: a => a is StringConstantExpressionAst strConst && strConst.Value == item.Name, + searchNestedScriptBlocks: false); + + if (argAst?.Parent is CommandAst cmdAst) + { + (command, parameter) = GetCommandAndParameter(cmdAst, argAst); + break; + } + + // The generated powershell script may just assign the placeholders to some variables + // and use those variables in the command invocation. In that case, we search for those + // variables instead. + if (argAst?.Parent is CommandExpressionAst cmdExpr && + cmdExpr.Parent is AssignmentStatementAst assignment && + assignment.Left is VariableExpressionAst var) + { + variable = var; + } + } + + if (variable is not null) + { + Ast argAst = sbAst.Find( + predicate: a => a is VariableExpressionAst v && v.VariablePath.UserPath == variable.VariablePath.UserPath && v.Parent is CommandAst, + searchNestedScriptBlocks: false); + + if (argAst is not null) + { + var cmdAst = (CommandAst)argAst.Parent; + (command, parameter) = GetCommandAndParameter(cmdAst, argAst); + + break; + } + } + } + + if (command is null) + { + // This may happen if the generated PowerShell script assigns the placeholder values to a set of variables at the beginning. + Log.Debug("[DataRetriever] Failed to pair the placeholder '{0}' for PowerShell code.", item.Name); + } + + ArgumentPair pair = new(item, command, parameter, CopilotActivity.PSHandlerTopic); + _placeholders.Add(pair); + _placeholderMap.Add(item.Name, pair); + } + + static (string command, string parameter) GetCommandAndParameter(CommandAst cmdAst, Ast argAst) + { + string command = ((StringConstantExpressionAst)cmdAst.CommandElements[0]).Value; + string parameter = null; + + for (int i = 0; i < cmdAst.CommandElements.Count; i++) + { + if (cmdAst.CommandElements[i] == argAst) + { + parameter = cmdAst.CommandElements[i-1].Extent.Text; + break; + } + } + + return (command, parameter); + } + } + + private void PairPlaceholdersForCLICode(ResponseData data) { var cmds = new Dictionary(data.CommandSet.Count); @@ -391,7 +496,18 @@ private void PairPlaceholders(ResponseData data) continue; } - int paramIndex = script.LastIndexOf("--", argIndex); + // The placeholder value may be enclosed in double or single quotes. + if (script[argIndex - 1] is '\'' or '\"') + { + argIndex--; + } + + // The generated AzCLI command may contain both long (--xxx) and short (-x) flag forms + // for its parameters. So we need to properly handle it when looking for the parameter + // right before the placeholder value. + int paramIndex = 1 + Math.Max( + script.LastIndexOf(" --", argIndex), + script.LastIndexOf(" -", argIndex)); parameter = script.AsSpan(paramIndex, argIndex - paramIndex).Trim().ToString(); placeholderFound = true; @@ -406,6 +522,7 @@ private void PairPlaceholders(ResponseData data) command = script; parameter = null; + Log.Debug("[DataRetriever] Non-AzCLI command: '{0}'", command); placeholderFound = true; break; } @@ -417,7 +534,7 @@ private void PairPlaceholders(ResponseData data) } } - ArgumentPair pair = new(item, command, parameter); + ArgumentPair pair = new(item, command, parameter, CopilotActivity.CLIHandlerTopic); _placeholders.Add(pair); _placeholderMap.Add(item.Name, pair); } @@ -469,10 +586,10 @@ private ArgumentInfo CreateArgInfo(ArgumentPair pair) return new ArgumentInfo(item.Name, item.Desc, restriction: null, dataType, item.ValidValues); } - // Handle non-AzCLI command. - if (pair.Parameter is null) + // Pairing placeholder with command and parameter was not successful. + // This may happen when the generated code is unexpected or contains commands that is not AzPS or AzCLI. + if (pair.Command is null || pair.Parameter is null) { - Log.Debug("[DataRetriever] Non-AzCLI command: '{0}'", pair.Command); return new ArgumentInfo(item.Name, item.Desc, dataType); } @@ -483,8 +600,7 @@ private ArgumentInfo CreateArgInfo(ArgumentPair pair) return new ArgumentInfoWithNamingRule(item.Name, item.Desc, restriction, rule); } - if (string.Equals(pair.Parameter, "--name", StringComparison.OrdinalIgnoreCase) - && pair.Command.EndsWith(" create", StringComparison.OrdinalIgnoreCase)) + if (IsCreatingNewResourceWithAzCLI(pair) || IsCreatingNewResourceWithAzPS(pair)) { // Placeholder is for the name of a new resource to be created, but not in our cache. return new ArgumentInfo(item.Name, item.Desc, dataType); @@ -492,11 +608,165 @@ private ArgumentInfo CreateArgInfo(ArgumentPair pair) if (_stop) { return null; } - List suggestions = GetArgValues(pair); + IList suggestions = pair.TopicName is CopilotActivity.CLIHandlerTopic + ? GetArgValuesForAzCLI(pair) + : GetArgValuesForAzPS(pair); return new ArgumentInfo(item.Name, item.Desc, restriction: null, dataType, suggestions); + + + /* local helper methods */ + static bool IsCreatingNewResourceWithAzCLI(ArgumentPair pair) + { + return pair.TopicName is CopilotActivity.CLIHandlerTopic + && string.Equals(pair.Parameter, "--name", StringComparison.OrdinalIgnoreCase) + && pair.Command.EndsWith(" create", StringComparison.OrdinalIgnoreCase); + } + + static bool IsCreatingNewResourceWithAzPS(ArgumentPair pair) + { + return pair.TopicName is CopilotActivity.PSHandlerTopic + && string.Equals(pair.Parameter, "-Name", StringComparison.OrdinalIgnoreCase) + && pair.Command.StartsWith("New-", StringComparison.OrdinalIgnoreCase); + } + } + + private IList GetArgValuesForAzPS(ArgumentPair pair) + { + string command = pair.Command; + string parameter = pair.Parameter.TrimStart('-'); + + Runspace defaultRunspace = Runspace.DefaultRunspace; + PowerShell pwsh = s_psPool.Checkout(); + Runspace.DefaultRunspace = pwsh.Runspace; + + try + { + CommandInfo cmdInfo = null; + var r = pwsh.AddCommand("Get-Command").AddParameter("Name", command).Invoke(); + cmdInfo = r.FirstOrDefault(); + + if (cmdInfo is null) + { + Log.Debug("[DataRetriever] Cannot find the command '{0}'", command); + return null; + } + + if (!cmdInfo.Parameters.TryGetValue(parameter, out ParameterMetadata paramMetadata)) + { + Log.Debug("[DataRetriever] Cannot find the parameter '{0}' for command '{1}'", parameter, command); + return null; + } + + Log.Debug("[DataRetriever] Perform tab completion for '{0} {1} '", command, pair.Parameter); + + if (paramMetadata.ParameterType.IsEnum) + { + Log.Debug("[DataRetriever] - Enum values completion"); + return Enum.GetNames(paramMetadata.ParameterType); + } + + List returnValues = null; + foreach (var attribute in paramMetadata.Attributes) + { + if (_stop) + { + return null; + } + + if (attribute is ValidateSetAttribute setAtt) + { + Log.Debug("[DataRetriever] - ValidateSetAttribute completion"); + + returnValues = new(capacity: setAtt.ValidValues.Count); + foreach (string value in setAtt.ValidValues) + { + if (value != string.Empty) + { + returnValues.Add(value); + } + } + + return returnValues; + } + + if (attribute is ArgumentCompleterAttribute comAtt) + { + Log.Debug("[DataRetriever] - ArgumentCompleterAttribute completion"); + + if (comAtt.ScriptBlock is not null) + { + Log.Debug("[DataRetriever] - Invoke attr.ScriptBlock"); + + // Today, none of Azure PowerShell's argument completer attributes use the 'CommandAst' and + // the fake bound parameters. So we pass 'null' for both of them for simplicity. + Collection results = comAtt.ScriptBlock.Invoke(command, parameter, "", null, null); + if (results?.Count > 0) + { + returnValues = new(capacity: results.Count); + foreach (var result in results) + { + if (result.BaseObject is CompletionResult cr) + { + returnValues.Add(cr.CompletionText); + } + else + { + returnValues.Add(result.ToString()); + } + } + } + } + else + { + Log.Debug("[DataRetriever] - Invoke IArgumentCompleter"); + + IArgumentCompleter completer = comAtt.Type is not null + ? Activator.CreateInstance(comAtt.Type) as IArgumentCompleter + : comAtt is IArgumentCompleterFactory factory + ? factory.Create() + : null; + + // Today, Azure PowerShell's argument completer attributes don't use 'CommandAst' and the fake bound parameters. + // So we pass 'null' for both of them for simplicity. + IEnumerable results = completer?.CompleteArgument(command, parameter, "", null, null); + if (results is not null) + { + foreach (var result in results) + { + returnValues ??= []; + returnValues.Add(result.CompletionText); + } + } + } + + return returnValues; + } + } + } + catch (Exception e) + { + string commandLine = $"{command} {pair.Parameter}"; + Log.Error(e, "[DataRetriever] Exception while performing argument completion for '{0}'", commandLine); + if (Telemetry.Enabled) + { + Dictionary details = new() + { + ["Command"] = commandLine, + ["Message"] = "Argument completion for AzPS command raised an exception." + }; + Telemetry.Trace(AzTrace.Exception(details), e); + } + } + finally + { + Runspace.DefaultRunspace = defaultRunspace; + s_psPool.Return(pwsh); + } + + return null; } - private List GetArgValues(ArgumentPair pair) + private List GetArgValuesForAzCLI(ArgumentPair pair) { // First, try to get static argument values if they exist. bool hasCompleter = true; @@ -690,13 +960,15 @@ internal class ArgumentPair internal PlaceholderItem Placeholder { get; } internal string Command { get; } internal string Parameter { get; } + internal string TopicName { get; } internal Task ArgumentInfo { set; get; } - internal ArgumentPair(PlaceholderItem placeholder, string command, string parameter) + internal ArgumentPair(PlaceholderItem placeholder, string command, string parameter, string topicName) { Placeholder = placeholder; Command = command; Parameter = parameter; + TopicName = topicName; ArgumentInfo = null; } } diff --git a/shell/agents/Microsoft.Azure.Agent/Microsoft.Azure.Agent.csproj b/shell/agents/Microsoft.Azure.Agent/Microsoft.Azure.Agent.csproj index d484f156..8f5e7d26 100644 --- a/shell/agents/Microsoft.Azure.Agent/Microsoft.Azure.Agent.csproj +++ b/shell/agents/Microsoft.Azure.Agent/Microsoft.Azure.Agent.csproj @@ -20,6 +20,10 @@ + + contentFiles + All + diff --git a/shell/agents/Microsoft.Azure.Agent/Schema.cs b/shell/agents/Microsoft.Azure.Agent/Schema.cs index 53d7ccf0..9e33b436 100644 --- a/shell/agents/Microsoft.Azure.Agent/Schema.cs +++ b/shell/agents/Microsoft.Azure.Agent/Schema.cs @@ -121,6 +121,7 @@ internal class CopilotActivity public const string ConversationStateName = "azurecopilot/conversationstate"; public const string SuggestedResponseName = "azurecopilot/suggesteduserresponses"; public const string CLIHandlerTopic = "generate_azure_cli_scripts"; + public const string PSHandlerTopic = "generate_powershell_script"; public string Type { get; set; } public string Id { get; set; } @@ -297,6 +298,7 @@ internal class ResponseData { internal string Text { get; set; } internal string Locale { get; set; } + internal string TopicName { get; set; } internal List CommandSet { get; set; } internal List PlaceholderSet { get; set; } } diff --git a/shell/agents/Microsoft.Azure.Agent/Utils.cs b/shell/agents/Microsoft.Azure.Agent/Utils.cs index 58236f66..ecfab558 100644 --- a/shell/agents/Microsoft.Azure.Agent/Utils.cs +++ b/shell/agents/Microsoft.Azure.Agent/Utils.cs @@ -1,8 +1,13 @@ +using System.Collections.Concurrent; using System.Text.Encodings.Web; using System.Text.Json; +using System.Management.Automation.Runspaces; namespace Microsoft.Azure.Agent; +using PowerShell = System.Management.Automation.PowerShell; +using ExecutionPolicy = Microsoft.PowerShell.ExecutionPolicy; + internal static class Utils { internal const string JsonContentType = "application/json"; @@ -98,3 +103,61 @@ internal class ChatMessage public string Role { get; set; } public string Content { get; set; } } + +internal class PowerShellPool +{ + private readonly int _size; + private readonly BlockingCollection _pool; + + internal PowerShellPool(int size) + { + _size = size; + _pool = new(boundedCapacity: size); + + var iss = InitialSessionState.CreateDefault(); + iss.ImportPSModule("Az.Accounts"); + + if (OperatingSystem.IsWindows()) + { + iss.ExecutionPolicy = ExecutionPolicy.Bypass; + } + + // Pre-populate the pool on worker thread. + Task.Factory.StartNew( + CreatePowerShell, + iss, + CancellationToken.None, + TaskCreationOptions.DenyChildAttach, + TaskScheduler.Default); + } + + private void CreatePowerShell(object state) + { + var iss = (InitialSessionState)state; + + for (int i = 0; i < _size; i++) + { + var runspace = RunspaceFactory.CreateRunspace(iss); + runspace.Open(); + + var pwsh = PowerShell.Create(runspace); + _pool.Add(pwsh); + } + } + + internal PowerShell Checkout() + { + return _pool.Take(); + } + + internal void Return(PowerShell pwsh) + { + if (pwsh is not null) + { + pwsh.Commands.Clear(); + pwsh.Streams.ClearStreams(); + + _pool.Add(pwsh); + } + } +}