diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index d930bd2e..c8a46ad2 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -58,6 +58,7 @@ daddr datacenter davidanson DDCE +DDTHH debbuild DEBHELPER debian @@ -145,6 +146,8 @@ INVM Ioctl iusr jetbrains +jqlang +JOBOBJECT jobsjob joutvhu JScript @@ -175,6 +178,7 @@ MFC microsoftcblmariner microsoftlosangeles microsoftwindowsdesktop +mmm mnt msasn msp @@ -270,6 +274,7 @@ spellright splitn SRPMS SSRF +SSZ stackoverflow stdbool stdint @@ -320,10 +325,12 @@ vflji vhd vmagentlog VMGA +vmhwm VMId vmlinux vmr vmrcs +vmrss vms vns VTeam @@ -351,6 +358,7 @@ wireserverand wireserverandimds WMI workarounds +WORKINGSET WORKDIR WScript wsf diff --git a/.github/workflows/reusable-build.yml b/.github/workflows/reusable-build.yml index 0c409b0d..750b320d 100644 --- a/.github/workflows/reusable-build.yml +++ b/.github/workflows/reusable-build.yml @@ -44,7 +44,7 @@ jobs: - name: rust-toolchain uses: actions-rs/toolchain@v1.0.6 with: - toolchain: 1.69.0 + toolchain: stable - name: Install llvm Code Coverage uses: taiki-e/install-action@cargo-llvm-cov @@ -198,7 +198,7 @@ jobs: - name: rust-toolchain uses: actions-rs/toolchain@v1.0.6 with: - toolchain: 1.69.0 + toolchain: stable - name: Run Build.cmd Debug arm64 run: .\build.cmd debug arm64 diff --git a/Cargo.lock b/Cargo.lock index 875c9875..1b252ea4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4,7 +4,7 @@ version = 4 [[package]] name = "ProxyAgentExt" -version = "1.0.38" +version = "1.0.39" dependencies = [ "clap", "ctor", @@ -172,7 +172,7 @@ dependencies = [ [[package]] name = "azure-proxy-agent" -version = "1.0.38" +version = "1.0.39" dependencies = [ "aya", "bitflags", @@ -925,7 +925,7 @@ dependencies = [ [[package]] name = "proxy_agent_setup" -version = "1.0.38" +version = "1.0.39" dependencies = [ "clap", "proxy_agent_shared", @@ -937,7 +937,7 @@ dependencies = [ [[package]] name = "proxy_agent_shared" -version = "1.0.38" +version = "1.0.39" dependencies = [ "chrono", "concurrent-queue", @@ -957,6 +957,7 @@ dependencies = [ "serde-xml-rs", "serde_derive", "serde_json", + "sysinfo", "thiserror", "thread-id", "time", diff --git a/build-linux.sh b/build-linux.sh index 5a3094c5..ae4cf0fd 100755 --- a/build-linux.sh +++ b/build-linux.sh @@ -60,7 +60,7 @@ echo "======= BuildEnvironment is $BuildEnvironment" echo "======= rustup update to a particular version" -rustup_version=1.88.0 +rustup_version=1.92.0 rustup update $rustup_version # This command sets a specific Rust toolchain version for the current directory. diff --git a/build.cmd b/build.cmd index ae93d5d7..57e2a790 100644 --- a/build.cmd +++ b/build.cmd @@ -55,7 +55,7 @@ if "%Target%"=="arm64" ( ) echo ======= rustup update to a particular version of the Rust toolchain -SET rustup_version=1.88.0 +SET rustup_version=1.92.0 call rustup update %rustup_version% REM This command sets a specific Rust toolchain version for the current directory. REM It means that whenever you are in this directory, Rust commands will use the specified toolchain version, regardless of the global default. diff --git a/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj b/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj index 1e27c879..9575c6bb 100644 --- a/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj +++ b/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj @@ -9,12 +9,13 @@ private preview feed for Azure SDK for .NET packages https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json --> + false - - + + diff --git a/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh b/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh index 4b5d05fa..1ea12b8d 100755 --- a/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh +++ b/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh @@ -4,8 +4,11 @@ # SPDX-License-Identifier: MIT customOutputJsonUrl=$(echo $customOutputJsonSAS | base64 -d) +expectedSecureChannelState=$(echo $expectedSecureChannelState) echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - Start Guest Proxy Agent Validation" +echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - expectedSecureChannelState=$expectedSecureChannelState" + currentDir=$(pwd) customOutputJsonPath=$currentDir/proxyagentvalidation.json @@ -47,12 +50,104 @@ else echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - logdir does not exist" fi +echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - detecting os and installing jq" +os=$(hostnamectl | grep "Operating System") +echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - os=$os" +if [[ $os == *"Ubuntu"* ]]; then + for i in {1..3}; do + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - start installing jq via apt-get $i" + sudo apt update + sudo apt-get install -y jq + sleep 10 + install=$(apt list --installed jq) + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - install=$install" + if [[ $install == *"jq"* ]]; then + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - jq installed successfully" + break + fi + done +elif [[ $os == *"SUSE"* ]]; then + # SUSE repo does not have jq package, so we download jq binary directly + # Detect architecture + arch=$(uname -m) + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - detected architecture: $arch" + if [[ $arch == "x86_64" ]]; then + jq_binary="jq-linux-amd64" + elif [[ $arch == "aarch64" ]] || [[ $arch == "arm64" ]]; then + jq_binary="jq-linux-arm64" + else + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - unsupported architecture: $arch" + jq_binary="jq-linux64" + fi + # Download jq binary directly from GitHub + for i in {1..3}; do + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - downloading $jq_binary binary (attempt $i)" + sudo curl -L https://github.com/jqlang/jq/releases/download/jq-1.8.1/$jq_binary -o /usr/local/bin/jq + if [ $? -eq 0 ]; then + sudo chmod +x /usr/local/bin/jq + if command -v jq &> /dev/null; then + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - jq installed successfully" + jq --version + break + fi + fi + sleep 5 + done +else + for i in {1..3}; do + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - start installing jq via dnf $i" + sudo dnf -y install jq + sleep 10 + install=$(dnf list --installed jq) + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - install=$install" + if [[ $install == *"jq"* ]]; then + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - jq installed successfully" + break + fi + done +fi + +# check status.json file Content +## check timestamp of last entry in status.json file +## check the secure channel status +timeout=300 +elapsed=0 +statusFile=$logdir/status.json +secureChannelState="" + +# Current UTC time in epoch seconds +currentUtcTime=$(date -u +%s) +echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - Checking GPA status file $statusFile with 5 minute timeout" +while :; do + timestamp=$(cat "$statusFile" | jq -r '.timestamp') + # Convert timestamp to epoch seconds + timestampEpoch=$(date -u -d "$timestamp" +%s) + if ((timestampEpoch > currentUtcTime)); then + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - The last entry timestamp '$timestamp' is valid." + ## check secure channel status + secureChannelState=$(cat "$statusFile" | jq -r '.proxyAgentStatus.keyLatchStatus.states.secureChannelState') + if [[ "$secureChannelState" == "$expectedSecureChannelState" ]]; then + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - The secure channel status '$secureChannelState' matches the expected state: '$expectedSecureChannelState'." + break + else + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - The secure channel status '$secureChannelState' does not match the expected state: '$expectedSecureChannelState'." + fi + fi + ((elapsed += 3)) + if [[ $elapsed -ge $timeout ]]; then + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - Timeout reached. Error, The secureChannelState is '$secureChannelState'." + break + fi + sleep 3 +done + echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyAgentServiceExist=$guestProxyAgentServiceExist" echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyAgentServiceStatus=$guestProxyAgentServiceStatus" echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyProcessStarted=$guestProxyProcessStarted" echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyAgentLogGenerated=$guestProxyAgentLogGenerated" +echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - secureChannelState=$secureChannelState" -jsonString='{"guestProxyAgentServiceInstalled": "'$guestProxyAgentServiceExist'", "guestProxyAgentServiceStatus": "'$guestProxyAgentServiceStatus'", "guestProxyProcessStarted": "'$guestProxyProcessStarted'", "guestProxyAgentLogGenerated": "'$guestProxyAgentLogGenerated'"}' +jsonString='{"guestProxyAgentServiceInstalled": "'$guestProxyAgentServiceExist'", "guestProxyAgentServiceStatus": "'$guestProxyAgentServiceStatus'", "guestProxyProcessStarted": "'$guestProxyProcessStarted'", "secureChannelState": "'$secureChannelState'", "guestProxyAgentLogGenerated": "'$guestProxyAgentLogGenerated'"}' echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - $jsonString" # write to $customOutputJsonPath diff --git a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1 b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1 index e3f7c474..e2362587 100644 --- a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1 +++ b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1 @@ -3,7 +3,7 @@ param ( [Parameter(Mandatory=$true, Position=0)] - [string]$customOutputJsonSAS, + [string]$customOutputJsonSAS, [string]$expectedProxyAgentVersion ) Write-Output "$((Get-Date).ToUniversalTime()) - expectedProxyAgentVersion=$expectedProxyAgentVersion" diff --git a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1 b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1 index a8037d01..22c9c90b 100644 --- a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1 +++ b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1 @@ -3,13 +3,15 @@ param ( [Parameter(Mandatory=$true, Position=0)] - [string]$customOutputJsonSAS + [string]$customOutputJsonSAS, + [string]$expectedSecureChannelState ) $decodedUrlBytes = [System.Convert]::FromBase64String($customOutputJsonSAS) $decodedUrlString = [System.Text.Encoding]::UTF8.GetString($decodedUrlBytes) Write-Output "$((Get-Date).ToUniversalTime()) - Start Guest Proxy Agent Validation" +Write-Output "$((Get-Date).ToUniversalTime()) - expectedSecureChannelState=$expectedSecureChannelState" $currentFolder = $PWD.Path $customOutputJsonPath = $currentFolder + "\proxyagentvalidation.json"; @@ -21,7 +23,7 @@ $guestProxyAgentServiceExist = $true $guestProxyAgentServiceStatus = "" $guestProxyAgentProcessExist = $true -if ($service -ne $null) { +if ($null -ne $service) { Write-Output "$((Get-Date).ToUniversalTime()) - The service $serviceName exists." $guestProxyAgentServiceStatus = $service.Status } else { @@ -34,7 +36,7 @@ $processName = "GuestProxyAgent" $process = Get-Process -Name $processName -ErrorAction SilentlyContinue -if ($process -ne $null) { +if ($null -ne $process) { Write-Output "$((Get-Date).ToUniversalTime()) - The process $processName exists." } else { $guestProxyAgentProcessExist = $false @@ -57,10 +59,52 @@ if (Test-Path -Path $folderPath -PathType Container) { Write-Output "$((Get-Date).ToUniversalTime()) - The folder $folderPath does not exist." } +# check status.json file Content +## check timestamp of last entry in status.json file +## check the secure channel status +$timeoutInSeconds = 300 +$statusFilePath = [IO.Path]::Combine($folderPath, "status.json") +Write-Output "$((Get-Date).ToUniversalTime()) - Checking GPA status file $statusFilePath with 5 minute timeout" +$secureChannelState = "" +$stopwatch = [System.Diagnostics.Stopwatch]::StartNew() +$currentUtcTime = (Get-Date).ToUniversalTime() +do { + $boolStatus = Test-Path -Path $statusFilePath + if ($boolStatus) { + $json = Get-Content $statusFilePath | Out-String | ConvertFrom-Json + $timestamp = $json.timestamp + if ($null -ne $timestamp -and $timestamp -ne "") { + Write-Output "$((Get-Date).ToUniversalTime()) - The status.json file contains a valid timestamp: $timestamp" + # parse the timestamp to UTC DateTime object, if must later than $currentUtcTime + $timestampDateTime = [DateTime]::Parse($timestamp).ToUniversalTime() + if ($timestampDateTime -gt $currentUtcTime) { + Write-Output "$((Get-Date).ToUniversalTime()) - The status.json timestamp $timestampDateTime is later than $currentUtcTime." + ## check secure channel status + $secureChannelState = $json.proxyAgentStatus.keyLatchStatus.states.secureChannelState + Write-Output "$((Get-Date).ToUniversalTime()) - The secure channel status is $secureChannelState." + if ($secureChannelState -eq $expectedSecureChannelState) { + Write-Output "$((Get-Date).ToUniversalTime()) - The secure channel status '$secureChannelState' matches the expected state: '$expectedSecureChannelState'." + # break + } else { + Write-Output "$((Get-Date).ToUniversalTime()) - The secure channel status '$secureChannelState' does not match the expected state: '$expectedSecureChannelState'." + } + + if ($stopwatch.Elapsed.TotalSeconds -ge $timeoutInSeconds) { + Write-Output "$((Get-Date).ToUniversalTime()) - Timeout reached. Error, The secureChannelState is '$secureChannelState'." + break + } + } + } else { + Write-Output "$((Get-Date).ToUniversalTime()) - The status.json file does not contain a valid timestamp yet." + } + } + start-sleep -Seconds 3 +} until ($false) $jsonString = '{"guestProxyAgentServiceInstalled": ' + $guestProxyAgentServiceExist.ToString().ToLower() ` + ', "guestProxyProcessStarted": ' + $guestProxyAgentProcessExist.ToString().ToLower() ` + ', "guestProxyAgentServiceStatus": "' + $guestProxyAgentServiceStatus ` + + '", "secureChannelState": "' + $secureChannelState ` + '", "guestProxyAgentLogGenerated": ' + $guestProxyAgentLogGenerated.ToString().ToLower() + '}' Write-Output "$((Get-Date).ToUniversalTime()) - $jsonString" diff --git a/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs b/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs index 19e2e066..8c17594a 100644 --- a/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs +++ b/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs @@ -3,10 +3,7 @@ using Azure.Core; using GuestProxyAgentTest.Models; using GuestProxyAgentTest.Utilities; -using Newtonsoft.Json; using System.Reflection; -using System.Runtime.Serialization; -using System.Text.Json.Serialization; namespace GuestProxyAgentTest.Settings { diff --git a/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs b/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs index 87fb512e..1373465e 100644 --- a/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs +++ b/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs @@ -1,25 +1,33 @@ -using Azure.ResourceManager.Compute.Models; +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT +// +using Azure.ResourceManager.Compute.Models; using GuestProxyAgentTest.Models; using GuestProxyAgentTest.Settings; using GuestProxyAgentTest.TestScenarios; +using GuestProxyAgentTest.Utilities; using Newtonsoft.Json; +using System.Xml.Linq; namespace GuestProxyAgentTest.TestCases { internal class EnableProxyAgentCase : TestCaseBase { - public EnableProxyAgentCase() : this("EnableProxyAgentCase", true) + public EnableProxyAgentCase() : this("EnableProxyAgentCase", true, false) { } - public EnableProxyAgentCase(string testCaseName) : this(testCaseName, true) + public EnableProxyAgentCase(string testCaseName) : this(testCaseName, true, false) { } - public EnableProxyAgentCase(string testCaseName, bool enableProxyAgent) : base(testCaseName) + public EnableProxyAgentCase(string testCaseName, bool enableProxyAgent, bool addProxyAgentExtensionForLinuxVM) : base(testCaseName) { EnableProxyAgent = enableProxyAgent; + AddProxyAgentVMExtension = addProxyAgentExtensionForLinuxVM; } internal bool EnableProxyAgent { get; set; } + internal bool AddProxyAgentVMExtension { get; set; } = false; + public override async Task StartAsync(TestCaseExecutionContext context) { var vmr = context.VirtualMachineResource; @@ -34,16 +42,23 @@ public override async Task StartAsync(TestCaseExecutionContext context) } } }; + // Only Linux VMs support flag 'AddProxyAgentExtension', + // Windows VMs always have the GPA VM Extension installed when ProxyAgentSettings.Enabled is true. + if (!Constants.IS_WINDOWS()) + { + patch.SecurityProfile.ProxyAgentSettings.AddProxyAgentExtension = AddProxyAgentVMExtension; + } if (EnableProxyAgent) { + // property 'inVMAccessControlProfileReferenceId' cannot be used together with property 'mode' patch.SecurityProfile.ProxyAgentSettings.WireServer = new HostEndpointSettings { - InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmWireServerAccessControlProfileReferenceId + InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmWireServerAccessControlProfileReferenceId, }; patch.SecurityProfile.ProxyAgentSettings.Imds = new HostEndpointSettings { - InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmIMDSAccessControlProfileReferenceId + InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmIMDSAccessControlProfileReferenceId, }; } diff --git a/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs b/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs index 40e02aac..acd55770 100644 --- a/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs +++ b/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs @@ -14,6 +14,7 @@ public class GuestProxyAgentValidationCase : TestCaseBase { private static readonly string EXPECTED_GUEST_PROXY_AGENT_SERVICE_STATUS; + private string expectedSecureChannelState = "disabled"; static GuestProxyAgentValidationCase() { if (Constants.IS_WINDOWS()) @@ -28,9 +29,16 @@ static GuestProxyAgentValidationCase() public GuestProxyAgentValidationCase() : base("GuestProxyAgentValidationCase") { } + public GuestProxyAgentValidationCase(string testCaseName, string expectedSecureChannelState) : base(testCaseName) + { + this.expectedSecureChannelState = expectedSecureChannelState; + } + public override async Task StartAsync(TestCaseExecutionContext context) { - context.TestResultDetails = (await RunScriptViaRunCommandV2Async(context, Constants.GUEST_PROXY_AGENT_VALIDATION_SCRIPT_NAME, null!)).ToTestResultDetails(ConsoleLog); + List<(string, string)> parameterList = new List<(string, string)>(); + parameterList.Add(("expectedSecureChannelState", expectedSecureChannelState)); + context.TestResultDetails = (await RunScriptViaRunCommandV2Async(context, Constants.GUEST_PROXY_AGENT_VALIDATION_SCRIPT_NAME, parameterList)).ToTestResultDetails(ConsoleLog); if (context.TestResultDetails.Succeed && context.TestResultDetails.CustomOut != null) { var validationDetails = context.TestResultDetails.SafeDeserializedCustomOutAs(); @@ -40,7 +48,8 @@ public override async Task StartAsync(TestCaseExecutionContext context) && validationDetails.GuestProxyAgentServiceInstalled && validationDetails.GuestProxyAgentServiceStatus.Equals(EXPECTED_GUEST_PROXY_AGENT_SERVICE_STATUS, StringComparison.OrdinalIgnoreCase) && validationDetails.GuestProxyProcessStarted - && validationDetails.GuestProxyAgentLogGenerated) + && validationDetails.GuestProxyAgentLogGenerated + && validationDetails.SecureChannelState.Equals(expectedSecureChannelState, StringComparison.OrdinalIgnoreCase)) { context.TestResultDetails.Succeed = true; } @@ -58,5 +67,6 @@ class GuestProxyAgentValidationDetails public bool GuestProxyProcessStarted { get; set; } public bool GuestProxyAgentLogGenerated { get; set; } public string GuestProxyAgentServiceStatus { get; set; } = null!; + public string SecureChannelState { get; set; } = null!; } } diff --git a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml index d5dfdccb..c68fd4df 100644 --- a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml +++ b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml @@ -9,4 +9,6 @@ scenarios: - name: LinuxPackageScenario className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario - name: ProxyAgentExtension - className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension \ No newline at end of file + className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension + - name: LinuxImplicitExtension + className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension \ No newline at end of file diff --git a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml index ad34139f..68dfb487 100644 --- a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml +++ b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml @@ -9,4 +9,6 @@ scenarios: - name: LinuxPackageScenario className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario - name: ProxyAgentExtension - className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension \ No newline at end of file + className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension + - name: LinuxImplicitExtension + className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension \ No newline at end of file diff --git a/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml b/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml index 0a0d1e6b..3cd46b95 100644 --- a/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml +++ b/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml @@ -1,6 +1,5 @@ testGroupList: - include: AzureLinux3-Fips-TestGroup.yml - - include: Mariner2-Fips-TestGroup.yml - include: Ubuntu24-TestGroup.yml - include: Ubuntu22-TestGroup.yml - include: Ubuntu20-TestGroup.yml diff --git a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml index e10e26b1..5664d9f6 100644 --- a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml +++ b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml @@ -9,4 +9,6 @@ scenarios: - name: LinuxPackageScenario className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario - name: ProxyAgentExtension - className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension \ No newline at end of file + className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension + - name: LinuxImplicitExtension + className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension \ No newline at end of file diff --git a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml index 24c28233..bda47658 100644 --- a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml +++ b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml @@ -9,4 +9,6 @@ scenarios: - name: LinuxPackageScenario className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario - name: ProxyAgentExtension - className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension \ No newline at end of file + className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension + - name: LinuxImplicitExtension + className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension \ No newline at end of file diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs b/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs index 4226d896..3bcc2c9b 100644 --- a/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs +++ b/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs @@ -27,6 +27,7 @@ public override void TestScenarioSetup() // it will add GPA VM Extension and overwrite the private GPA package AddTestCase(new EnableProxyAgentCase()); secureChannelEnabled = true; + AddTestCase(new GuestProxyAgentValidationCase("GuestProxyAgentValidationWithSecureChannelEnabled", "WireServer Enforce - IMDS Enforce - HostGA Enforce")); } AddTestCase(new IMDSPingTestCase("IMDSPingTestBeforeReboot", secureChannelEnabled)); diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/LinuxImplicitExtension.cs b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxImplicitExtension.cs new file mode 100644 index 00000000..c535c097 --- /dev/null +++ b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxImplicitExtension.cs @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT +using GuestProxyAgentTest.TestCases; +using GuestProxyAgentTest.Utilities; + +namespace GuestProxyAgentTest.TestScenarios +{ + public class LinuxImplicitExtension : TestScenarioBase + { + public override void TestScenarioSetup() + { + if (Constants.IS_WINDOWS()) + { + throw new InvalidOperationException("LinuxImplicitExtension scenario can only run on Linux VMs."); + } + + // Passing in 0 version number for the first validation case + string proxyAgentVersionBeforeUpdate = "0"; + string proxyAgentVersion = Settings.TestSetting.Instance.proxyAgentVersion; + ConsoleLog(string.Format("Received ProxyAgent Version:{0}", proxyAgentVersion)); + // implicitly enable the Guest Proxy Agent extension by setting EnableProxyAgent to true and AddProxyAgentVMExtension to true + AddTestCase(new EnableProxyAgentCase("EnableProxyAgentCase", true, true)); + AddTestCase(new GuestProxyAgentExtensionValidationCase("GuestProxyAgentExtensionValidationCaseBeforeUpdate", proxyAgentVersionBeforeUpdate)); + AddTestCase(new InstallOrUpdateGuestProxyAgentExtensionCase()); + AddTestCase(new GuestProxyAgentExtensionValidationCase("GuestProxyAgentExtensionValidationCaseAfterUpdate", proxyAgentVersion)); + AddTestCase(new IMDSPingTestCase("IMDSPingTestBeforeReboot", true)); + AddTestCase(new RebootVMCase("RebootVMCaseAfterUpdateGuestProxyAgentExtension")); + AddTestCase(new IMDSPingTestCase("IMDSPingTestAfterReboot", true)); + } + } +} diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs index 82661abe..15ba7bc8 100644 --- a/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs +++ b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs @@ -13,6 +13,7 @@ public override void TestScenarioSetup() AddTestCase(new InstallOrUpdateGuestProxyAgentPackageCase()); AddTestCase(new GuestProxyAgentValidationCase()); AddTestCase(new EnableProxyAgentCase()); + AddTestCase(new GuestProxyAgentValidationCase("GuestProxyAgentValidationWithSecureChannelEnabled", "WireServer Enforce - IMDS Enforce - HostGA Enforce")); AddTestCase(new IMDSPingTestCase("IMDSPingTestBeforeReboot", true)); AddTestCase(new RebootVMCase("RebootVMCaseAfterInstallOrUpdateGuestProxyAgent")); AddTestCase(new IMDSPingTestCase("IMDSPingTestAfterReboot", true)); diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs b/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs index 5be330e7..7b5ae1c2 100644 --- a/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs +++ b/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs @@ -2,8 +2,6 @@ // SPDX-License-Identifier: MIT using GuestProxyAgentTest.TestCases; using GuestProxyAgentTest.Utilities; -using System.Diagnostics; -using System.IO.Compression; namespace GuestProxyAgentTest.TestScenarios { diff --git a/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs b/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs index 7dc8b460..f96d2518 100644 --- a/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs +++ b/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs @@ -140,6 +140,12 @@ private async Task DoCreateVMData(ResourceGroupResource rgr, }, } }; + if (!Constants.IS_WINDOWS()) + { + // Only Linux VMs support flag 'AddProxyAgentExtension', + // Windows VMs always have the GPA VM Extension installed when ProxyAgentSettings.Enabled is true. + vmData.SecurityProfile.ProxyAgentSettings.AddProxyAgentExtension = true; + } } if (Constants.IS_WINDOWS()) diff --git a/proxy_agent/Cargo.toml b/proxy_agent/Cargo.toml index d128f2e0..8f279dda 100644 --- a/proxy_agent/Cargo.toml +++ b/proxy_agent/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "azure-proxy-agent" -version = "1.0.38" # always 3-number version +version = "1.0.39" # always 3-number version edition = "2021" build = "build.rs" readme = "README.md" diff --git a/proxy_agent/src/common/config.rs b/proxy_agent/src/common/config.rs index d9d1fdab..a9dae135 100644 --- a/proxy_agent/src/common/config.rs +++ b/proxy_agent/src/common/config.rs @@ -49,10 +49,6 @@ pub fn get_monitor_duration() -> Duration { pub fn get_poll_key_status_duration() -> Duration { Duration::from_secs(SYSTEM_CONFIG.get_poll_key_status_interval()) } -//TODO: remove this config/function once the contract is defined for HostGAPlugin -pub fn get_host_gaplugin_support() -> u8 { - SYSTEM_CONFIG.hostGAPluginSupport -} pub fn get_max_event_file_count() -> usize { SYSTEM_CONFIG.get_max_event_file_count() @@ -90,7 +86,6 @@ pub struct Config { latchKeyFolder: String, monitorIntervalInSeconds: u64, pollKeyStatusIntervalInSeconds: u64, - hostGAPluginSupport: u8, // 0 not support; 1 proxy only; 2 proxy + authentication check #[serde(skip_serializing_if = "Option::is_none")] maxEventFileCount: Option, #[serde(skip_serializing_if = "Option::is_none")] @@ -170,10 +165,6 @@ impl Config { self.pollKeyStatusIntervalInSeconds } - pub fn get_host_gaplugin_support(&self) -> u8 { - self.hostGAPluginSupport - } - pub fn get_max_event_file_count(&self) -> usize { self.maxEventFileCount .unwrap_or(constants::DEFAULT_MAX_EVENT_FILE_COUNT) @@ -269,12 +260,6 @@ mod tests { "get_poll_key_status_interval mismatch" ); - assert_eq!( - 1u8, - config.get_host_gaplugin_support(), - "get_host_gaplugin_support mismatch" - ); - assert_eq!( constants::DEFAULT_MAX_EVENT_FILE_COUNT, config.get_max_event_file_count(), diff --git a/proxy_agent/src/common/constants.rs b/proxy_agent/src/common/constants.rs index 22ae07d7..095aaf8f 100644 --- a/proxy_agent/src/common/constants.rs +++ b/proxy_agent/src/common/constants.rs @@ -17,8 +17,6 @@ pub const GA_PLUGIN_IP_NETWORK_BYTE_ORDER: u32 = 0x10813FA8; // 168.63.129.16 pub const IMDS_IP_NETWORK_BYTE_ORDER: u32 = 0xFEA9FEA9; //"169.254.169.254"; pub const PROXY_AGENT_IP_NETWORK_BYTE_ORDER: u32 = 0x100007F; //"127.0.0.1"; -pub const EMPTY_GUID: &str = "00000000-0000-0000-0000-000000000000"; - pub const KEY_DELIVERY_METHOD_HTTP: &str = "http"; pub const KEY_DELIVERY_METHOD_VTPM: &str = "vtpm"; diff --git a/proxy_agent/src/common/helpers.rs b/proxy_agent/src/common/helpers.rs index 0ee66384..518555cd 100644 --- a/proxy_agent/src/common/helpers.rs +++ b/proxy_agent/src/common/helpers.rs @@ -1,76 +1,9 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use super::logger; + use once_cell::sync::Lazy; -use proxy_agent_shared::misc_helpers; use proxy_agent_shared::telemetry::span::SimpleSpan; -#[cfg(not(windows))] -use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; - -#[cfg(windows)] -use super::windows; - -static CURRENT_SYS_INFO: Lazy<(u64, usize)> = Lazy::new(|| { - #[cfg(windows)] - { - let ram_in_mb = match windows::get_memory_in_mb() { - Ok(ram) => ram, - Err(e) => { - logger::write_error(format!("get_memory_in_mb failed: {e}")); - 0 - } - }; - let cpu_count = windows::get_processor_count(); - (ram_in_mb, cpu_count) - } - #[cfg(not(windows))] - { - let sys = System::new_with_specifics( - RefreshKind::new() - .with_memory(MemoryRefreshKind::everything()) - .with_cpu(CpuRefreshKind::everything()), - ); - let ram = sys.total_memory(); - let ram_in_mb = ram / 1024 / 1024; - let cpu_count = sys.cpus().len(); - (ram_in_mb, cpu_count) - } -}); - -static CURRENT_OS_INFO: Lazy<(String, String)> = Lazy::new(|| { - //arch - let arch = misc_helpers::get_processor_arch(); - // os - let os = misc_helpers::get_long_os_version(); - (arch, os) -}); - -pub fn get_ram_in_mb() -> u64 { - CURRENT_SYS_INFO.0 -} - -pub fn get_cpu_count() -> usize { - CURRENT_SYS_INFO.1 -} - -pub fn get_cpu_arch() -> String { - CURRENT_OS_INFO.0.to_string() -} - -pub fn get_long_os_version() -> String { - CURRENT_OS_INFO.1.to_string() -} - -// replace xml escape characters -pub fn xml_escape(s: String) -> String { - s.replace('&', "&") - .replace('\'', "'") - .replace('"', """) - .replace('<', "<") - .replace('>', ">") -} - static START: Lazy = Lazy::new(SimpleSpan::new); pub fn get_elapsed_time_in_millisec() -> u128 { @@ -85,22 +18,6 @@ pub fn write_startup_event( ) -> String { let message = START.write_event(task, method_name, module_name, logger_key); #[cfg(not(windows))] - logger::write_serial_console_log(message.clone()); + crate::common::logger::write_serial_console_log(message.clone()); message } - -#[cfg(test)] -mod tests { - #[test] - fn get_system_info_tests() { - let ram = super::get_ram_in_mb(); - assert!(ram > 100, "total ram must great than 100MB"); - let cpu_count = super::get_cpu_count(); - assert!( - cpu_count >= 1, - "total cpu count must great than or equal to 1" - ); - let cpu_arch = super::get_cpu_arch(); - assert_ne!("unknown", cpu_arch, "cpu arch cannot be 'unknown'"); - } -} diff --git a/proxy_agent/src/key_keeper.rs b/proxy_agent/src/key_keeper.rs index 1f444c19..fb527e9b 100644 --- a/proxy_agent/src/key_keeper.rs +++ b/proxy_agent/src/key_keeper.rs @@ -29,16 +29,19 @@ use self::key::Key; use crate::common::error::{Error, KeyErrorType}; use crate::common::result::Result; use crate::common::{constants, helpers, logger}; +use crate::key_keeper::key::KeyStatus; use crate::provision; use crate::proxy::authorization_rules::{AuthorizationRulesForLogging, ComputedAuthorizationRules}; +use crate::shared_state::access_control_wrapper::AccessControlSharedState; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; +use crate::shared_state::connection_summary_wrapper::ConnectionSummarySharedState; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; -use crate::shared_state::telemetry_wrapper::TelemetrySharedState; -use crate::shared_state::SharedState; +use crate::shared_state::{EventThreadsSharedState, SharedState}; use crate::{acl, redirector}; use hyper::Uri; +use proxy_agent_shared::common_state::CommonState; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; @@ -56,7 +59,7 @@ pub const MUST_SIG_WIRESERVER_IMDS: &str = "wireserverandimds"; pub const UNKNOWN_STATE: &str = "Unknown"; static FREQUENT_PULL_INTERVAL: Duration = Duration::from_secs(1); // 1 second const FREQUENT_PULL_TIMEOUT_IN_MILLISECONDS: u128 = 300000; // 5 minutes -const PROVISION_TIMEUP_IN_MILLISECONDS: u128 = 120000; // 2 minute +const PROVISION_TIMEOUT_IN_MILLISECONDS: u128 = 120000; // 2 minute const DELAY_START_EVENT_THREADS_IN_MILLISECONDS: u128 = 60000; // 1 minute #[derive(Clone)] @@ -73,14 +76,27 @@ pub struct KeyKeeper { cancellation_token: CancellationToken, /// key_keeper_shared_state: the sender for the key details, secure channel state, access control rule key_keeper_shared_state: KeyKeeperSharedState, - /// telemetry_shared_state: the sender for the telemetry events - telemetry_shared_state: TelemetrySharedState, + /// common_state: the sender for the common states + common_state: CommonState, /// redirector_shared_state: the sender for the redirector/eBPF module redirector_shared_state: RedirectorSharedState, /// provision_shared_state: the sender for the provision state provision_shared_state: ProvisionSharedState, /// agent_status_shared_state: the sender for the agent status agent_status_shared_state: AgentStatusSharedState, + /// access_control_shared_state: the sender for the access control rules + access_control_shared_state: AccessControlSharedState, + /// connection_summary_shared_state: the sender for the connection summary module + connection_summary_shared_state: ConnectionSummarySharedState, +} + +/// Reason for waking up from sleep +/// Used in the pull_secure_channel_status loop +/// Notified: woke up due to notification +/// TimerElapsed: woke up due to sleep-timer elapsed +enum WakeReason { + Notified, + TimerElapsed, } impl KeyKeeper { @@ -98,10 +114,12 @@ impl KeyKeeper { interval, cancellation_token: shared_state.get_cancellation_token(), key_keeper_shared_state: shared_state.get_key_keeper_shared_state(), - telemetry_shared_state: shared_state.get_telemetry_shared_state(), + common_state: shared_state.get_common_state(), redirector_shared_state: shared_state.get_redirector_shared_state(), provision_shared_state: shared_state.get_provision_shared_state(), agent_status_shared_state: shared_state.get_agent_status_shared_state(), + access_control_shared_state: shared_state.get_access_control_shared_state(), + connection_summary_shared_state: shared_state.get_connection_summary_shared_state(), } } @@ -171,7 +189,7 @@ impl KeyKeeper { async fn loop_poll(&self) { let mut first_iteration: bool = true; let mut started_event_threads: bool = false; - let mut provision_timeup: bool = false; + let mut provision_timeout: bool = false; let notify = match self.key_keeper_shared_state.get_notify().await { Ok(notify) => notify, Err(e) => { @@ -191,11 +209,10 @@ impl KeyKeeper { )); } - let mut start = Instant::now(); + let mut provision_start_time = Instant::now(); + let mut redirect_policy_updated = false; loop { if !first_iteration { - // skip the sleep for the first loop - let current_state = match self .key_keeper_shared_state .get_current_secure_channel_state() @@ -210,376 +227,569 @@ impl KeyKeeper { } }; - let sleep = if current_state == UNKNOWN_STATE - && helpers::get_elapsed_time_in_millisec() - < FREQUENT_PULL_TIMEOUT_IN_MILLISECONDS - { - // frequent poll the secure channel status every second for the first 5 minutes - // until the secure channel state is known - FREQUENT_PULL_INTERVAL - } else { - self.interval - }; + let sleep_interval = self.calculate_sleep_duration(¤t_state); + let (continue_loop, reset_timer) = self + .handle_notification(¬ify, sleep_interval, ¤t_state) + .await; - let time = Instant::now(); - tokio::select! { - // notify to query the secure channel status immediately when the secure channel state is unknown or disabled - // this is to handle quicker response to the secure channel state change during VM provisioning. - _ = notify.notified() => { - if current_state == DISABLE_STATE || current_state == UNKNOWN_STATE { - logger::write_warning(format!("poll_secure_channel_status task notified and secure channel state is '{current_state}', reset states and start poll status now.")); - provision::key_latch_ready_state_reset(self.provision_shared_state.clone()).await; - if let Err(e) = self.key_keeper_shared_state.update_current_secure_channel_state(UNKNOWN_STATE.to_string()).await{ - logger::write_warning(format!("Failed to update secure channel state to 'Unknown': {e}")); - } - - if start.elapsed().as_millis() > PROVISION_TIMEUP_IN_MILLISECONDS { - // already timeup, reset the start timer - start = Instant::now(); - provision_timeup = false; - } - } else { - // report key latched ready to try update the provision finished time_tick - provision::key_latched( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ).await; - let slept_time_in_millisec = time.elapsed().as_millis(); - let continue_sleep = sleep.as_millis() - slept_time_in_millisec; - if continue_sleep > 0 { - let continue_sleep = Duration::from_millis(continue_sleep as u64); - let message = format!("poll_secure_channel_status task notified but secure channel state is '{current_state}', continue with sleep wait for {continue_sleep:?}."); - logger::write_warning(message); - tokio::time::sleep(continue_sleep).await; - } - } - }, - _ = tokio::time::sleep(sleep) => {} + if reset_timer { + provision_start_time = Instant::now(); + provision_timeout = false; } - } - first_iteration = false; - if !provision_timeup && start.elapsed().as_millis() > PROVISION_TIMEUP_IN_MILLISECONDS { - provision::provision_timeup( - None, - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) - .await; - provision_timeup = true; + if !continue_loop { + continue; + } } + first_iteration = false; - if !started_event_threads - && helpers::get_elapsed_time_in_millisec() - > DELAY_START_EVENT_THREADS_IN_MILLISECONDS - { - provision::start_event_threads( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) + self.handle_provision_timeout(&mut provision_start_time, &mut provision_timeout) .await; - started_event_threads = true; - } + started_event_threads = self.handle_event_threads_start(started_event_threads).await; let status = match key::get_status(&self.base_url).await { Ok(s) => s, Err(e) => { self.update_status_message(format!("Failed to get key status - {e}"), true) .await; + // failed to get status, skip to next iteration continue; } }; self.update_status_message(format!("Got key status successfully: {status}."), true) .await; - let mut access_control_rules_changed = false; - let wireserver_rule_id = status.get_wireserver_rule_id(); - let imds_rule_id: String = status.get_imds_rule_id(); - let hostga_rule_id: String = status.get_hostga_rule_id(); + self.update_access_control_rules(&status).await; - match self + let state = status.get_secure_channel_state(); + let secure_channel_state_updated = self .key_keeper_shared_state - .update_wireserver_rule_id(wireserver_rule_id.to_string()) + .update_current_secure_channel_state(state.to_string()) + .await; + + if !self.handle_key_acquisition(&status, &state).await { + // Handle key acquisition failed, skip to next iteration + continue; + } + + if self + .handle_secure_channel_state_change( + secure_channel_state_updated, + &state, + &status, + &mut redirect_policy_updated, + ) .await { - Ok((updated, old_wire_server_rule_id)) => { - if updated { - logger::write_warning(format!( - "Wireserver rule id changed from '{old_wire_server_rule_id}' to '{wireserver_rule_id}'." - )); - if let Err(e) = self - .key_keeper_shared_state - .set_wireserver_rules(status.get_wireserver_rules()) - .await - { - logger::write_error(format!("Failed to set wireserver rules: {e}")); - } - access_control_rules_changed = true; + // successfully handled secure channel state change, skip to next iteration + continue; + } + + // check and update the redirect policy if not updated successfully before, try again here + // this could happen when the eBPF/redirector module was not started yet before + if !redirect_policy_updated { + logger::write_warning( + "redirect policy was not update successfully before, retrying now".to_string(), + ); + redirect_policy_updated = self.update_redirector_policy(&status).await; + } + } + } + + /// Calculate sleep duration based on current secure channel state + fn calculate_sleep_duration(&self, current_state: &str) -> Duration { + if current_state == UNKNOWN_STATE + && helpers::get_elapsed_time_in_millisec() < FREQUENT_PULL_TIMEOUT_IN_MILLISECONDS + { + // frequent poll the secure channel status every second for the first 5 minutes + // until the secure channel state is known + FREQUENT_PULL_INTERVAL + } else { + self.interval + } + } + + async fn wait_for_wake(notify: &tokio::sync::Notify, sleep_interval: Duration) -> WakeReason { + tokio::select! { + // if a notification arrives + _ = notify.notified() => WakeReason::Notified, + // if the sleep duration elapses + _ = tokio::time::sleep(sleep_interval) => WakeReason::TimerElapsed, + } + } + + /// Handle notification and sleep logic + /// Returns (continue_loop, reset_timer) + async fn handle_notification( + &self, + notify: &tokio::sync::Notify, + sleep_interval: Duration, + current_state: &str, + ) -> (bool, bool) { + let start_time = Instant::now(); + + match Self::wait_for_wake(notify, sleep_interval).await { + WakeReason::Notified => { + self.handle_notified_state(current_state, sleep_interval, start_time) + .await + } + WakeReason::TimerElapsed => (true, false), + } + } + + /// Handle the notified state - either reset or continue with key latched + /// Returns (continue_loop, reset_timer) + async fn handle_notified_state( + &self, + current_state: &str, + sleep_interval: Duration, + current_loop_iteration_start_time: Instant, + ) -> (bool, bool) { + // notify to query the secure channel status immediately when the secure channel state is unknown or disabled + // this is to handle quicker response to the secure channel state change during VM provisioning. + if self.should_reset_state(current_state) { + self.reset_state_on_notification(current_state).await + } else { + self.continue_with_key_latched( + current_state, + sleep_interval, + current_loop_iteration_start_time, + ) + .await + } + } + + /// Check if the state should be reset based on current secure channel state + fn should_reset_state(&self, current_state: &str) -> bool { + current_state == DISABLE_STATE || current_state == UNKNOWN_STATE + } + + /// Reset state when notified in disabled or unknown state + /// Returns (continue_loop, reset_timer) + async fn reset_state_on_notification(&self, current_state: &str) -> (bool, bool) { + logger::write_warning(format!( + "poll_secure_channel_status task notified and secure channel state is '{current_state}', reset states and start poll status now." + )); + + provision::key_latch_ready_state_reset(self.provision_shared_state.clone()).await; + + if let Err(e) = self + .key_keeper_shared_state + .update_current_secure_channel_state(UNKNOWN_STATE.to_string()) + .await + { + logger::write_warning(format!( + "Failed to update secure channel state to 'Unknown': {e}" + )); + } + + (true, true) + } + + /// Continue with key latched when notified in a stable state + /// Returns (continue_loop, reset_timer) + async fn continue_with_key_latched( + &self, + current_state: &str, + sleep_interval: Duration, + current_loop_iteration_start_time: Instant, + ) -> (bool, bool) { + // report key latched ready to try update the provision finished time_tick + provision::key_latched(self.create_event_threads_shared_state()).await; + + self.handle_remaining_sleep( + current_state, + sleep_interval, + current_loop_iteration_start_time, + ) + .await; + + (true, false) + } + + /// Handle the remaining sleep time after the notification + async fn handle_remaining_sleep( + &self, + current_state: &str, + sleep_interval: Duration, + current_loop_iteration_start_time: Instant, + ) { + let slept_time_in_millisec = current_loop_iteration_start_time.elapsed().as_millis(); + let continue_sleep = sleep_interval + .as_millis() + .saturating_sub(slept_time_in_millisec); + if continue_sleep > 0 { + // continue sleep with the remaining time for current loop iteration + // it is to avoid too frequent polling when the secure channel state is stable + let continue_sleep = Duration::from_millis(continue_sleep as u64); + let message = format!( + "poll_secure_channel_status task notified but secure channel state is '{current_state}', continue with sleep wait for {continue_sleep:?}." + ); + logger::write_warning(message); + tokio::time::sleep(continue_sleep).await; + } + } + + /// Handle provision timeout logic + async fn handle_provision_timeout( + &self, + provision_start_time: &mut Instant, + provision_timeout: &mut bool, + ) { + if !*provision_timeout + && provision_start_time.elapsed().as_millis() > PROVISION_TIMEOUT_IN_MILLISECONDS + { + provision::provision_timeout( + None, + self.provision_shared_state.clone(), + self.agent_status_shared_state.clone(), + ) + .await; + *provision_timeout = true; + } + } + + /// Handle starting event threads + async fn handle_event_threads_start(&self, started_event_threads: bool) -> bool { + if !started_event_threads + && helpers::get_elapsed_time_in_millisec() > DELAY_START_EVENT_THREADS_IN_MILLISECONDS + { + provision::start_event_threads(self.create_event_threads_shared_state()).await; + return true; + } + started_event_threads + } + + /// Update access control rules from the key status + /// Returns true if any rules changed + async fn update_access_control_rules(&self, status: &KeyStatus) -> bool { + let mut access_control_rules_changed = false; + let wireserver_rule_id = status.get_wireserver_rule_id(); + let imds_rule_id = status.get_imds_rule_id(); + let hostga_rule_id = status.get_hostga_rule_id(); + + // Update wireserver rules + match self + .key_keeper_shared_state + .update_wireserver_rule_id(wireserver_rule_id.to_string()) + .await + { + Ok((updated, old_wire_server_rule_id)) => { + if updated { + logger::write_warning(format!( + "Wireserver rule id changed from '{old_wire_server_rule_id}' to '{wireserver_rule_id}'." + )); + if let Err(e) = self + .access_control_shared_state + .set_wireserver_rules(status.get_wireserver_rules()) + .await + { + logger::write_error(format!("Failed to set wireserver rules: {e}")); } - } - Err(e) => { - logger::write_warning(format!("Failed to update wireserver rule id: {e}")); + access_control_rules_changed = true; } } + Err(e) => { + logger::write_warning(format!("Failed to update wireserver rule id: {e}")); + } + } - match self - .key_keeper_shared_state - .update_imds_rule_id(imds_rule_id.to_string()) - .await - { - Ok((updated, old_imds_rule_id)) => { - if updated { - logger::write_warning(format!( - "IMDS rule id changed from '{old_imds_rule_id}' to '{imds_rule_id}'." - )); - if let Err(e) = self - .key_keeper_shared_state - .set_imds_rules(status.get_imds_rules()) - .await - { - logger::write_error(format!("Failed to set imds rules: {e}")); - } - access_control_rules_changed = true; + // Update IMDS rules + match self + .key_keeper_shared_state + .update_imds_rule_id(imds_rule_id.to_string()) + .await + { + Ok((updated, old_imds_rule_id)) => { + if updated { + logger::write_warning(format!( + "IMDS rule id changed from '{old_imds_rule_id}' to '{imds_rule_id}'." + )); + if let Err(e) = self + .access_control_shared_state + .set_imds_rules(status.get_imds_rules()) + .await + { + logger::write_error(format!("Failed to set imds rules: {e}")); } - } - Err(e) => { - logger::write_warning(format!("Failed to update imds rule id: {e}")); + access_control_rules_changed = true; } } + Err(e) => { + logger::write_warning(format!("Failed to update imds rule id: {e}")); + } + } - match self - .key_keeper_shared_state - .update_hostga_rule_id(hostga_rule_id.to_string()) - .await - { - Ok((updated, old_hostga_rule_id)) => { - if updated { - logger::write_warning(format!( - "HostGA rule id changed from '{old_hostga_rule_id}' to '{hostga_rule_id}'." - )); - if let Err(e) = self - .key_keeper_shared_state - .set_hostga_rules(status.get_hostga_rules()) - .await - { - logger::write_error(format!("Failed to set HostGA rules: {e}")); - } - access_control_rules_changed = true; + // Update HostGA rules + match self + .key_keeper_shared_state + .update_hostga_rule_id(hostga_rule_id.to_string()) + .await + { + Ok((updated, old_hostga_rule_id)) => { + if updated { + logger::write_warning(format!( + "HostGA rule id changed from '{old_hostga_rule_id}' to '{hostga_rule_id}'." + )); + if let Err(e) = self + .access_control_shared_state + .set_hostga_rules(status.get_hostga_rules()) + .await + { + logger::write_error(format!("Failed to set HostGA rules: {e}")); } + access_control_rules_changed = true; } - Err(e) => { - logger::write_warning(format!("Failed to update HostGA rule id: {e}")); - } } + Err(e) => { + logger::write_warning(format!("Failed to update HostGA rule id: {e}")); + } + } - if access_control_rules_changed { - if let (Ok(wireserver_rules), Ok(imds_rules), Ok(hostga_rules)) = ( - self.key_keeper_shared_state.get_wireserver_rules().await, - self.key_keeper_shared_state.get_imds_rules().await, - self.key_keeper_shared_state.get_hostga_rules().await, - ) { - let rules = AuthorizationRulesForLogging::new( - status.authorizationRules.clone(), - ComputedAuthorizationRules { - wireserver: wireserver_rules, - imds: imds_rules, - hostga: hostga_rules, - }, + // Write authorization rules to file if changed + if access_control_rules_changed { + if let (Ok(wireserver_rules), Ok(imds_rules), Ok(hostga_rules)) = ( + self.access_control_shared_state + .get_wireserver_rules() + .await, + self.access_control_shared_state.get_imds_rules().await, + self.access_control_shared_state.get_hostga_rules().await, + ) { + let rules = AuthorizationRulesForLogging::new( + status.authorizationRules.clone(), + ComputedAuthorizationRules { + wireserver: wireserver_rules, + imds: imds_rules, + hostga: hostga_rules, + }, + ); + rules.write_all(&self.status_dir, constants::MAX_LOG_FILE_COUNT); + } + } + + access_control_rules_changed + } + + /// Handle key acquisition from local or server + /// Returns true if successful, false if should continue to next iteration + async fn handle_key_acquisition(&self, status: &KeyStatus, state: &str) -> bool { + // check if need fetch the key + if state != DISABLE_STATE + && (status.keyGuid.is_none() // key has not latched yet + || status.keyGuid != self.key_keeper_shared_state.get_current_key_guid().await.unwrap_or(None)) + // key changed + { + if self.try_fetch_local_key(&status.keyGuid).await { + return true; + } + + // if key has not latched before, or not found, or could not read locally, + // try fetch from server + return self.acquire_key_from_server().await; + } + true + } + + /// Try to fetch key from local storage + /// Returns true if key was found and loaded successfully + async fn try_fetch_local_key(&self, key_guid: &Option) -> bool { + if let Some(guid) = key_guid { + // key latched before and search the key locally first + match Self::fetch_key(&self.key_dir, guid) { + Ok(key) => { + if let Err(e) = self.update_key_to_shared_state(key.clone()).await { + logger::write_warning(format!("Failed to update key: {e}")); + } + + let message = helpers::write_startup_event( + "Found key details from local and ready to use.", + "poll_secure_channel_status", + "key_keeper", + logger::AGENT_LOGGER_KEY, + ); + self.update_status_message(message, false).await; + + provision::key_latched(self.create_event_threads_shared_state()).await; + return true; + } + Err(e) => { + event_logger::write_event( + LoggerLevel::Info, + format!("Failed to fetch local key details with error: {e:?}. Will try acquire the key details from Server."), + "poll_secure_channel_status", + "key_keeper", + logger::AGENT_LOGGER_KEY, ); - rules.write_all(&self.status_dir, constants::MAX_LOG_FILE_COUNT); } + }; + } + false + } + + /// Acquire key from server, persist it, and attest it + /// Returns true if successful, false if should continue to next iteration + async fn acquire_key_from_server(&self) -> bool { + let key = match key::acquire_key(&self.base_url).await { + Ok(k) => k, + Err(e) => { + self.update_status_message(format!("Failed to acquire key details: {e:?}"), true) + .await; + return false; } + }; - let state = status.get_secure_channel_state(); - let secure_channel_state_updated = self - .key_keeper_shared_state - .update_current_secure_channel_state(state.to_string()) + // persist the new key to local disk + let guid = key.guid.to_string(); + if let Err(e) = Self::store_key(&self.key_dir, &key) { + self.update_status_message(format!("Failed to save key details to file: {e:?}"), true) .await; + return false; + } + logger::write_information(format!( + "Successfully acquired the key '{guid}' details from server and saved locally." + )); - // check if need fetch the key - if state != DISABLE_STATE - && (status.keyGuid.is_none() // key has not latched yet - || status.keyGuid != self.key_keeper_shared_state.get_current_key_guid().await.unwrap_or(None)) - // key changed - { - let mut key_found = false; - if let Some(guid) = &status.keyGuid { - // key latched before and search the key locally first - match Self::fetch_key(&self.key_dir, guid) { - Ok(key) => { - if let Err(e) = - self.key_keeper_shared_state.update_key(key.clone()).await - { - logger::write_warning(format!("Failed to update key: {e}")); - } - - let message = helpers::write_startup_event( - "Found key details from local and ready to use.", - "poll_secure_channel_status", - "key_keeper", - logger::AGENT_LOGGER_KEY, - ); - self.update_status_message(message, false).await; - key_found = true; - - provision::key_latched( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) - .await; - } - Err(e) => { - event_logger::write_event( - LoggerLevel::Info, - format!("Failed to fetch local key details with error: {e:?}. Will try acquire the key details from Server."), - "poll_secure_channel_status", - "key_keeper", - logger::AGENT_LOGGER_KEY, - ); - } - }; + // double check the key details saved correctly to local disk + if let Err(e) = Self::check_key(&self.key_dir, &key) { + self.update_status_message( + format!("Failed to check the key '{guid}' details saved locally: {e:?}."), + true, + ) + .await; + return false; + } + + // attest the key + match key::attest_key(&self.base_url, &key).await { + Ok(()) => { + // update in memory + if let Err(e) = self.update_key_to_shared_state(key.clone()).await { + logger::write_warning(format!("Failed to update key: {e}")); } - // if key has not latched before, - // or not found - // or could not read locally, - // try fetch from server - if !key_found { - let key = match key::acquire_key(&self.base_url).await { - Ok(k) => k, - Err(e) => { - self.update_status_message( - format!("Failed to acquire key details: {e:?}"), - true, - ) - .await; - continue; - } - }; - - // persist the new key to local disk - let guid = key.guid.to_string(); - match Self::store_key(&self.key_dir, &key) { - Ok(()) => { - logger::write_information(format!( - "Successfully acquired the key '{guid}' details from server and saved locally.")); - } - Err(e) => { - self.update_status_message( - format!("Failed to save key details to file: {e:?}"), - true, - ) - .await; - continue; - } - } + let message = helpers::write_startup_event( + "Successfully attest the key and ready to use.", + "poll_secure_channel_status", + "key_keeper", + logger::AGENT_LOGGER_KEY, + ); + self.update_status_message(message, false).await; + provision::key_latched(self.create_event_threads_shared_state()).await; + true + } + Err(e) => { + logger::write_warning(format!("Failed to attest the key: {e:?}")); + false + } + } + } - // double check the key details saved correctly to local disk - if let Err(e) = Self::check_key(&self.key_dir, &key) { - self.update_status_message( - format!( - "Failed to check the key '{guid}' details saved locally: {e:?}." - ), - true, - ) - .await; - continue; - } else { - match key::attest_key(&self.base_url, &key).await { - Ok(()) => { - // update in memory - if let Err(e) = - self.key_keeper_shared_state.update_key(key.clone()).await - { - logger::write_warning(format!("Failed to update key: {e}")); - } - - let message = helpers::write_startup_event( - "Successfully attest the key and ready to use.", - "poll_secure_channel_status", - "key_keeper", - logger::AGENT_LOGGER_KEY, - ); - self.update_status_message(message, false).await; - provision::key_latched( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) - .await; - } - Err(e) => { - logger::write_warning(format!("Failed to attest the key: {e:?}")); - continue; - } + /// Handle secure channel state change + /// Returns true if should continue to next iteration + async fn handle_secure_channel_state_change( + &self, + secure_channel_state_updated: std::result::Result, + state: &str, + status: &KeyStatus, + redirect_policy_updated: &mut bool, + ) -> bool { + match secure_channel_state_updated { + Ok(updated) => { + if updated { + // secure channel state changed, update the redirect policy + *redirect_policy_updated = self.update_redirector_policy(status).await; + + // customer has not enforce the secure channel state + if state == DISABLE_STATE { + let message = helpers::write_startup_event( + "Customer has not enforce the secure channel state.", + "poll_secure_channel_status", + "key_keeper", + logger::AGENT_LOGGER_KEY, + ); + // Update the status message and let the provision to continue + self.update_status_message(message, false).await; + + // clear key in memory for disabled state + if let Err(e) = self.key_keeper_shared_state.clear_key().await { + logger::write_warning(format!("Failed to clear key: {e}")); } + provision::key_latched(self.create_event_threads_shared_state()).await; } + + return true; } } + Err(e) => { + logger::write_warning(format!("Failed to update secure channel state: {e}")); + } + } + false + } - // update redirect policy if current secure channel state updated - match secure_channel_state_updated { - Ok(updated) => { - if updated { - // update the redirector policy map - redirector::update_wire_server_redirect_policy( - status.get_wire_server_mode() != DISABLE_STATE, - self.redirector_shared_state.clone(), - ) - .await; - redirector::update_imds_redirect_policy( - status.get_imds_mode() != DISABLE_STATE, - self.redirector_shared_state.clone(), - ) - .await; - redirector::update_hostga_redirect_policy( - status.get_hostga_mode() != DISABLE_STATE, - self.redirector_shared_state.clone(), - ) - .await; + /// Create EventThreadsSharedState from current state + fn create_event_threads_shared_state(&self) -> EventThreadsSharedState { + EventThreadsSharedState { + cancellation_token: self.cancellation_token.clone(), + common_state: self.common_state.clone(), + access_control_shared_state: self.access_control_shared_state.clone(), + redirector_shared_state: self.redirector_shared_state.clone(), + key_keeper_shared_state: self.key_keeper_shared_state.clone(), + provision_shared_state: self.provision_shared_state.clone(), + agent_status_shared_state: self.agent_status_shared_state.clone(), + connection_summary_shared_state: self.connection_summary_shared_state.clone(), + } + } - // customer has not enforce the secure channel state - if state == DISABLE_STATE { - let message = helpers::write_startup_event( - "Customer has not enforce the secure channel state.", - "poll_secure_channel_status", - "key_keeper", - logger::AGENT_LOGGER_KEY, - ); - // Update the status message and let the provision to continue - self.update_status_message(message, false).await; - - // clear key in memory for disabled state - if let Err(e) = self.key_keeper_shared_state.clear_key().await { - logger::write_warning(format!("Failed to clear key: {e}")); - } - provision::key_latched( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) - .await; - } - } - } - Err(e) => { - logger::write_warning(format!("Failed to update secure channel state: {e}")); - } - } + async fn update_key_to_shared_state(&self, key: Key) -> Result<()> { + self.key_keeper_shared_state.update_key(key.clone()).await?; + + // update the current key guid and value to common states + self.common_state + .set_state( + proxy_agent_shared::common_state::SECURE_KEY_GUID.to_string(), + key.guid.to_string(), + ) + .await?; + self.common_state + .set_state( + proxy_agent_shared::common_state::SECURE_KEY_VALUE.to_string(), + key.key.to_string(), + ) + .await?; + Ok(()) + } + + /// update the redirector/eBPF policy based on the secure channel status + /// it should be called when the secure channel state is changed + async fn update_redirector_policy(&self, status: &KeyStatus) -> bool { + // update the redirector policy map + if !redirector::update_wire_server_redirect_policy( + status.get_wire_server_mode() != DISABLE_STATE, + self.redirector_shared_state.clone(), + ) + .await + { + return false; + } + if !redirector::update_imds_redirect_policy( + status.get_imds_mode() != DISABLE_STATE, + self.redirector_shared_state.clone(), + ) + .await + { + return false; } + if !redirector::update_hostga_redirect_policy( + status.get_hostga_mode() != DISABLE_STATE, + self.redirector_shared_state.clone(), + ) + .await + { + return false; + } + + true } async fn update_status_message(&self, message: String, log_to_file: bool) { @@ -849,10 +1059,12 @@ mod tests { interval: Duration::from_millis(10), cancellation_token: cancellation_token.clone(), key_keeper_shared_state: key_keeper::KeyKeeperSharedState::start_new(), - telemetry_shared_state: key_keeper::TelemetrySharedState::start_new(), + common_state: key_keeper::CommonState::start_new(), redirector_shared_state: key_keeper::RedirectorSharedState::start_new(), provision_shared_state: key_keeper::ProvisionSharedState::start_new(), agent_status_shared_state: key_keeper::AgentStatusSharedState::start_new(), + access_control_shared_state: key_keeper::AccessControlSharedState::start_new(), + connection_summary_shared_state: key_keeper::ConnectionSummarySharedState::start_new(), }; tokio::spawn({ diff --git a/proxy_agent/src/key_keeper/key.rs b/proxy_agent/src/key_keeper/key.rs index 7039b0e6..667b0932 100644 --- a/proxy_agent/src/key_keeper/key.rs +++ b/proxy_agent/src/key_keeper/key.rs @@ -580,12 +580,14 @@ impl KeyStatus { } pub fn get_hostga_rules(&self) -> Option { - // short-term: HostGA has no rules + // match &self.authorizationRules { + // Some(rules) => rules.hostga.clone(), + // None => None, + // } + + // short-term: HostGA uses wireserver rules // long-term: TBD - match &self.authorizationRules { - Some(rules) => rules.hostga.clone(), - None => None, - } + self.get_wireserver_rules() } pub fn get_wire_server_mode(&self) -> String { @@ -1261,95 +1263,105 @@ mod tests { "roleAssignment identities mismatch" ); - // Validate HostGA rules + // Validate HostGA rules to match WireServer rules let hostga_rules = status.get_hostga_rules().unwrap(); assert_eq!( - "allow", hostga_rules.defaultAccess, - "defaultAccess mismatch" + wireserver_rules.defaultAccess, hostga_rules.defaultAccess, + "hostga_rules defaultAccess mismatch" ); assert_eq!( - "sigid", + status.get_wireserver_rule_id(), status.get_hostga_rule_id(), "HostGA rule id mismatch" ); - assert_eq!("enforce", status.get_hostga_mode(), "HostGA mode mismatch"); - - // Validate HostGA rule details - // Retrieve and validate second privilege for HostGA - let privilege = &hostga_rules - .rules - .as_ref() - .unwrap() - .privileges - .as_ref() - .unwrap()[1]; - - assert_eq!("test2", privilege.name, "privilege name mismatch"); - assert_eq!("/test2", privilege.path, "privilege path mismatch"); - - assert_eq!( - "value3", - privilege.queryParameters.as_ref().unwrap()["key1"], - "privilege queryParameters mismatch" - ); - assert_eq!( - "value4", - privilege.queryParameters.as_ref().unwrap()["key2"], - "privilege queryParameters mismatch" - ); - - // Retrieve and validate second role for HostGA - let role = &hostga_rules.rules.as_ref().unwrap().roles.as_ref().unwrap()[1]; - assert_eq!("test6", role.name, "role name mismatch"); - assert_eq!("test4", role.privileges[0], "role privilege mismatch"); - assert_eq!("test5", role.privileges[1], "role privilege mismatch"); - - // Retrieve and validate first identity for HostGA - let identity = &hostga_rules - .rules - .as_ref() - .unwrap() - .identities - .as_ref() - .unwrap()[0]; - assert_eq!("test", identity.name, "identity name mismatch"); - assert_eq!( - "test", - identity.userName.as_ref().unwrap(), - "identity userName mismatch" - ); - assert_eq!( - "test", - identity.groupName.as_ref().unwrap(), - "identity groupName mismatch" - ); - assert_eq!( - "test", - identity.exePath.as_ref().unwrap(), - "identity exePath mismatch" - ); - assert_eq!( - "test", - identity.processName.as_ref().unwrap(), - "identity processName mismatch" - ); - - // Retrieve and validate first role assignment for HostGA - let role_assignment = &hostga_rules - .rules - .as_ref() - .unwrap() - .roleAssignments - .as_ref() - .unwrap()[0]; - assert_eq!( - "test4", role_assignment.role, - "roleAssignment role mismatch" - ); assert_eq!( - "test", role_assignment.identities[0], - "roleAssignment identities mismatch" - ); + status.get_wire_server_mode(), + status.get_hostga_mode(), + "HostGA mode mismatch" + ); + + /*** + * + * Validate HostGA rule details in future when HostGA has independent rules + + // Validate HostGA rule details + // Retrieve and validate second privilege for HostGA + let privilege = &hostga_rules + .rules + .as_ref() + .unwrap() + .privileges + .as_ref() + .unwrap()[1]; + + assert_eq!("test2", privilege.name, "privilege name mismatch"); + assert_eq!("/test2", privilege.path, "privilege path mismatch"); + + assert_eq!( + "value3", + privilege.queryParameters.as_ref().unwrap()["key1"], + "privilege queryParameters mismatch" + ); + assert_eq!( + "value4", + privilege.queryParameters.as_ref().unwrap()["key2"], + "privilege queryParameters mismatch" + ); + + // Retrieve and validate second role for HostGA + let role = &hostga_rules.rules.as_ref().unwrap().roles.as_ref().unwrap()[1]; + assert_eq!("test6", role.name, "role name mismatch"); + assert_eq!("test4", role.privileges[0], "role privilege mismatch"); + assert_eq!("test5", role.privileges[1], "role privilege mismatch"); + + // Retrieve and validate first identity for HostGA + let identity = &hostga_rules + .rules + .as_ref() + .unwrap() + .identities + .as_ref() + .unwrap()[0]; + assert_eq!("test", identity.name, "identity name mismatch"); + assert_eq!( + "test", + identity.userName.as_ref().unwrap(), + "identity userName mismatch" + ); + assert_eq!( + "test", + identity.groupName.as_ref().unwrap(), + "identity groupName mismatch" + ); + assert_eq!( + "test", + identity.exePath.as_ref().unwrap(), + "identity exePath mismatch" + ); + assert_eq!( + "test", + identity.processName.as_ref().unwrap(), + "identity processName mismatch" + ); + + // Retrieve and validate first role assignment for HostGA + let role_assignment = &hostga_rules + .rules + .as_ref() + .unwrap() + .roleAssignments + .as_ref() + .unwrap()[0]; + assert_eq!( + "test4", role_assignment.role, + "roleAssignment role mismatch" + ); + assert_eq!( + "test", role_assignment.identities[0], + "roleAssignment identities mismatch" + ); + * + */ } #[test] diff --git a/proxy_agent/src/main.rs b/proxy_agent/src/main.rs index 460425ca..7a73720c 100644 --- a/proxy_agent/src/main.rs +++ b/proxy_agent/src/main.rs @@ -10,7 +10,6 @@ pub mod proxy_agent_status; pub mod redirector; pub mod service; pub mod shared_state; -pub mod telemetry; use common::cli::{Commands, CLI}; use common::constants; diff --git a/proxy_agent/src/provision.rs b/proxy_agent/src/provision.rs index 4c4774c5..fd11fbea 100644 --- a/proxy_agent/src/provision.rs +++ b/proxy_agent/src/provision.rs @@ -5,94 +5,23 @@ //! It is used to track the provision state for each module and write the provision state to provisioned.tag and status.tag files. //! It also provides the http handler to query the provision status for GPA service. //! It is used to query the provision status from GPA service http listener. -//! Example for GPA service: -//! ```rust -//! use proxy_agent::provision; -//! use proxy_agent::shared_state::agent_status_wrapper::AgentStatusModule; -//! use proxy_agent::shared_state::agent_status_wrapper::AgentStatusSharedState; -//! use proxy_agent::shared_state::key_keeper_wrapper::KeyKeeperSharedState; -//! use proxy_agent::shared_state::provision_wrapper::ProvisionSharedState; -//! use proxy_agent::shared_state::SharedState; -//! use proxy_agent::shared_state::telemetry_wrapper::TelemetrySharedState; -//! -//! use std::time::Duration; -//! -//! let shared_state = SharedState::start_all(); -//! let cancellation_token = shared_state.get_cancellation_token(); -//! let key_keeper_shared_state = shared_state.get_key_keeper_shared_state(); -//! let telemetry_shared_state = shared_state.get_telemetry_shared_state(); -//! let provision_shared_state = shared_state.get_provision_shared_state(); -//! let agent_status_shared_state = shared_state.get_agent_status_shared_state(); -//! -//! let provision_state = provision::get_provision_state( -//! provision_shared_state.clone(), -//! agent_status_shared_state.clone(), -//! ).await; -//! assert_eq!(false, provision_state.finished); -//! assert_eq!(0, provision_state.errorMessage.len()); -//! -//! // update provision state when each provision finished -//! provision::redirector_ready( -//! cancellation_token.clone(), -//! key_keeper_shared_state.clone(), -//! telemetry_shared_state.clone(), -//! provision_shared_state.clone(), -//! agent_status_shared_state.clone(), -//! ).await; -//! provision::key_latched( -//! cancellation_token.clone(), -//! key_keeper_shared_state.clone(), -//! telemetry_shared_state.clone(), -//! provision_shared_state.clone(), -//! agent_status_shared_state.clone(), -//! ).await; -//! provision::listener_started( -//! cancellation_token.clone(), -//! key_keeper_shared_state.clone(), -//! telemetry_shared_state.clone(), -//! provision_shared_state.clone(), -//! agent_status_shared_state.clone(), -//! ).await; -//! -//! let provision_state = provision::get_provision_state( -//! provision_shared_state.clone(), -//! agent_status_shared_state.clone(), -//! ).await; -//! assert_eq!(true, provision_state.finished); -//! assert_eq!(0, provision_state.errorMessage.len()); -//! ``` -//! -//! Example for GPA command line option --status [--wait seconds]: -//! ```rust -//! use proxy_agent::provision::ProvisionQuery; -//! use std::time::Duration; -//! -//! let proxy_server_port = 8092; -//! let provision_query = ProvisionQuery::new(proxy_server_port, None); -//! let provision_not_finished_state = provision_query.get_provision_status_wait().await; -//! assert_eq!(false, provision_state.0); -//! assert_eq!(0, provision_state.1.len()); -//! -//! let provision_query = ProvisionQuery::new(proxy_server_port, Some(Duration::from_millis(5))); -//! let provision_finished_state = provision_query.get_provision_status_wait().await; -//! assert_eq!(true, provision_state.0); -//! assert_eq!(0, provision_state.1.len()); -//! ``` - -use crate::common::{config, helpers, logger}; + +use crate::common::{config, logger}; use crate::key_keeper::{DISABLE_STATE, UNKNOWN_STATE}; -use crate::proxy_agent_status; +use crate::proxy::authorization_rules::AuthorizationMode; +use crate::shared_state::access_control_wrapper::AccessControlSharedState; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::shared_state::provision_wrapper::ProvisionSharedState; -use crate::shared_state::telemetry_wrapper::TelemetrySharedState; -use crate::telemetry::event_reader::EventReader; +use crate::shared_state::redirector_wrapper::RedirectorSharedState; +use crate::shared_state::EventThreadsSharedState; +use crate::{proxy_agent_status, redirector}; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::telemetry::event_logger; +use proxy_agent_shared::telemetry::event_reader::EventReader; use proxy_agent_shared::{misc_helpers, proxy_agent_aggregate_status}; use std::path::PathBuf; use std::time::Duration; -use tokio_util::sync::CancellationToken; const PROVISION_TAG_FILE_NAME: &str = "provisioned.tag"; const STATUS_TAG_TMP_FILE_NAME: &str = "status.tag.tmp"; @@ -153,63 +82,33 @@ impl ProvisionStateInternal { /// Update provision state when redirector provision finished /// It could be called by redirector module -pub async fn redirector_ready( - cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - provision_shared_state: ProvisionSharedState, - agent_status_shared_state: AgentStatusSharedState, -) { +pub async fn redirector_ready(event_threads_shared_state: EventThreadsSharedState) { update_provision_state( ProvisionFlags::REDIRECTOR_READY, None, - cancellation_token, - key_keeper_shared_state, - telemetry_shared_state, - provision_shared_state, - agent_status_shared_state, + event_threads_shared_state, ) .await; } /// Update provision state when key latch provision finished /// It could be called by key latch module -pub async fn key_latched( - cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - provision_shared_state: ProvisionSharedState, - agent_status_shared_state: AgentStatusSharedState, -) { +pub async fn key_latched(event_threads_shared_state: EventThreadsSharedState) { update_provision_state( ProvisionFlags::KEY_LATCH_READY, None, - cancellation_token, - key_keeper_shared_state, - telemetry_shared_state, - provision_shared_state, - agent_status_shared_state, + event_threads_shared_state, ) .await; } /// Update provision state when listener provision finished /// It could be called by listener module -pub async fn listener_started( - cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - provision_shared_state: ProvisionSharedState, - agent_status_shared_state: AgentStatusSharedState, -) { +pub async fn listener_started(event_threads_shared_state: EventThreadsSharedState) { update_provision_state( ProvisionFlags::LISTENER_READY, None, - cancellation_token, - key_keeper_shared_state, - telemetry_shared_state, - provision_shared_state, - agent_status_shared_state, + event_threads_shared_state, ) .await; } @@ -218,15 +117,28 @@ pub async fn listener_started( async fn update_provision_state( state: ProvisionFlags, provision_dir: Option, - cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - provision_shared_state: ProvisionSharedState, - agent_status_shared_state: AgentStatusSharedState, + event_threads_shared_state: EventThreadsSharedState, ) { - if let Ok(provision_state) = provision_shared_state.update_one_state(state).await { + if let Ok(provision_state) = event_threads_shared_state + .provision_shared_state + .update_one_state(state) + .await + { if provision_state.contains(ProvisionFlags::ALL_READY) { - if let Err(e) = provision_shared_state.set_provision_finished(true).await { + // update redirector/eBPF policy based on access control status + update_redirector_policy( + event_threads_shared_state.redirector_shared_state.clone(), + event_threads_shared_state + .access_control_shared_state + .clone(), + ) + .await; + + if let Err(e) = event_threads_shared_state + .provision_shared_state + .set_provision_finished(true) + .await + { // log the error and continue logger::write_error(format!( "update_provision_state::Failed to set provision finished with error: {e}" @@ -236,24 +148,62 @@ async fn update_provision_state( // write provision success state here write_provision_state( provision_dir, - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.provision_shared_state.clone(), + event_threads_shared_state.agent_status_shared_state.clone(), ) .await; // start event threads right after provision successfully - start_event_threads( - cancellation_token, - key_keeper_shared_state, - telemetry_shared_state, - provision_shared_state, - agent_status_shared_state, - ) - .await; + start_event_threads(event_threads_shared_state).await; } } } +/// update the redirector/eBPF policy based on access control status +/// it should be called when provision finished +async fn update_redirector_policy( + redirector_shared_state: RedirectorSharedState, + access_control_shared_state: AccessControlSharedState, +) { + let wireserver_mode = + if let Ok(Some(rules)) = access_control_shared_state.get_wireserver_rules().await { + rules.mode + } else { + // default to disabled if the rules are not ready + AuthorizationMode::Disabled + }; + redirector::update_wire_server_redirect_policy( + wireserver_mode != AuthorizationMode::Disabled, + redirector_shared_state.clone(), + ) + .await; + + let imds_mode = if let Ok(Some(rules)) = access_control_shared_state.get_imds_rules().await { + rules.mode + } else { + // default to disabled if the rules are not ready + AuthorizationMode::Disabled + }; + redirector::update_imds_redirect_policy( + imds_mode != AuthorizationMode::Disabled, + redirector_shared_state.clone(), + ) + .await; + + let ga_plugin_mode = + if let Ok(Some(rules)) = access_control_shared_state.get_hostga_rules().await { + rules.mode + } else { + // default to disabled if the rules are not ready + AuthorizationMode::Disabled + }; + redirector::update_hostga_redirect_policy( + ga_plugin_mode != AuthorizationMode::Disabled, + redirector_shared_state.clone(), + ) + .await; +} + pub async fn key_latch_ready_state_reset(provision_shared_state: ProvisionSharedState) { reset_provision_state(ProvisionFlags::KEY_LATCH_READY, provision_shared_state).await; } @@ -287,9 +237,9 @@ async fn reset_provision_state( /// use std::sync::{Arc, Mutex}; /// /// let shared_state = Arc::new(Mutex::new(SharedState::new())); -/// provision::provision_timeup(None, shared_state.clone()); +/// provision::provision_timeout(None, shared_state.clone()); /// ``` -pub async fn provision_timeup( +pub async fn provision_timeout( provision_dir: Option, provision_shared_state: ProvisionSharedState, agent_status_shared_state: AgentStatusSharedState, @@ -313,17 +263,78 @@ pub async fn provision_timeup( } } +/// Set resource limits for Guest Proxy Agent service +/// It will be called when provision finished or timedout, +/// it is designed to delay set resource limits to give more cpu time to provision tasks +/// For Linux GPA service, it sets CPUQuota for azure-proxy-agent.service to limit the CPU usage +/// For Windows GPA service, it sets CPU and RAM limits for current process to limit the CPU and RAM usage +fn set_resource_limits() { + #[cfg(not(windows))] + { + // Set CPUQuota for azure-proxy-agent.service to 15% to limit the CPU usage for Linux azure-proxy-agent service + // Linux GPA VM Extension is not required for Linux GPA service, it should not have the resource limits set in HandlerManifest.json file + const SERVICE_NAME: &str = "azure-proxy-agent.service"; + const CPU_QUOTA: u16 = 15; + match proxy_agent_shared::linux::set_cpu_quota(SERVICE_NAME, CPU_QUOTA) { + Ok(_) => { + logger::write_warning(format!( + "Successfully set {SERVICE_NAME} CPU quota to {CPU_QUOTA}%" + )); + } + Err(e) => { + logger::write_error(format!( + "Failed to set {SERVICE_NAME} CPU quota with error: {e}" + )); + } + } + + // Do not set MemoryMax or MemoryHigh for azure-proxy-agent.service to limit the RAM usage + // As Linux GPA service is designed to be lightweight and use minimal memory footprint (~20MB), + // but its provisioning process may need more memory temporarily (e.g., up to 100MB) and then shrinks to ~20MB. + // If we set MemoryMax to 20MB, it may cause the provisioning process OOM kill unexpectedly. + // If we set MemoryHigh to 20MB, it may cause the provisioning process being throttled/hung unexpectedly. + } + + #[cfg(windows)] + { + // Set CPUQuota for GPA service process to limit the CPU usage for Windows GPA service + // As we need adjust the total CPU quota based on the number of CPU cores, + // Windows GPA VM Extension should not have the resource limits set in HandlerManifest.json file + let cpu_count = proxy_agent_shared::current_info::get_cpu_count(); + let percent = if cpu_count <= 4 { + 50 + } else if cpu_count <= 8 { + 30 + } else if cpu_count <= 16 { + 20 + } else { + 15 + }; + + const RAM_LIMIT_IN_MB: usize = 20; + match proxy_agent_shared::windows::set_resource_limits( + std::process::id(), + percent, + RAM_LIMIT_IN_MB, + ) { + Ok(_) => { + logger::write_warning(format!( + "Successfully set current process CPU quota to {percent}% and RAM limit to {RAM_LIMIT_IN_MB}MB" + )); + } + Err(e) => { + logger::write_error(format!("Failed to set CPU and RAM quota with error: {e}")); + } + } + } +} + /// Start event logger & reader tasks and status reporting task /// It will be called when provision finished or timedout, /// it is designed to delay start those tasks to give more cpu time to provision tasks -pub async fn start_event_threads( - cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - provision_shared_state: ProvisionSharedState, - agent_status_shared_state: AgentStatusSharedState, -) { - if let Ok(logger_threads_initialized) = provision_shared_state +pub async fn start_event_threads(event_threads_shared_state: EventThreadsSharedState) { + if let Ok(logger_threads_initialized) = event_threads_shared_state + .provision_shared_state .get_event_log_threads_initialized() .await { @@ -332,7 +343,12 @@ pub async fn start_event_threads( } } - let cloned_agent_status_shared_state = agent_status_shared_state.clone(); + // set resource limits before launching lower priority tasks, + // those tasks starts to run after provision finished or provision timedout + set_resource_limits(); + + let cloned_agent_status_shared_state = + event_threads_shared_state.agent_status_shared_state.clone(); tokio::spawn({ async { event_logger::start( @@ -355,10 +371,10 @@ pub async fn start_event_threads( let event_reader = EventReader::new( config::get_events_dir(), true, - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.cancellation_token.clone(), + event_threads_shared_state.common_state.clone(), + "ProxyAgent".to_string(), + "MicrosoftAzureGuestProxyAgent".to_string(), ); async move { event_reader @@ -366,7 +382,8 @@ pub async fn start_event_threads( .await; } }); - if let Err(e) = provision_shared_state + if let Err(e) = event_threads_shared_state + .provision_shared_state .set_event_log_threads_initialized() .await { @@ -379,9 +396,12 @@ pub async fn start_event_threads( let agent_status_task = proxy_agent_status::ProxyAgentStatusTask::new( Duration::from_secs(60), proxy_agent_aggregate_status::get_proxy_agent_aggregate_status_folder(), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.cancellation_token.clone(), + event_threads_shared_state.key_keeper_shared_state.clone(), + event_threads_shared_state.agent_status_shared_state.clone(), + event_threads_shared_state + .connection_summary_shared_state + .clone(), ); async move { agent_status_task.start().await; @@ -429,7 +449,7 @@ async fn write_provision_state( if !failed_state_message.is_empty() { // escape xml characters to allow the message to able be composed into xml payload - failed_state_message = helpers::xml_escape(failed_state_message); + failed_state_message = misc_helpers::xml_escape(failed_state_message); // write provision failed error message to event event_logger::write_event( @@ -658,6 +678,7 @@ mod tests { use crate::provision::provision_query::ProvisionQuery; use crate::provision::ProvisionFlags; use crate::proxy::proxy_server; + use crate::shared_state::EventThreadsSharedState; use crate::shared_state::SharedState; use std::env; use std::fs; @@ -677,8 +698,10 @@ mod tests { let cancellation_token = shared_state.get_cancellation_token(); let provision_shared_state = shared_state.get_provision_shared_state(); let key_keeper_shared_state = shared_state.get_key_keeper_shared_state(); - let telemetry_shared_state = shared_state.get_telemetry_shared_state(); let agent_status_shared_state = shared_state.get_agent_status_shared_state(); + let event_threads_shared_state = EventThreadsSharedState::new(&shared_state); + + // initialize key keeper secure channel state to UNKNOWN let port: u16 = 8092; let proxy_server = proxy_server::ProxyServer::new(port, &shared_state); @@ -709,11 +732,7 @@ mod tests { _ = super::update_provision_state( ProvisionFlags::KEY_LATCH_READY, Some(temp_test_path.to_path_buf()), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ) .await; _ = key_keeper_shared_state @@ -738,29 +757,17 @@ mod tests { super::update_provision_state( ProvisionFlags::REDIRECTOR_READY, Some(dir1), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ), super::update_provision_state( ProvisionFlags::KEY_LATCH_READY, Some(dir2), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ), super::update_provision_state( ProvisionFlags::LISTENER_READY, Some(dir3), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ), ]; for handle in handles { @@ -812,14 +819,7 @@ mod tests { assert!(event_threads_initialized); // update provision finish time_tick - super::key_latched( - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), - ) - .await; + super::key_latched(event_threads_shared_state.clone()).await; let provision_state_internal = super::get_provision_state_internal( provision_shared_state.clone(), agent_status_shared_state.clone(), @@ -872,14 +872,7 @@ mod tests { ); // test key_latched ready again - super::key_latched( - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), - ) - .await; + super::key_latched(event_threads_shared_state.clone()).await; let provision_state = provision_shared_state.get_state().await.unwrap(); assert!( provision_state.contains(ProvisionFlags::ALL_READY), @@ -910,39 +903,26 @@ mod tests { let shared_state = SharedState::start_all(); let cancellation_token = shared_state.get_cancellation_token(); let provision_shared_state = shared_state.get_provision_shared_state(); - let key_keeper_shared_state = shared_state.get_key_keeper_shared_state(); - let telemetry_shared_state = shared_state.get_telemetry_shared_state(); let agent_status_shared_state = shared_state.get_agent_status_shared_state(); + let event_threads_shared_state = EventThreadsSharedState::new(&shared_state); // test all 3 provision states as ready super::update_provision_state( ProvisionFlags::LISTENER_READY, Some(temp_test_path.clone()), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ) .await; super::update_provision_state( ProvisionFlags::KEY_LATCH_READY, Some(temp_test_path.clone()), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ) .await; super::update_provision_state( ProvisionFlags::REDIRECTOR_READY, Some(temp_test_path.clone()), - cancellation_token.clone(), - key_keeper_shared_state.clone(), - telemetry_shared_state.clone(), - provision_shared_state.clone(), - agent_status_shared_state.clone(), + event_threads_shared_state.clone(), ) .await; @@ -964,7 +944,7 @@ mod tests { super::AgentStatusModule::KeyKeeper, ) .await; - super::provision_timeup( + super::provision_timeout( Some(temp_test_path.clone()), provision_shared_state.clone(), agent_status_shared_state.clone(), diff --git a/proxy_agent/src/proxy.rs b/proxy_agent/src/proxy.rs index a6a2dce5..dcc50b1a 100644 --- a/proxy_agent/src/proxy.rs +++ b/proxy_agent/src/proxy.rs @@ -165,10 +165,16 @@ impl Process { let (process_full_path, cmd); #[cfg(windows)] { - let handler = windows::get_process_handler(pid).unwrap_or_else(|e| { - println!("Failed to get process handler: {e}"); - 0 - }); + use windows_sys::Win32::System::Threading::{ + PROCESS_QUERY_INFORMATION, PROCESS_VM_READ, + }; + + let options = PROCESS_QUERY_INFORMATION | PROCESS_VM_READ; + let handler = proxy_agent_shared::windows::get_process_handler(pid, options) + .unwrap_or_else(|e| { + println!("Failed to get process handler: {e}"); + 0 + }); let base_info = windows::query_basic_process_info(handler); match base_info { Ok(_) => { @@ -182,7 +188,7 @@ impl Process { } } // close the handle - if let Err(e) = windows::close_process_handler(handler) { + if let Err(e) = proxy_agent_shared::windows::close_handler(handler) { println!("Failed to close process handler: {e}"); } } diff --git a/proxy_agent/src/proxy/proxy_authorizer.rs b/proxy_agent/src/proxy/proxy_authorizer.rs index 1707f55e..23d54dc5 100644 --- a/proxy_agent/src/proxy/proxy_authorizer.rs +++ b/proxy_agent/src/proxy/proxy_authorizer.rs @@ -8,12 +8,12 @@ //! ```rust //! use proxy_agent::proxy_authorizer; //! use proxy_agent::proxy::Claims; -//! use proxy_agent::shared_state::key_keeper_wrapper::KeyKeeperSharedState; +//! use crate::shared_state::access_control_wrapper::AccessControlSharedState; //! use proxy_agent::common::constants; //! use std::str::FromStr; //! -//! let key_keeper_shared_state = KeyKeeperSharedState::start_new(); -//! let vm_metadata = proxy_authorizer::get_access_control_rules(constants::WIRE_SERVER_IP.to_string(), constants::WIRE_SERVER_PORT, key_keeper_shared_state.clone()).await.unwrap(); +//! let access_control_shared_state = AccessControlSharedState::start_new(); +//! let vm_metadata = proxy_authorizer::get_access_control_rules(constants::WIRE_SERVER_IP.to_string(), constants::WIRE_SERVER_PORT, access_control_shared_state .clone()).await.unwrap(); //! let authorizer = proxy_authorizer::get_authorizer(constants::WIRE_SERVER_IP, constants::WIRE_SERVER_PORT, claims); //! let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); //! authorizer.authorize(logger, url, vm_metadata); @@ -21,7 +21,7 @@ use super::authorization_rules::{AuthorizationMode, ComputedAuthorizationItem}; use super::proxy_connection::ConnectionLogger; -use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; +use crate::shared_state::access_control_wrapper::AccessControlSharedState; use crate::{common::constants, common::result::Result, proxy::Claims}; use proxy_agent_shared::logger::LoggerLevel; @@ -211,17 +211,17 @@ pub fn get_authorizer(ip: String, port: u16, claims: Claims) -> Box Result> { match (ip.as_str(), port) { (constants::WIRE_SERVER_IP, constants::WIRE_SERVER_PORT) => { - key_keeper_shared_state.get_wireserver_rules().await + access_control_shared_state.get_wireserver_rules().await } (constants::GA_PLUGIN_IP, constants::GA_PLUGIN_PORT) => { - key_keeper_shared_state.get_hostga_rules().await + access_control_shared_state.get_hostga_rules().await } (constants::IMDS_IP, constants::IMDS_PORT) => { - key_keeper_shared_state.get_imds_rules().await + access_control_shared_state.get_imds_rules().await } _ => Ok(None), } @@ -248,7 +248,7 @@ mod tests { use crate::{ key_keeper::key::AuthorizationItem, proxy::{proxy_authorizer::AuthorizeResult, proxy_connection::ConnectionLogger}, - shared_state::key_keeper_wrapper::KeyKeeperSharedState, + shared_state::access_control_wrapper::AccessControlSharedState, }; use std::{ffi::OsString, path::PathBuf, str::FromStr}; @@ -347,7 +347,7 @@ mod tests { claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); - let key_keeper_shared_state = KeyKeeperSharedState::start_new(); + let access_control_shared_state = AccessControlSharedState::start_new(); // validate disabled rules let disabled_rules = AuthorizationItem { @@ -356,11 +356,11 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_wireserver_rules(Some(disabled_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state + let access_control_rules = access_control_shared_state .get_wireserver_rules() .await .unwrap(); @@ -383,11 +383,11 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_wireserver_rules(Some(audit_allow_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state + let access_control_rules = access_control_shared_state .get_wireserver_rules() .await .unwrap(); @@ -396,11 +396,11 @@ mod tests { == AuthorizeResult::Ok, "WireServer authentication must be Ok with audit allow rules" ); - key_keeper_shared_state + access_control_shared_state .set_wireserver_rules(Some(audit_deny_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state + let access_control_rules = access_control_shared_state .get_wireserver_rules() .await .unwrap(); @@ -423,11 +423,11 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_wireserver_rules(Some(enforce_allow_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state + let access_control_rules = access_control_shared_state .get_wireserver_rules() .await .unwrap(); @@ -436,11 +436,11 @@ mod tests { == AuthorizeResult::Ok, "WireServer authentication must be Ok with enforce allow rules" ); - key_keeper_shared_state + access_control_shared_state .set_wireserver_rules(Some(enforce_deny_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state + let access_control_rules = access_control_shared_state .get_wireserver_rules() .await .unwrap(); @@ -472,7 +472,7 @@ mod tests { claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); - let key_keeper_shared_state = KeyKeeperSharedState::start_new(); + let access_control_shared_state = AccessControlSharedState::start_new(); // validate disabled rules let disabled_rules = AuthorizationItem { @@ -481,11 +481,11 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_imds_rules(Some(disabled_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); + let access_control_rules = access_control_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Ok, @@ -505,21 +505,21 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_imds_rules(Some(audit_allow_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); + let access_control_rules = access_control_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Ok, "IMDS authentication must be Ok with audit allow rules" ); - key_keeper_shared_state + access_control_shared_state .set_imds_rules(Some(audit_deny_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); + let access_control_rules = access_control_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::OkWithAudit, @@ -539,21 +539,21 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_imds_rules(Some(enforce_allow_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); + let access_control_rules = access_control_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Ok, "IMDS authentication must be Ok with enforce allow rules" ); - key_keeper_shared_state + access_control_shared_state .set_imds_rules(Some(enforce_deny_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_imds_rules().await.unwrap(); + let access_control_rules = access_control_shared_state.get_imds_rules().await.unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules,) == AuthorizeResult::Forbidden, @@ -582,7 +582,7 @@ mod tests { claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); - let key_keeper_shared_state = KeyKeeperSharedState::start_new(); + let access_control_shared_state = AccessControlSharedState::start_new(); // validate disabled rules let disabled_rules = AuthorizationItem { @@ -591,11 +591,14 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_hostga_rules(Some(disabled_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); + let access_control_rules = access_control_shared_state + .get_hostga_rules() + .await + .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, @@ -615,21 +618,27 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_hostga_rules(Some(audit_allow_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); + let access_control_rules = access_control_shared_state + .get_hostga_rules() + .await + .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "HostGA authentication must be Ok with audit allow rules" ); - key_keeper_shared_state + access_control_shared_state .set_hostga_rules(Some(audit_deny_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); + let access_control_rules = access_control_shared_state + .get_hostga_rules() + .await + .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::OkWithAudit, @@ -649,21 +658,27 @@ mod tests { id: "id".to_string(), rules: None, }; - key_keeper_shared_state + access_control_shared_state .set_hostga_rules(Some(enforce_allow_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); + let access_control_rules = access_control_shared_state + .get_hostga_rules() + .await + .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Ok, "HostGA authentication must be Ok with enforce allow rules" ); - key_keeper_shared_state + access_control_shared_state .set_hostga_rules(Some(enforce_deny_rules)) .await .unwrap(); - let access_control_rules = key_keeper_shared_state.get_hostga_rules().await.unwrap(); + let access_control_rules = access_control_shared_state + .get_hostga_rules() + .await + .unwrap(); assert!( auth.authorize(&mut test_logger, url.clone(), access_control_rules) == AuthorizeResult::Forbidden, diff --git a/proxy_agent/src/proxy/proxy_connection.rs b/proxy_agent/src/proxy/proxy_connection.rs index 90a29893..e20956a2 100644 --- a/proxy_agent/src/proxy/proxy_connection.rs +++ b/proxy_agent/src/proxy/proxy_connection.rs @@ -10,7 +10,6 @@ use crate::proxy::Claims; use crate::redirector::{self, AuditEntry}; use crate::shared_state::proxy_server_wrapper::ProxyServerSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; -use http_body_util::Full; use hyper::body::Bytes; use hyper::client::conn::http1; use hyper::Request; @@ -23,7 +22,8 @@ use std::sync::Arc; use std::time::Instant; use tokio::sync::Mutex; -pub type RequestBody = Full; +pub type RequestBody = + http_body_util::combinators::BoxBody>; struct Client { sender: http1::SendRequest, } diff --git a/proxy_agent/src/proxy/proxy_server.rs b/proxy_agent/src/proxy/proxy_server.rs index c23fdb1a..cdea7b37 100644 --- a/proxy_agent/src/proxy/proxy_server.rs +++ b/proxy_agent/src/proxy/proxy_server.rs @@ -25,21 +25,23 @@ use super::proxy_connection::{ConnectionLogger, HttpConnectionContext, TcpConnec use crate::common::{constants, error::Error, helpers, logger, result::Result}; use crate::provision; use crate::proxy::{proxy_authorizer, proxy_summary::ProxySummary, Claims}; +use crate::shared_state::access_control_wrapper::AccessControlSharedState; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; +use crate::shared_state::connection_summary_wrapper::ConnectionSummarySharedState; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::proxy_server_wrapper::ProxyServerSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; -use crate::shared_state::telemetry_wrapper::TelemetrySharedState; -use crate::shared_state::SharedState; +use crate::shared_state::{EventThreadsSharedState, SharedState}; use http_body_util::Full; use http_body_util::{combinators::BoxBody, BodyExt}; -use hyper::body::{Bytes, Frame, Incoming}; +use hyper::body::{Bytes, Incoming}; use hyper::header::{HeaderName, HeaderValue}; use hyper::service::service_fn; use hyper::StatusCode; use hyper::{Request, Response}; use hyper_util::rt::TokioIo; +use proxy_agent_shared::common_state::CommonState; use proxy_agent_shared::error::HyperErrorType; use proxy_agent_shared::hyper_client; use proxy_agent_shared::logger::LoggerLevel; @@ -49,6 +51,7 @@ use proxy_agent_shared::telemetry::event_logger; use std::time::Duration; use tokio::net::TcpListener; use tokio::net::TcpStream; +use tokio_util::bytes::BytesMut; use tokio_util::sync::CancellationToken; use tower::Service; use tower_http::{body::Limited, limit::RequestBodyLimitLayer}; @@ -63,11 +66,13 @@ pub struct ProxyServer { port: u16, cancellation_token: CancellationToken, key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, + common_state: CommonState, provision_shared_state: ProvisionSharedState, agent_status_shared_state: AgentStatusSharedState, redirector_shared_state: RedirectorSharedState, proxy_server_shared_state: ProxyServerSharedState, + access_control_shared_state: AccessControlSharedState, + connection_summary_shared_state: ConnectionSummarySharedState, } impl ProxyServer { @@ -76,11 +81,13 @@ impl ProxyServer { port, cancellation_token: shared_state.get_cancellation_token(), key_keeper_shared_state: shared_state.get_key_keeper_shared_state(), - telemetry_shared_state: shared_state.get_telemetry_shared_state(), + common_state: shared_state.get_common_state(), provision_shared_state: shared_state.get_provision_shared_state(), agent_status_shared_state: shared_state.get_agent_status_shared_state(), redirector_shared_state: shared_state.get_redirector_shared_state(), proxy_server_shared_state: shared_state.get_proxy_server_shared_state(), + access_control_shared_state: shared_state.get_access_control_shared_state(), + connection_summary_shared_state: shared_state.get_connection_summary_shared_state(), } } @@ -182,13 +189,16 @@ impl ProxyServer { { logger::write_warning(format!("Failed to set module state: {e}")); } - provision::listener_started( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) + provision::listener_started(EventThreadsSharedState { + cancellation_token: self.cancellation_token.clone(), + common_state: self.common_state.clone(), + access_control_shared_state: self.access_control_shared_state.clone(), + redirector_shared_state: self.redirector_shared_state.clone(), + key_keeper_shared_state: self.key_keeper_shared_state.clone(), + provision_shared_state: self.provision_shared_state.clone(), + agent_status_shared_state: self.agent_status_shared_state.clone(), + connection_summary_shared_state: self.connection_summary_shared_state.clone(), + }) .await; // We start a loop to continuously accept incoming connections @@ -467,7 +477,7 @@ impl ProxyServer { let access_control_rules = match proxy_authorizer::get_access_control_rules( ip.to_string(), port, - self.key_keeper_shared_state.clone(), + self.access_control_shared_state.clone(), ) .await { @@ -574,16 +584,10 @@ impl ProxyServer { http_connection_context.method, http_connection_context.url ), ); - let request = match Self::convert_request(proxy_request).await { - Ok(r) => r, - Err(e) => { - http_connection_context.log( - LoggerLevel::Error, - format!("Failed to convert request: {e}"), - ); - return Ok(Self::closed_response(StatusCode::BAD_REQUEST)); - } - }; + + let (head, body) = proxy_request.into_parts(); + // Stream the request body directly without buffering + let request = Request::from_parts(head, body.boxed()); http_connection_context.log( LoggerLevel::Trace, format!( @@ -603,20 +607,6 @@ impl ProxyServer { .await } - async fn convert_request( - request: Request>, - ) -> Result>> { - let (head, body) = request.into_parts(); - let whole_body = match body.collect().await { - Ok(data) => data.to_bytes(), - Err(e) => { - return Err(Error::Hyper(HyperErrorType::RequestBody(e.to_string()))); - } - }; - - Ok(Request::from_parts(head, Full::new(whole_body))) - } - async fn handle_provision_state_check_request( &self, logger: &mut ConnectionLogger, @@ -739,23 +729,9 @@ impl ProxyServer { LoggerLevel::Trace, "Converting response to the client format.".to_string(), ); - let mut logger = http_connection_context.logger.clone(); let (head, body) = proxy_response.into_parts(); - let frame_stream = body.map_frame(move |frame| { - let frame = match frame.into_data() { - Ok(data) => data.iter().map(|byte| byte.to_be()).collect::(), - Err(e) => { - logger.write( - LoggerLevel::Error, - format!("Failed to get frame data: {e:?}"), - ); - Bytes::new() - } - }; - - Frame::data(frame) - }); - let mut response = Response::from_parts(head, frame_stream.boxed()); + // Stream the response body directly without buffering + let mut response = Response::from_parts(head, body.boxed()); http_connection_context.log(LoggerLevel::Trace, "Adding proxy agent header.".to_string()); // insert default x-ms-azure-host-authorization header to let the client know it is through proxy agent @@ -857,7 +833,7 @@ impl ProxyServer { if log_authorize_failed { if let Err(e) = self - .agent_status_shared_state + .connection_summary_shared_state .add_one_failed_connection_summary(summary) .await { @@ -867,7 +843,7 @@ impl ProxyServer { ); } } else if let Err(e) = self - .agent_status_shared_state + .connection_summary_shared_state .add_one_connection_summary(summary) .await { @@ -912,17 +888,7 @@ impl ProxyServer { LoggerLevel::Trace, "Starting to collect the client request body.".to_string(), ); - let whole_body = match body.collect().await { - Ok(data) => data.to_bytes(), - Err(e) => { - http_connection_context.log( - LoggerLevel::Error, - format!("Failed to receive the request body: {e}"), - ); - return Ok(Self::closed_response(StatusCode::BAD_REQUEST)); - } - }; - + let whole_body = Self::read_body_bytes(body).await?; http_connection_context.log( LoggerLevel::Trace, format!( @@ -934,8 +900,11 @@ impl ProxyServer { ); // create a new request to the Host endpoint - let mut proxy_request: Request> = - Request::from_parts(head.clone(), Full::new(whole_body.clone())); + let body = Full::new(whole_body.clone()) + .map_err(|never| -> Box { match never {} }) + .boxed(); + let mut proxy_request: Request = + Request::from_parts(head.clone(), body); // sign the request // Add header x-ms-azure-host-authorization @@ -1017,6 +986,32 @@ impl ProxyServer { self.forward_response(proxy_response, http_connection_context) .await } + + /// It reads the body in chunks and concatenates them into a single Bytes object + /// It also yields control to the tokio scheduler to avoid blocking the thread if the body is large + async fn read_body_bytes(mut body: B) -> Result + where + B: hyper::body::Body + Unpin, + B::Error: std::fmt::Display + Send + Sync + 'static, + { + let body_size = body.size_hint().upper().unwrap_or(4 * 1024 * 1024); + let mut buf = BytesMut::with_capacity(body_size as usize); + while let Some(chunk) = body.frame().await { + match chunk { + Ok(chunk) => { + if let Ok(data) = chunk.into_data() { + buf.extend_from_slice(&data) + } + } + Err(e) => { + return Err(Error::Hyper(HyperErrorType::ReceiveBody(e.to_string()))); + } + } + // yield control to the tokio scheduler to avoid blocking the thread if the body is large + tokio::task::yield_now().await; + } + Ok(buf.freeze()) + } } #[cfg(test)] diff --git a/proxy_agent/src/proxy/windows.rs b/proxy_agent/src/proxy/windows.rs index 145b84d7..9b4ac2f7 100644 --- a/proxy_agent/src/proxy/windows.rs +++ b/proxy_agent/src/proxy/windows.rs @@ -15,7 +15,7 @@ use windows_sys::Wdk::System::Threading::{ NtQueryInformationProcess, // ntdll.dll PROCESSINFOCLASS, }; -use windows_sys::Win32::Foundation::{CloseHandle, BOOL, HANDLE, LUID, NTSTATUS, UNICODE_STRING}; +use windows_sys::Win32::Foundation::{LUID, NTSTATUS, UNICODE_STRING}; use windows_sys::Win32::Security::Authentication::Identity; use windows_sys::Win32::Security::Authentication::Identity::{ LSA_UNICODE_STRING, SECURITY_LOGON_SESSION_DATA, @@ -25,9 +25,6 @@ use windows_sys::Win32::System::ProcessStatus::{ K32GetModuleFileNameExW, // kernel32.dll }; use windows_sys::Win32::System::Threading::PROCESS_BASIC_INFORMATION; -use windows_sys::Win32::System::Threading::{ - OpenProcess, //kernel32.dll -}; const LG_INCLUDE_INDIRECT: u32 = 1u32; const MAX_PREFERRED_LENGTH: u32 = 4294967295u32; @@ -238,14 +235,10 @@ fn to_pwstr(s: &str) -> Vec { /* Get process information */ -const PROCESS_QUERY_INFORMATION: u32 = 0x0400; -const PROCESS_VM_READ: u32 = 0x0010; -const FALSE: BOOL = 0; const MAX_PATH: usize = 260; const STATUS_BUFFER_OVERFLOW: NTSTATUS = -2147483643; const STATUS_BUFFER_TOO_SMALL: NTSTATUS = -1073741789; const STATUS_INFO_LENGTH_MISMATCH: NTSTATUS = -1073741820; - const PROCESS_BASIC_INFORMATION_CLASS: PROCESSINFOCLASS = 0; const PROCESS_COMMAND_LINE_INFORMATION_CLASS: PROCESSINFOCLASS = 60; @@ -270,51 +263,6 @@ pub fn query_basic_process_info(handler: isize) -> Result` - Process handler -/// # Errors -/// * `Error::Invalid` - If the pid is 0 -/// * `Error::WindowsApi` - If the OpenProcess call fails -/// # Safety -/// This function is safe to call as it does not dereference any raw pointers. -/// However, the caller is responsible for closing the process handler using `close_process_handler` -/// when it is no longer needed to avoid resource leaks. -pub fn get_process_handler(pid: u32) -> Result { - if pid == 0 { - return Err(Error::Invalid("pid 0".to_string())); - } - let options = PROCESS_QUERY_INFORMATION | PROCESS_VM_READ; - - // https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-openprocess - let handler = unsafe { OpenProcess(options, FALSE, pid) }; - if handler == 0 { - return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( - std::io::Error::last_os_error(), - ))); - } - Ok(handler) -} - -/// Close process handler -/// # Arguments -/// * `handler` - Process handler -/// # Returns -/// * `Result<()>` - Ok if successful, Err if failed -pub fn close_process_handler(handler: HANDLE) -> Result<()> { - if handler != 0 { - // https://learn.microsoft.com/en-us/windows/win32/api/handleapi/nf-handleapi-closehandle - if 0 != unsafe { CloseHandle(handler) } { - return Err(Error::WindowsApi(WindowsApiErrorType::WindowsOsError( - std::io::Error::last_os_error(), - ))); - } - } - Ok(()) -} - pub fn get_process_cmd(handler: isize) -> Result { unsafe { let mut return_length = 0; @@ -432,8 +380,11 @@ mod tests { #[test] fn get_process_test() { + use windows_sys::Win32::System::Threading::{PROCESS_QUERY_INFORMATION, PROCESS_VM_READ}; + let pid = std::process::id(); - let handler = super::get_process_handler(pid).unwrap(); + let options = PROCESS_QUERY_INFORMATION | PROCESS_VM_READ; + let handler = proxy_agent_shared::windows::get_process_handler(pid, options).unwrap(); let name = super::get_process_name(handler).unwrap(); let full_name = super::get_process_full_name(handler).unwrap(); let cmd = super::get_process_cmd(handler).unwrap(); diff --git a/proxy_agent/src/proxy_agent_status.rs b/proxy_agent/src/proxy_agent_status.rs index 777c1620..6a248474 100644 --- a/proxy_agent/src/proxy_agent_status.rs +++ b/proxy_agent/src/proxy_agent_status.rs @@ -31,6 +31,7 @@ use crate::common::logger; use crate::key_keeper::UNKNOWN_STATE; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; +use crate::shared_state::connection_summary_wrapper::ConnectionSummarySharedState; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; @@ -50,6 +51,7 @@ pub struct ProxyAgentStatusTask { cancellation_token: CancellationToken, key_keeper_shared_state: KeyKeeperSharedState, agent_status_shared_state: AgentStatusSharedState, + connection_summary_shared_state: ConnectionSummarySharedState, } impl ProxyAgentStatusTask { @@ -59,6 +61,7 @@ impl ProxyAgentStatusTask { cancellation_token: CancellationToken, key_keeper_shared_state: KeyKeeperSharedState, agent_status_shared_state: AgentStatusSharedState, + connection_summary_shared_state: ConnectionSummarySharedState, ) -> ProxyAgentStatusTask { ProxyAgentStatusTask { interval, @@ -66,6 +69,7 @@ impl ProxyAgentStatusTask { cancellation_token, key_keeper_shared_state, agent_status_shared_state, + connection_summary_shared_state, } } @@ -126,6 +130,11 @@ impl ProxyAgentStatusTask { } loop { + #[cfg(not(windows))] + { + self.monitor_memory_usage(); + } + let aggregate_status = self.guest_proxy_agent_aggregate_status_new().await; // write proxyAgentStatus event if status_report_time.elapsed() >= status_report_duration { @@ -151,7 +160,11 @@ impl ProxyAgentStatusTask { "Clearing the connection summary map and failed authenticate summary map." .to_string(), ); - if let Err(e) = self.agent_status_shared_state.clear_all_summary().await { + if let Err(e) = self + .connection_summary_shared_state + .clear_all_summary() + .await + { logger::write_error(format!("Error clearing the connection summary map and failed authenticate summary map: {e}")); } start_time = Instant::now(); @@ -261,7 +274,7 @@ impl ProxyAgentStatusTask { timestamp: misc_helpers::get_date_time_string_with_milliseconds(), proxyAgentStatus: self.proxy_agent_status_new().await, proxyConnectionSummary: match self - .agent_status_shared_state + .connection_summary_shared_state .get_all_connection_summary() .await { @@ -272,7 +285,7 @@ impl ProxyAgentStatusTask { } }, failedAuthenticateSummary: match self - .agent_status_shared_state + .connection_summary_shared_state .get_all_failed_connection_summary() .await { @@ -301,6 +314,46 @@ impl ProxyAgentStatusTask { .await; } } + + /// Monitor the memory usage of the current process and log it. + /// If the memory usage exceeds the limit, log a warning. + /// If the memory usage exceeds the limits for multiple times, take action (e.g., restart the process). + #[cfg(not(windows))] + fn monitor_memory_usage(&self) { + const RAM_LIMIT_IN_MB: u64 = 20; + match proxy_agent_shared::linux::read_proc_memory_status(std::process::id()) { + Ok(memory_status) => { + if let Some(vmrss_kb) = memory_status.vmrss_kb { + let ram_in_mb = vmrss_kb / 1024; + logger::write_information(format!( + "Current process memory usage: {ram_in_mb} MB", + )); + + if ram_in_mb > RAM_LIMIT_IN_MB { + logger::write_warning(format!( + "Current process memory usage {ram_in_mb} MB exceeds the limit of {RAM_LIMIT_IN_MB} MB.", + )); + // take action if needed, e.g., restart the process + } + } else { + logger::write_information("Current process memory usage: Unknown".to_string()); + } + if let Some(vmhwm_kb) = memory_status.vmhwm_kb { + logger::write_information(format!( + "Current process peak memory usage: {} MB", + vmhwm_kb / 1024 + )); + } else { + logger::write_information( + "Current process peak memory usage: Unknown".to_string(), + ); + } + } + Err(e) => { + logger::write_error(format!("Error reading process memory status: {e}")); + } + } + } } #[cfg(test)] @@ -308,7 +361,9 @@ mod tests { use crate::{ proxy_agent_status::ProxyAgentStatusTask, shared_state::{ - agent_status_wrapper::AgentStatusSharedState, key_keeper_wrapper::KeyKeeperSharedState, + agent_status_wrapper::AgentStatusSharedState, + connection_summary_wrapper::ConnectionSummarySharedState, + key_keeper_wrapper::KeyKeeperSharedState, }, }; use proxy_agent_shared::{ @@ -330,6 +385,7 @@ mod tests { CancellationToken::new(), KeyKeeperSharedState::start_new(), AgentStatusSharedState::start_new(), + ConnectionSummarySharedState::start_new(), ); let aggregate_status = task.guest_proxy_agent_aggregate_status_new().await; task.write_aggregate_status_to_file(aggregate_status).await; diff --git a/proxy_agent/src/redirector.rs b/proxy_agent/src/redirector.rs index 07a33433..2dc87a46 100644 --- a/proxy_agent/src/redirector.rs +++ b/proxy_agent/src/redirector.rs @@ -45,20 +45,21 @@ mod windows; #[cfg(not(windows))] mod linux; -use crate::common::constants; use crate::common::error::BpfErrorType; use crate::common::error::Error; use crate::common::helpers; use crate::common::result::Result; use crate::common::{config, logger}; use crate::provision; -use crate::proxy::authorization_rules::AuthorizationMode; +use crate::shared_state::access_control_wrapper::AccessControlSharedState; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; +use crate::shared_state::connection_summary_wrapper::ConnectionSummarySharedState; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; -use crate::shared_state::telemetry_wrapper::TelemetrySharedState; +use crate::shared_state::EventThreadsSharedState; use crate::shared_state::SharedState; +use proxy_agent_shared::common_state::CommonState; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; @@ -110,8 +111,10 @@ pub struct Redirector { key_keeper_shared_state: KeyKeeperSharedState, agent_status_shared_state: AgentStatusSharedState, cancellation_token: CancellationToken, - telemetry_shared_state: TelemetrySharedState, + common_state: CommonState, provision_shared_state: ProvisionSharedState, + access_control_shared_state: AccessControlSharedState, + connection_summary_shared_state: ConnectionSummarySharedState, } impl Redirector { @@ -120,10 +123,12 @@ impl Redirector { local_port, cancellation_token: shared_state.get_cancellation_token(), key_keeper_shared_state: shared_state.get_key_keeper_shared_state(), - telemetry_shared_state: shared_state.get_telemetry_shared_state(), + common_state: shared_state.get_common_state(), provision_shared_state: shared_state.get_provision_shared_state(), agent_status_shared_state: shared_state.get_agent_status_shared_state(), redirector_shared_state: shared_state.get_redirector_shared_state(), + access_control_shared_state: shared_state.get_access_control_shared_state(), + connection_summary_shared_state: shared_state.get_connection_summary_shared_state(), } } @@ -186,49 +191,9 @@ impl Redirector { logger::write_information(format!( "Success updated bpf skip_process map with pid={pid}." )); - let wireserver_mode = - if let Ok(Some(rules)) = self.key_keeper_shared_state.get_wireserver_rules().await { - rules.mode - } else { - AuthorizationMode::Audit - }; - if wireserver_mode != AuthorizationMode::Disabled { - bpf_object.update_policy_elem_bpf_map( - "WireServer endpoints", - self.local_port, - constants::WIRE_SERVER_IP_NETWORK_BYTE_ORDER, //0x10813FA8 - 168.63.129.16 - constants::WIRE_SERVER_PORT, - )?; - logger::write_information( - "Success updated bpf map for WireServer support.".to_string(), - ); - } - let imds_mode = if let Ok(Some(rules)) = self.key_keeper_shared_state.get_imds_rules().await - { - rules.mode - } else { - AuthorizationMode::Audit - }; - if imds_mode != AuthorizationMode::Disabled { - bpf_object.update_policy_elem_bpf_map( - "IMDS endpoints", - self.local_port, - constants::IMDS_IP_NETWORK_BYTE_ORDER, //0xFEA9FEA9, // 169.254.169.254 - constants::IMDS_PORT, - )?; - logger::write_information("Success updated bpf map for IMDS support.".to_string()); - } - if config::get_host_gaplugin_support() > 0 { - bpf_object.update_policy_elem_bpf_map( - "Host GAPlugin endpoints", - self.local_port, - constants::GA_PLUGIN_IP_NETWORK_BYTE_ORDER, //0x10813FA8, // 168.63.129.16 - constants::GA_PLUGIN_PORT, - )?; - logger::write_information( - "Success updated bpf map for Host GAPlugin support.".to_string(), - ); - } + + // Do not update redirect policy map here, it will be updated by provision module + // When provision is finished, it will call update_xxx_redirect_policy functions to update the redirect policy maps. // programs self.attach_bpf_prog(&mut bpf_object)?; @@ -272,13 +237,16 @@ impl Redirector { } // report redirector ready for provision - provision::redirector_ready( - self.cancellation_token.clone(), - self.key_keeper_shared_state.clone(), - self.telemetry_shared_state.clone(), - self.provision_shared_state.clone(), - self.agent_status_shared_state.clone(), - ) + provision::redirector_ready(EventThreadsSharedState { + cancellation_token: self.cancellation_token.clone(), + common_state: self.common_state.clone(), + access_control_shared_state: self.access_control_shared_state.clone(), + redirector_shared_state: self.redirector_shared_state.clone(), + key_keeper_shared_state: self.key_keeper_shared_state.clone(), + provision_shared_state: self.provision_shared_state.clone(), + agent_status_shared_state: self.agent_status_shared_state.clone(), + connection_summary_shared_state: self.connection_summary_shared_state.clone(), + }) .await; Ok(()) diff --git a/proxy_agent/src/redirector/linux.rs b/proxy_agent/src/redirector/linux.rs index f973428e..cf5e1776 100644 --- a/proxy_agent/src/redirector/linux.rs +++ b/proxy_agent/src/redirector/linux.rs @@ -279,7 +279,7 @@ impl BpfObject { dest_port: u16, local_port: u16, redirect: bool, - ) { + ) -> bool { let policy_map_name = "policy_map"; match self.0.map_mut(policy_map_name) { Some(map) => match HashMap::<&mut MapData, [u32; 6], [u32; 6]>::try_from(map) { @@ -301,7 +301,7 @@ impl BpfObject { ); } Err(err) => { - logger::write(format!("Failed to remove destination: {}:{} from policy_map with error: {}", ip_to_string(dest_ipv4), dest_port, err)); + logger::write(format!("Failed to remove destination: {}:{} from policy_map with error: {}. The policy_map may not contain this entry, skip and continue.", ip_to_string(dest_ipv4), dest_port, err)); } }; } else { @@ -319,22 +319,26 @@ impl BpfObject { let local_ip: u32 = super::string_to_ip(&local_ip); let value = destination_entry::from_ipv4(local_ip, local_port); match policy_map.insert(key.to_array(), value.to_array(), 0) { - Ok(_) => event_logger::write_event( - LoggerLevel::Info, - format!( - "policy_map updated for destination: {}:{}", - ip_to_string(dest_ipv4), - dest_port - ), - "update_redirect_policy_internal", - "redirector/linux", - logger::AGENT_LOGGER_KEY, - ), + Ok(_) => { + event_logger::write_event( + LoggerLevel::Info, + format!( + "policy_map updated for destination: {}:{}", + ip_to_string(dest_ipv4), + dest_port + ), + "update_redirect_policy_internal", + "redirector/linux", + logger::AGENT_LOGGER_KEY, + ); + return true; + } Err(err) => { - logger::write(format!("Failed to insert destination: {}:{} to policy_map with error: {}", ip_to_string(dest_ipv4), dest_port, err)); + logger::write_error(format!("Failed to insert destination: {}:{} to policy_map with error: {}", ip_to_string(dest_ipv4), dest_port, err)); } } } + return true; } Err(err) => { logger::write(format!( @@ -346,6 +350,8 @@ impl BpfObject { logger::write("Failed to get map 'policy_map'.".to_string()); } } + + false } pub fn remove_audit_map_entry(&mut self, source_port: u16) -> Result<()> { @@ -425,52 +431,56 @@ impl super::Redirector { pub async fn update_wire_server_redirect_policy( redirect: bool, redirector_shared_state: RedirectorSharedState, -) { +) -> bool { if let (Ok(Some(bpf_object)), Ok(local_port)) = ( redirector_shared_state.get_bpf_object().await, redirector_shared_state.get_local_port().await, ) { - bpf_object.lock().unwrap().update_redirect_policy( + return bpf_object.lock().unwrap().update_redirect_policy( constants::WIRE_SERVER_IP_NETWORK_BYTE_ORDER, constants::WIRE_SERVER_PORT, local_port, redirect, ); } + + false } pub async fn update_imds_redirect_policy( redirect: bool, redirector_shared_state: RedirectorSharedState, -) { +) -> bool { if let (Ok(Some(bpf_object)), Ok(local_port)) = ( redirector_shared_state.get_bpf_object().await, redirector_shared_state.get_local_port().await, ) { - bpf_object.lock().unwrap().update_redirect_policy( + return bpf_object.lock().unwrap().update_redirect_policy( constants::IMDS_IP_NETWORK_BYTE_ORDER, constants::IMDS_PORT, local_port, redirect, ); } + false } pub async fn update_hostga_redirect_policy( redirect: bool, redirector_shared_state: RedirectorSharedState, -) { +) -> bool { if let (Ok(Some(bpf_object)), Ok(local_port)) = ( redirector_shared_state.get_bpf_object().await, redirector_shared_state.get_local_port().await, ) { - bpf_object.lock().unwrap().update_redirect_policy( + return bpf_object.lock().unwrap().update_redirect_policy( constants::GA_PLUGIN_IP_NETWORK_BYTE_ORDER, constants::GA_PLUGIN_PORT, local_port, redirect, ); } + false } #[cfg(test)] diff --git a/proxy_agent/src/redirector/windows.rs b/proxy_agent/src/redirector/windows.rs index 6b9118ec..3f61902a 100644 --- a/proxy_agent/src/redirector/windows.rs +++ b/proxy_agent/src/redirector/windows.rs @@ -144,7 +144,7 @@ pub fn get_audit_from_redirect_context(raw_socket_id: usize) -> Result bool { if let Ok(Some(bpf_object)) = redirector_shared_state.get_bpf_object().await { if redirect { if let Ok(local_port) = redirector_shared_state.get_local_port().await { @@ -173,13 +173,16 @@ pub async fn update_wire_server_redirect_policy( } else { logger::write("Success deleted bpf map for wireserver redirect policy.".to_string()); } + true + } else { + false } } pub async fn update_imds_redirect_policy( redirect: bool, redirector_shared_state: RedirectorSharedState, -) { +) -> bool { if let Ok(Some(bpf_object)) = redirector_shared_state.get_bpf_object().await { if redirect { if let Ok(local_port) = redirector_shared_state.get_local_port().await { @@ -207,13 +210,16 @@ pub async fn update_imds_redirect_policy( } else { logger::write("Success deleted bpf map for IMDS redirect policy.".to_string()); } + true + } else { + false } } pub async fn update_hostga_redirect_policy( redirect: bool, redirector_shared_state: RedirectorSharedState, -) { +) -> bool { if let Ok(Some(bpf_object)) = redirector_shared_state.get_bpf_object().await { if redirect { if let Ok(local_port) = redirector_shared_state.get_local_port().await { @@ -242,5 +248,9 @@ pub async fn update_hostga_redirect_policy( } else { logger::write("Success deleted bpf map for HostGAPlugin redirect policy.".to_string()); } + + true + } else { + false } } diff --git a/proxy_agent/src/service.rs b/proxy_agent/src/service.rs index a56e8a32..23e8011e 100644 --- a/proxy_agent/src/service.rs +++ b/proxy_agent/src/service.rs @@ -9,12 +9,13 @@ use crate::proxy::proxy_connection::ConnectionLogger; use crate::proxy::proxy_server::ProxyServer; use crate::redirector::{self, Redirector}; use crate::shared_state::SharedState; +use proxy_agent_shared::current_info; use proxy_agent_shared::logger::rolling_logger::RollingLogger; use proxy_agent_shared::logger::{logger_manager, LoggerLevel}; use proxy_agent_shared::proxy_agent_aggregate_status; use proxy_agent_shared::telemetry::event_logger; - use std::path::PathBuf; + #[cfg(not(windows))] use std::time::Duration; @@ -46,8 +47,8 @@ pub async fn start_service(shared_state: SharedState) { let start_message = format!( "============== GuestProxyAgent ({}) is starting on {}({}), elapsed: {}", proxy_agent_shared::misc_helpers::get_current_version(), - helpers::get_long_os_version(), - helpers::get_cpu_arch(), + current_info::get_long_os_version(), + current_info::get_cpu_arch(), helpers::get_elapsed_time_in_millisec() ); logger::write_information(start_message.clone()); diff --git a/proxy_agent/src/shared_state.rs b/proxy_agent/src/shared_state.rs index 4615552a..58553849 100644 --- a/proxy_agent/src/shared_state.rs +++ b/proxy_agent/src/shared_state.rs @@ -1,13 +1,15 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +pub mod access_control_wrapper; pub mod agent_status_wrapper; +pub mod connection_summary_wrapper; pub mod key_keeper_wrapper; pub mod provision_wrapper; pub mod proxy_server_wrapper; pub mod redirector_wrapper; -pub mod telemetry_wrapper; +use proxy_agent_shared::common_state::CommonState; use tokio_util::sync::CancellationToken; const UNKNOWN_STATUS_MESSAGE: &str = "Status unknown."; @@ -22,7 +24,7 @@ const UNKNOWN_STATUS_MESSAGE: &str = "Status unknown."; /// use proxy_agent::shared_state::SharedState; /// let shared_state = SharedState::start_all(); /// let key_keeper_shared_state = shared_state.get_key_keeper_shared_state(); -/// let telemetry_shared_state = shared_state.get_telemetry_shared_state(); +/// let common_state = shared_state.get_common_state(); /// let provision_shared_state = shared_state.get_provision_shared_state(); /// let agent_status_shared_state = shared_state.get_agent_status_shared_state(); /// let redirector_shared_state = shared_state.get_redirector_shared_state(); @@ -34,10 +36,10 @@ const UNKNOWN_STATUS_MESSAGE: &str = "Status unknown."; pub struct SharedState { /// The cancellation token is used to cancel the agent when the agent is stopped cancellation_token: CancellationToken, + /// The sender for the common states + common_state: proxy_agent_shared::common_state::CommonState, /// The sender for the key keeper module key_keeper_shared_state: key_keeper_wrapper::KeyKeeperSharedState, - /// The sender for the telemetry event modules - telemetry_shared_state: telemetry_wrapper::TelemetrySharedState, /// The sender for the provision module provision_shared_state: provision_wrapper::ProvisionSharedState, /// The sender for the agent status module @@ -46,6 +48,10 @@ pub struct SharedState { redirector_shared_state: redirector_wrapper::RedirectorSharedState, /// The sender for the proxy server module proxy_server_shared_state: proxy_server_wrapper::ProxyServerSharedState, + /// The sender for the access control module + access_control_shared_state: access_control_wrapper::AccessControlSharedState, + /// The sender for the connection summary module + connection_summary_shared_state: connection_summary_wrapper::ConnectionSummarySharedState, } impl SharedState { @@ -53,11 +59,15 @@ impl SharedState { SharedState { cancellation_token: CancellationToken::new(), key_keeper_shared_state: key_keeper_wrapper::KeyKeeperSharedState::start_new(), - telemetry_shared_state: telemetry_wrapper::TelemetrySharedState::start_new(), + common_state: CommonState::start_new(), provision_shared_state: provision_wrapper::ProvisionSharedState::start_new(), agent_status_shared_state: agent_status_wrapper::AgentStatusSharedState::start_new(), redirector_shared_state: redirector_wrapper::RedirectorSharedState::start_new(), proxy_server_shared_state: proxy_server_wrapper::ProxyServerSharedState::start_new(), + access_control_shared_state: + access_control_wrapper::AccessControlSharedState::start_new(), + connection_summary_shared_state: + connection_summary_wrapper::ConnectionSummarySharedState::start_new(), } } @@ -65,8 +75,8 @@ impl SharedState { self.key_keeper_shared_state.clone() } - pub fn get_telemetry_shared_state(&self) -> telemetry_wrapper::TelemetrySharedState { - self.telemetry_shared_state.clone() + pub fn get_common_state(&self) -> CommonState { + self.common_state.clone() } pub fn get_provision_shared_state(&self) -> provision_wrapper::ProvisionSharedState { @@ -85,6 +95,18 @@ impl SharedState { self.proxy_server_shared_state.clone() } + pub fn get_access_control_shared_state( + &self, + ) -> access_control_wrapper::AccessControlSharedState { + self.access_control_shared_state.clone() + } + + pub fn get_connection_summary_shared_state( + &self, + ) -> connection_summary_wrapper::ConnectionSummarySharedState { + self.connection_summary_shared_state.clone() + } + pub fn get_cancellation_token(&self) -> CancellationToken { self.cancellation_token.clone() } @@ -93,3 +115,34 @@ impl SharedState { self.cancellation_token.cancel(); } } + +/// The shared state for the lower priority event threads, including event logger & reader tasks and status reporting task +/// It contains the cancellation token, which is used to cancel the event threads when the agent is stopped. +/// It also contains the senders for the key keeper, provision, agent status, and connection summary modules. +/// This struct contains multiple shared states to avoid too_many_arguments error from `cargo clippy`. +#[derive(Clone)] +pub struct EventThreadsSharedState { + pub cancellation_token: CancellationToken, + pub common_state: CommonState, + pub access_control_shared_state: access_control_wrapper::AccessControlSharedState, + pub redirector_shared_state: redirector_wrapper::RedirectorSharedState, + pub key_keeper_shared_state: key_keeper_wrapper::KeyKeeperSharedState, + pub provision_shared_state: provision_wrapper::ProvisionSharedState, + pub agent_status_shared_state: agent_status_wrapper::AgentStatusSharedState, + pub connection_summary_shared_state: connection_summary_wrapper::ConnectionSummarySharedState, +} + +impl EventThreadsSharedState { + pub fn new(shared_state: &SharedState) -> Self { + EventThreadsSharedState { + cancellation_token: shared_state.get_cancellation_token(), + common_state: shared_state.get_common_state(), + access_control_shared_state: shared_state.get_access_control_shared_state(), + redirector_shared_state: shared_state.get_redirector_shared_state(), + key_keeper_shared_state: shared_state.get_key_keeper_shared_state(), + provision_shared_state: shared_state.get_provision_shared_state(), + agent_status_shared_state: shared_state.get_agent_status_shared_state(), + connection_summary_shared_state: shared_state.get_connection_summary_shared_state(), + } + } +} diff --git a/proxy_agent/src/shared_state/access_control_wrapper.rs b/proxy_agent/src/shared_state/access_control_wrapper.rs new file mode 100644 index 00000000..8f3749a6 --- /dev/null +++ b/proxy_agent/src/shared_state/access_control_wrapper.rs @@ -0,0 +1,262 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +use crate::common::error::Error; +use crate::common::logger; +use crate::common::result::Result; +use crate::key_keeper::key::AuthorizationItem; +use crate::proxy::authorization_rules::ComputedAuthorizationItem; +use tokio::sync::{mpsc, oneshot}; + +/// The AccessControlAction enum represents the actions that can be performed on the Access Control +enum AccessControlAction { + SetWireServer { + rules: Option, + response: oneshot::Sender<()>, + }, + GetWireServer { + response: oneshot::Sender>, + }, + SetImds { + rules: Option, + response: oneshot::Sender<()>, + }, + GetImds { + response: oneshot::Sender>, + }, + SetHostGA { + rules: Option, + response: oneshot::Sender<()>, + }, + GetHostGA { + response: oneshot::Sender>, + }, +} + +/// The AccessControlState struct is used to send actions to the Access Control Rules related shared state fields +#[derive(Clone, Debug)] +pub struct AccessControlSharedState(mpsc::Sender); + +impl AccessControlSharedState { + pub fn start_new() -> Self { + let (sender, mut receiver) = mpsc::channel(100); + + tokio::spawn(async move { + // The authorization rules for the WireServer endpoints + let mut wireserver_rules: Option = None; + // The authorization rules for the IMDS endpoints + let mut imds_rules: Option = None; + // The authorization rules for the HostGAPlugin endpoints + let mut hostga_rules: Option = None; + loop { + match receiver.recv().await { + Some(AccessControlAction::SetWireServer { rules, response }) => { + wireserver_rules = rules; + if response.send(()).is_err() { + logger::write_warning( + "Failed to send response to AccessControlAction::SetWireServer" + .to_string(), + ); + } + } + Some(AccessControlAction::GetWireServer { response }) => { + if response.send(wireserver_rules.clone()).is_err() { + logger::write_warning( + "Failed to send response to AccessControlAction::GetWireServer" + .to_string(), + ); + } + } + Some(AccessControlAction::SetImds { rules, response }) => { + imds_rules = rules; + if response.send(()).is_err() { + logger::write_warning( + "Failed to send response to AccessControlAction::SetImds" + .to_string(), + ); + } + } + Some(AccessControlAction::GetImds { response }) => { + if response.send(imds_rules.clone()).is_err() { + logger::write_warning( + "Failed to send response to AccessControlAction::GetImds" + .to_string(), + ); + } + } + Some(AccessControlAction::SetHostGA { rules, response }) => { + hostga_rules = rules; + if response.send(()).is_err() { + logger::write_warning( + "Failed to send response to AccessControlAction::SetHostGA" + .to_string(), + ); + } + } + Some(AccessControlAction::GetHostGA { response }) => { + if response.send(hostga_rules.clone()).is_err() { + logger::write_warning( + "Failed to send response to AccessControlAction::GetHostGA" + .to_string(), + ); + } + } + None => break, + } + } + }); + + Self(sender) + } + + pub async fn set_wireserver_rules(&self, rules: Option) -> Result<()> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(AccessControlAction::SetWireServer { + rules: rules.map(ComputedAuthorizationItem::from_authorization_item), + response, + }) + .await + .map_err(|e| { + Error::SendError( + "AccessControlAction::SetWireServer".to_string(), + e.to_string(), + ) + })?; + receiver + .await + .map_err(|e| Error::RecvError("AccessControlAction::GetWireServer".to_string(), e)) + } + + pub async fn get_wireserver_rules(&self) -> Result> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(AccessControlAction::GetWireServer { response }) + .await + .map_err(|e| { + Error::SendError( + "AccessControlAction::GetWireServer".to_string(), + e.to_string(), + ) + })?; + receiver + .await + .map_err(|e| Error::RecvError("AccessControlAction::GetWireServer".to_string(), e)) + } + + pub async fn set_imds_rules(&self, rules: Option) -> Result<()> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(AccessControlAction::SetImds { + rules: rules.map(ComputedAuthorizationItem::from_authorization_item), + response, + }) + .await + .map_err(|e| { + Error::SendError("AccessControlAction::SetImds".to_string(), e.to_string()) + })?; + receiver + .await + .map_err(|e| Error::RecvError("AccessControlAction::SetImds".to_string(), e)) + } + + pub async fn get_imds_rules(&self) -> Result> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(AccessControlAction::GetImds { response }) + .await + .map_err(|e| { + Error::SendError("AccessControlAction::GetImds".to_string(), e.to_string()) + })?; + receiver + .await + .map_err(|e| Error::RecvError("AccessControlAction::GetImds".to_string(), e)) + } + + pub async fn set_hostga_rules(&self, rules: Option) -> Result<()> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(AccessControlAction::SetHostGA { + rules: rules.map(ComputedAuthorizationItem::from_authorization_item), + response, + }) + .await + .map_err(|e| { + Error::SendError("AccessControlAction::SetHostGA".to_string(), e.to_string()) + })?; + receiver + .await + .map_err(|e| Error::RecvError("AccessControlAction::GetHostGA".to_string(), e)) + } + + pub async fn get_hostga_rules(&self) -> Result> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(AccessControlAction::GetHostGA { response }) + .await + .map_err(|e| { + Error::SendError("AccessControlAction::GetHostGA".to_string(), e.to_string()) + })?; + receiver + .await + .map_err(|e| Error::RecvError("AccessControlAction::GetHostGA".to_string(), e)) + } +} + +#[cfg(test)] +mod tests { + use crate::proxy::authorization_rules; + + use super::*; + + #[tokio::test] + async fn test_access_control_shared_state() { + let access_control = AccessControlSharedState::start_new(); + + // test WireServer Rule + let rule_id = "test_rule_id"; + let rules = AuthorizationItem { + defaultAccess: "allow".to_string(), + mode: "audit".to_string(), + id: rule_id.to_string(), + rules: None, + }; + access_control + .set_wireserver_rules(Some(rules.clone())) + .await + .unwrap(); + let retrieved_rules = access_control.get_wireserver_rules().await.unwrap(); + assert!(retrieved_rules.is_some()); + let retrieved_rules = retrieved_rules.unwrap(); + assert_eq!(rules.id, retrieved_rules.id); + assert_eq!(true, retrieved_rules.defaultAllowed); + assert_eq!( + authorization_rules::AuthorizationMode::Audit, + retrieved_rules.mode + ); + assert_eq!(0, retrieved_rules.privilegeAssignments.len()); + assert_eq!(0, retrieved_rules.privileges.len()); + assert_eq!(0, retrieved_rules.identities.len()); + + // test IMDS Rule + let rules = AuthorizationItem { + defaultAccess: "deny".to_string(), + mode: "enforce".to_string(), + id: rule_id.to_string(), + rules: None, + }; + access_control + .set_imds_rules(Some(rules.clone())) + .await + .unwrap(); + let retrieved_rules = access_control.get_imds_rules().await.unwrap(); + assert!(retrieved_rules.is_some()); + let retrieved_rules = retrieved_rules.unwrap(); + assert_eq!(rules.id, retrieved_rules.id); + assert_eq!(false, retrieved_rules.defaultAllowed); + assert_eq!( + authorization_rules::AuthorizationMode::Enforce, + retrieved_rules.mode + ); + } +} diff --git a/proxy_agent/src/shared_state/agent_status_wrapper.rs b/proxy_agent/src/shared_state/agent_status_wrapper.rs index e7f55aa0..3be23d0e 100644 --- a/proxy_agent/src/shared_state/agent_status_wrapper.rs +++ b/proxy_agent/src/shared_state/agent_status_wrapper.rs @@ -3,19 +3,14 @@ //! This module contains the logic to interact with the proxy agent status. //! The proxy agent status contains the 'state' and 'status message' of the key keeper, telemetry reader, telemetry logger, redirector, and proxy server modules. -//! The proxy agent status contains the 'connection summary' of the proxy server. -//! The proxy agent status contains the 'failed connection summary' of the proxy server. //! The proxy agent status contains the 'connection count' of the proxy server. +use crate::common::error::Error; use crate::common::logger; use crate::common::result::Result; -use crate::{common::error::Error, proxy::proxy_summary::ProxySummary}; use proxy_agent_shared::logger::LoggerLevel; -use proxy_agent_shared::proxy_agent_aggregate_status::{ - ModuleState, ProxyAgentDetailStatus, ProxyConnectionSummary, -}; +use proxy_agent_shared::proxy_agent_aggregate_status::{ModuleState, ProxyAgentDetailStatus}; use proxy_agent_shared::telemetry::event_logger; -use std::collections::{hash_map, HashMap}; use tokio::sync::{mpsc, oneshot}; const MAX_STATUS_MESSAGE_LENGTH: usize = 1024; @@ -39,23 +34,6 @@ enum AgentStatusAction { module: AgentStatusModule, response: oneshot::Sender, }, - AddOneConnectionSummary { - summary: ProxySummary, - response: oneshot::Sender<()>, - }, - AddOneFailedConnectionSummary { - summary: ProxySummary, - response: oneshot::Sender<()>, - }, - GetAllConnectionSummary { - response: oneshot::Sender>, - }, - GetAllFailedConnectionSummary { - response: oneshot::Sender>, - }, - ClearAllSummary { - response: oneshot::Sender<()>, - }, GetConnectionCount { response: oneshot::Sender, }, @@ -97,11 +75,6 @@ impl AgentStatusSharedState { let mut proxy_agent_status_state = ModuleState::UNKNOWN; let mut proxy_agent_status_message = super::UNKNOWN_STATUS_MESSAGE.to_string(); - // The proxy connection summary from the proxy - let mut proxy_summary: HashMap = HashMap::new(); - // The failed authenticate summary from the proxy - let mut failed_authenticate_summary: HashMap = - HashMap::new(); // The proxied connection count for the listener let mut tcp_connection_count: u128 = 0; let mut http_connection_count: u128 = 0; @@ -227,68 +200,6 @@ impl AgentStatusSharedState { )); } } - AgentStatusAction::AddOneConnectionSummary { summary, response } => { - let key = summary.to_key_string(); - if let hash_map::Entry::Vacant(e) = proxy_summary.entry(key.clone()) { - e.insert(summary.into()); - } else if let Some(connection_summary) = proxy_summary.get_mut(&key) { - //increase_count(connection_summary); - connection_summary.count += 1; - } - if response.send(()).is_err() { - logger::write_warning("Failed to send response to AgentStatusAction::AddOneConnectionSummary".to_string()); - } - } - AgentStatusAction::AddOneFailedConnectionSummary { summary, response } => { - let key = summary.to_key_string(); - if let hash_map::Entry::Vacant(e) = - failed_authenticate_summary.entry(key.clone()) - { - e.insert(summary.into()); - } else if let Some(connection_summary) = - failed_authenticate_summary.get_mut(&key) - { - //increase_count(connection_summary); - connection_summary.count += 1; - } - if response.send(()).is_err() { - logger::write_warning("Failed to send response to AgentStatusAction::AddOneFailedConnectionSummary".to_string()); - } - } - AgentStatusAction::GetAllConnectionSummary { response } => { - let mut copy_summary: Vec = Vec::new(); - for (_, connection_summary) in proxy_summary.iter() { - copy_summary.push(connection_summary.clone()); - } - if let Err(summary) = response.send(copy_summary) { - logger::write_warning(format!( - "Failed to send response to AgentStatusAction::GetAllConnectionSummary with summary count '{:?}'", - summary.len() - )); - } - } - AgentStatusAction::GetAllFailedConnectionSummary { response } => { - let mut copy_summary: Vec = Vec::new(); - for (_, connection_summary) in failed_authenticate_summary.iter() { - copy_summary.push(connection_summary.clone()); - } - if let Err(summary) = response.send(copy_summary) { - logger::write_warning(format!( - "Failed to send response to AgentStatusAction::GetAllFailedConnectionSummary with summary count '{:?}'", - summary.len() - )); - } - } - AgentStatusAction::ClearAllSummary { response } => { - proxy_summary.clear(); - failed_authenticate_summary.clear(); - if response.send(()).is_err() { - logger::write_warning( - "Failed to send response to AgentStatusAction::ClearAllSummary" - .to_string(), - ); - } - } AgentStatusAction::GetConnectionCount { response } => { if let Err(count) = response.send(http_connection_count) { logger::write_warning(format!( @@ -321,105 +232,6 @@ impl AgentStatusSharedState { AgentStatusSharedState(tx) } - pub async fn add_one_connection_summary(&self, summary: ProxySummary) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - self.0 - .send(AgentStatusAction::AddOneConnectionSummary { - summary, - response: response_tx, - }) - .await - .map_err(|e| { - Error::SendError( - "AgentStatusAction::AddOneConnectionSummary".to_string(), - e.to_string(), - ) - })?; - response_rx.await.map_err(|e| { - Error::RecvError("AgentStatusAction::AddOneConnectionSummary".to_string(), e) - }) - } - - pub async fn add_one_failed_connection_summary(&self, summary: ProxySummary) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - self.0 - .send(AgentStatusAction::AddOneFailedConnectionSummary { - summary, - response: response_tx, - }) - .await - .map_err(|e| { - Error::SendError( - "AgentStatusAction::AddOneFailedConnectionSummary".to_string(), - e.to_string(), - ) - })?; - response_rx.await.map_err(|e| { - Error::RecvError( - "AgentStatusAction::AddOneFailedConnectionSummary".to_string(), - e, - ) - }) - } - - pub async fn clear_all_summary(&self) -> Result<()> { - let (response_tx, response_rx) = oneshot::channel(); - self.0 - .send(AgentStatusAction::ClearAllSummary { - response: response_tx, - }) - .await - .map_err(|e| { - Error::SendError( - "AgentStatusAction::ClearAllSummary".to_string(), - e.to_string(), - ) - })?; - response_rx - .await - .map_err(|e| Error::RecvError("AgentStatusAction::ClearAllSummary".to_string(), e))?; - Ok(()) - } - - pub async fn get_all_connection_summary(&self) -> Result> { - let (response_tx, response_rx) = oneshot::channel(); - self.0 - .send(AgentStatusAction::GetAllConnectionSummary { - response: response_tx, - }) - .await - .map_err(|e| { - Error::SendError( - "AgentStatusAction::GetAllConnectionSummary".to_string(), - e.to_string(), - ) - })?; - response_rx.await.map_err(|e| { - Error::RecvError("AgentStatusAction::GetAllConnectionSummary".to_string(), e) - }) - } - - pub async fn get_all_failed_connection_summary(&self) -> Result> { - let (response_tx, response_rx) = oneshot::channel(); - self.0 - .send(AgentStatusAction::GetAllFailedConnectionSummary { - response: response_tx, - }) - .await - .map_err(|e| { - Error::SendError( - "AgentStatusAction::GetAllFailedConnectionSummary".to_string(), - e.to_string(), - ) - })?; - response_rx.await.map_err(|e| { - Error::RecvError( - "AgentStatusAction::GetAllFailedConnectionSummary".to_string(), - e, - ) - }) - } - async fn get_module_state(&self, module: AgentStatusModule) -> Result { let (response_tx, response_rx) = oneshot::channel(); self.0 @@ -629,9 +441,8 @@ impl AgentStatusSharedState { #[cfg(test)] mod tests { use super::*; - use crate::{proxy::proxy_summary::ProxySummary, shared_state}; + use crate::shared_state; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; - use std::path::PathBuf; #[tokio::test] async fn test_agent_status_shared_state() { @@ -697,82 +508,10 @@ mod tests { .unwrap(); assert_eq!(1, connection_count); - let connection_summary = ProxySummary { - id: connection_id, - method: "GET".to_string(), - url: "/status".to_string(), - clientIp: "127.0.0.1".to_string(), - clientPort: 6080, - ip: "127.0.0.1".to_string(), - port: 8080, - userId: 999, - userName: "user1".to_string(), - userGroups: vec!["group1".to_string()], - processFullPath: PathBuf::from("C:\\path\\to\\process.exe"), - processCmdLine: "process --arg1 --arg2".to_string(), - runAsElevated: true, - responseStatus: "200 OK".to_string(), - elapsedTime: 123, - errorDetails: "".to_string(), - }; - agent_status_shared_state - .add_one_connection_summary(connection_summary.clone()) - .await - .unwrap(); - let get_all_connection_summary = agent_status_shared_state - .get_all_connection_summary() - .await - .unwrap(); - assert_eq!(1, get_all_connection_summary.len()); - assert_eq!(1, get_all_connection_summary[0].count); - let connection_id = agent_status_shared_state .increase_connection_count() .await .unwrap(); assert_eq!(2, connection_id); - - let failed_connection_summary = ProxySummary { - id: connection_id, - method: "GET".to_string(), - url: "/status".to_string(), - clientIp: "127.0.0.1".to_string(), - clientPort: 6080, - ip: "127.0.0.1".to_string(), - port: 8080, - userId: 999, - userName: "user1".to_string(), - userGroups: vec!["group1".to_string()], - processFullPath: PathBuf::from("C:\\path\\to\\process.exe"), - processCmdLine: "process --arg1 --arg2".to_string(), - runAsElevated: true, - responseStatus: "500 Internal Server Error".to_string(), - elapsedTime: 123, - errorDetails: "Some error occurred".to_string(), - }; - agent_status_shared_state - .add_one_failed_connection_summary(failed_connection_summary.clone()) - .await - .unwrap(); - let get_all_failed_connection_summary = agent_status_shared_state - .get_all_failed_connection_summary() - .await - .unwrap(); - assert_eq!(1, get_all_failed_connection_summary.len()); - - // clear all summaries - agent_status_shared_state.clear_all_summary().await.unwrap(); - let get_all_connection_summary = agent_status_shared_state - .get_all_connection_summary() - .await - .unwrap(); - assert_eq!(0, get_all_connection_summary.len()); - - // connection count should not be reset - let connection_id = agent_status_shared_state - .increase_connection_count() - .await - .unwrap(); - assert_eq!(3, connection_id); } } diff --git a/proxy_agent/src/shared_state/connection_summary_wrapper.rs b/proxy_agent/src/shared_state/connection_summary_wrapper.rs new file mode 100644 index 00000000..dd472c78 --- /dev/null +++ b/proxy_agent/src/shared_state/connection_summary_wrapper.rs @@ -0,0 +1,297 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +//! This module contains the logic to interact with the connection summary status. +//! The proxy agent status contains the 'connection summary' of the proxy server. +//! The proxy agent status contains the 'failed connection summary' of the proxy server. + +use crate::common::logger; +use crate::common::result::Result; +use crate::{common::error::Error, proxy::proxy_summary::ProxySummary}; +use proxy_agent_shared::proxy_agent_aggregate_status::ProxyConnectionSummary; +use std::collections::{hash_map, HashMap}; +use tokio::sync::{mpsc, oneshot}; + +enum ConnectionSummaryAction { + AddOneConnection { + summary: ProxySummary, + response: oneshot::Sender<()>, + }, + AddOneFailedConnection { + summary: ProxySummary, + response: oneshot::Sender<()>, + }, + GetAllConnection { + response: oneshot::Sender>, + }, + GetAllFailedConnection { + response: oneshot::Sender>, + }, + ClearAll { + response: oneshot::Sender<()>, + }, +} + +#[derive(Clone, Debug)] +pub struct ConnectionSummarySharedState(mpsc::Sender); + +impl ConnectionSummarySharedState { + pub fn start_new() -> Self { + let (tx, mut rx) = mpsc::channel(100); + tokio::spawn(async move { + // The proxy connection summary from the proxy + let mut proxy_summary: HashMap = HashMap::new(); + // The failed authenticate summary from the proxy + let mut failed_authenticate_summary: HashMap = + HashMap::new(); + + while let Some(action) = rx.recv().await { + match action { + ConnectionSummaryAction::AddOneConnection { summary, response } => { + let key = summary.to_key_string(); + if let hash_map::Entry::Vacant(e) = proxy_summary.entry(key.clone()) { + e.insert(summary.into()); + } else if let Some(connection_summary) = proxy_summary.get_mut(&key) { + //increase_count(connection_summary); + connection_summary.count += 1; + } + if response.send(()).is_err() { + logger::write_warning("Failed to send response to ConnectionSummaryAction::AddOneConnection".to_string()); + } + } + ConnectionSummaryAction::AddOneFailedConnection { summary, response } => { + let key = summary.to_key_string(); + if let hash_map::Entry::Vacant(e) = + failed_authenticate_summary.entry(key.clone()) + { + e.insert(summary.into()); + } else if let Some(connection_summary) = + failed_authenticate_summary.get_mut(&key) + { + //increase_count(connection_summary); + connection_summary.count += 1; + } + if response.send(()).is_err() { + logger::write_warning("Failed to send response to ConnectionSummaryAction::AddOneFailedConnection".to_string()); + } + } + ConnectionSummaryAction::GetAllConnection { response } => { + let mut copy_summary: Vec = Vec::new(); + for (_, connection_summary) in proxy_summary.iter() { + copy_summary.push(connection_summary.clone()); + } + if let Err(summary) = response.send(copy_summary) { + logger::write_warning(format!( + "Failed to send response to ConnectionSummaryAction::GetAllConnection with summary count '{:?}'", + summary.len() + )); + } + } + ConnectionSummaryAction::GetAllFailedConnection { response } => { + let mut copy_summary: Vec = Vec::new(); + for (_, connection_summary) in failed_authenticate_summary.iter() { + copy_summary.push(connection_summary.clone()); + } + if let Err(summary) = response.send(copy_summary) { + logger::write_warning(format!( + "Failed to send response to ConnectionSummaryAction::GetAllFailedConnection with summary count '{:?}'", + summary.len() + )); + } + } + ConnectionSummaryAction::ClearAll { response } => { + proxy_summary.clear(); + failed_authenticate_summary.clear(); + if response.send(()).is_err() { + logger::write_warning( + "Failed to send response to ConnectionSummaryAction::ClearAll" + .to_string(), + ); + } + } + } + } + }); + + ConnectionSummarySharedState(tx) + } + + pub async fn add_one_connection_summary(&self, summary: ProxySummary) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .send(ConnectionSummaryAction::AddOneConnection { + summary, + response: response_tx, + }) + .await + .map_err(|e| { + Error::SendError( + "ConnectionSummaryAction::AddOneConnection".to_string(), + e.to_string(), + ) + })?; + response_rx.await.map_err(|e| { + Error::RecvError("ConnectionSummaryAction::AddOneConnection".to_string(), e) + }) + } + + pub async fn add_one_failed_connection_summary(&self, summary: ProxySummary) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .send(ConnectionSummaryAction::AddOneFailedConnection { + summary, + response: response_tx, + }) + .await + .map_err(|e| { + Error::SendError( + "ConnectionSummaryAction::AddOneFailedConnection".to_string(), + e.to_string(), + ) + })?; + response_rx.await.map_err(|e| { + Error::RecvError( + "ConnectionSummaryAction::AddOneFailedConnection".to_string(), + e, + ) + }) + } + + pub async fn clear_all_summary(&self) -> Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .send(ConnectionSummaryAction::ClearAll { + response: response_tx, + }) + .await + .map_err(|e| { + Error::SendError( + "ConnectionSummaryAction::ClearAll".to_string(), + e.to_string(), + ) + })?; + response_rx + .await + .map_err(|e| Error::RecvError("ConnectionSummaryAction::ClearAll".to_string(), e))?; + Ok(()) + } + + pub async fn get_all_connection_summary(&self) -> Result> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .send(ConnectionSummaryAction::GetAllConnection { + response: response_tx, + }) + .await + .map_err(|e| { + Error::SendError( + "ConnectionSummaryAction::GetAllConnection".to_string(), + e.to_string(), + ) + })?; + response_rx.await.map_err(|e| { + Error::RecvError("ConnectionSummaryAction::GetAllConnection".to_string(), e) + }) + } + + pub async fn get_all_failed_connection_summary(&self) -> Result> { + let (response_tx, response_rx) = oneshot::channel(); + self.0 + .send(ConnectionSummaryAction::GetAllFailedConnection { + response: response_tx, + }) + .await + .map_err(|e| { + Error::SendError( + "ConnectionSummaryAction::GetAllFailedConnection".to_string(), + e.to_string(), + ) + })?; + response_rx.await.map_err(|e| { + Error::RecvError( + "ConnectionSummaryAction::GetAllFailedConnection".to_string(), + e, + ) + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::proxy::proxy_summary::ProxySummary; + use std::path::PathBuf; + + #[tokio::test] + async fn test_agent_status_shared_state() { + let connection_summary_shared_state = ConnectionSummarySharedState::start_new(); + + let connection_summary = ProxySummary { + id: 1, + method: "GET".to_string(), + url: "/status".to_string(), + clientIp: "127.0.0.1".to_string(), + clientPort: 6080, + ip: "127.0.0.1".to_string(), + port: 8080, + userId: 999, + userName: "user1".to_string(), + userGroups: vec!["group1".to_string()], + processFullPath: PathBuf::from("C:\\path\\to\\process.exe"), + processCmdLine: "process --arg1 --arg2".to_string(), + runAsElevated: true, + responseStatus: "200 OK".to_string(), + elapsedTime: 123, + errorDetails: "".to_string(), + }; + connection_summary_shared_state + .add_one_connection_summary(connection_summary.clone()) + .await + .unwrap(); + let get_all_connection_summary = connection_summary_shared_state + .get_all_connection_summary() + .await + .unwrap(); + assert_eq!(1, get_all_connection_summary.len()); + assert_eq!(1, get_all_connection_summary[0].count); + + let failed_connection_summary = ProxySummary { + id: 2, + method: "GET".to_string(), + url: "/status".to_string(), + clientIp: "127.0.0.1".to_string(), + clientPort: 6080, + ip: "127.0.0.1".to_string(), + port: 8080, + userId: 999, + userName: "user1".to_string(), + userGroups: vec!["group1".to_string()], + processFullPath: PathBuf::from("C:\\path\\to\\process.exe"), + processCmdLine: "process --arg1 --arg2".to_string(), + runAsElevated: true, + responseStatus: "500 Internal Server Error".to_string(), + elapsedTime: 123, + errorDetails: "Some error occurred".to_string(), + }; + connection_summary_shared_state + .add_one_failed_connection_summary(failed_connection_summary.clone()) + .await + .unwrap(); + let get_all_failed_connection_summary = connection_summary_shared_state + .get_all_failed_connection_summary() + .await + .unwrap(); + assert_eq!(1, get_all_failed_connection_summary.len()); + + // clear all summaries + connection_summary_shared_state + .clear_all_summary() + .await + .unwrap(); + let get_all_connection_summary = connection_summary_shared_state + .get_all_connection_summary() + .await + .unwrap(); + assert_eq!(0, get_all_connection_summary.len()); + } +} diff --git a/proxy_agent/src/shared_state/key_keeper_wrapper.rs b/proxy_agent/src/shared_state/key_keeper_wrapper.rs index de75b566..4c3b46aa 100644 --- a/proxy_agent/src/shared_state/key_keeper_wrapper.rs +++ b/proxy_agent/src/shared_state/key_keeper_wrapper.rs @@ -26,7 +26,6 @@ /// let state = key_keeper_state.get_current_secure_channel_state().await?; /// let rule_id = key_keeper_state.get_wireserver_rule_id().await?; /// let rule_id = key_keeper_state.get_imds_rule_id().await?; -/// let status_message = key_keeper_state.get_status_message().await?; /// /// // clear the key once the feature is disabled /// key_keeper_state.clear_key().await?; @@ -38,8 +37,6 @@ /// ``` use crate::common::error::Error; use crate::common::result::Result; -use crate::key_keeper::key::AuthorizationItem; -use crate::proxy::authorization_rules::ComputedAuthorizationItem; use crate::{common::logger, key_keeper::key::Key}; use std::sync::Arc; use tokio::sync::{mpsc, oneshot, Notify}; @@ -81,27 +78,6 @@ enum KeyKeeperAction { rule_id: String, response: oneshot::Sender<()>, }, - SetWireServerRules { - rules: Option, - response: oneshot::Sender<()>, - }, - GetWireServerRules { - response: oneshot::Sender>, - }, - SetImdsRules { - rules: Option, - response: oneshot::Sender<()>, - }, - GetImdsRules { - response: oneshot::Sender>, - }, - SetHostGARules { - rules: Option, - response: oneshot::Sender<()>, - }, - GetHostGARules { - response: oneshot::Sender>, - }, GetNotify { response: oneshot::Sender>, }, @@ -126,12 +102,6 @@ impl KeyKeeperSharedState { let mut imds_rule_id: String = String::new(); // The rule ID for the HostGA endpoints let mut hostga_rule_id: String = String::new(); - // The authorization rules for the WireServer endpoints - let mut wireserver_rules: Option = None; - // The authorization rules for the IMDS endpoints - let mut imds_rules: Option = None; - // The authorization rules for the HostGAPlugin endpoints - let mut hostga_rules: Option = None; let notify = Arc::new(Notify::new()); loop { @@ -216,57 +186,6 @@ impl KeyKeeperSharedState { )); } } - Some(KeyKeeperAction::SetWireServerRules { rules, response }) => { - wireserver_rules = rules; - if response.send(()).is_err() { - logger::write_warning( - "Failed to send response to KeyKeeperAction::SetWireServerRules" - .to_string(), - ); - } - } - Some(KeyKeeperAction::GetWireServerRules { response }) => { - if response.send(wireserver_rules.clone()).is_err() { - logger::write_warning( - "Failed to send response to KeyKeeperAction::GetWireServerRules" - .to_string(), - ); - } - } - Some(KeyKeeperAction::SetImdsRules { rules, response }) => { - imds_rules = rules; - if response.send(()).is_err() { - logger::write_warning( - "Failed to send response to KeyKeeperAction::SetImdsRules" - .to_string(), - ); - } - } - Some(KeyKeeperAction::GetImdsRules { response }) => { - if response.send(imds_rules.clone()).is_err() { - logger::write_warning( - "Failed to send response to KeyKeeperAction::GetImdsRules" - .to_string(), - ); - } - } - Some(KeyKeeperAction::SetHostGARules { rules, response }) => { - hostga_rules = rules; - if response.send(()).is_err() { - logger::write_warning( - "Failed to send response to KeyKeeperAction::SetHostGARules" - .to_string(), - ); - } - } - Some(KeyKeeperAction::GetHostGARules { response }) => { - if response.send(hostga_rules.clone()).is_err() { - logger::write_warning( - "Failed to send response to KeyKeeperAction::GetHostGARules" - .to_string(), - ); - } - } Some(KeyKeeperAction::GetNotify { response }) => { if response.send(notify.clone()).is_err() { logger::write_warning( @@ -528,99 +447,6 @@ impl KeyKeeperSharedState { } } - pub async fn set_wireserver_rules(&self, rules: Option) -> Result<()> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(KeyKeeperAction::SetWireServerRules { - rules: rules.map(ComputedAuthorizationItem::from_authorization_item), - response, - }) - .await - .map_err(|e| { - Error::SendError( - "KeyKeeperAction::SetWireServerRules".to_string(), - e.to_string(), - ) - })?; - receiver - .await - .map_err(|e| Error::RecvError("KeyKeeperAction::SetWireServerRules".to_string(), e)) - } - - pub async fn get_wireserver_rules(&self) -> Result> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(KeyKeeperAction::GetWireServerRules { response }) - .await - .map_err(|e| { - Error::SendError( - "KeyKeeperAction::GetWireServerRules".to_string(), - e.to_string(), - ) - })?; - receiver - .await - .map_err(|e| Error::RecvError("KeyKeeperAction::GetWireServerRules".to_string(), e)) - } - - pub async fn set_imds_rules(&self, rules: Option) -> Result<()> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(KeyKeeperAction::SetImdsRules { - rules: rules.map(ComputedAuthorizationItem::from_authorization_item), - response, - }) - .await - .map_err(|e| { - Error::SendError("KeyKeeperAction::SetImdsRules".to_string(), e.to_string()) - })?; - receiver - .await - .map_err(|e| Error::RecvError("KeyKeeperAction::SetImdsRules".to_string(), e)) - } - - pub async fn get_imds_rules(&self) -> Result> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(KeyKeeperAction::GetImdsRules { response }) - .await - .map_err(|e| { - Error::SendError("KeyKeeperAction::GetImdsRules".to_string(), e.to_string()) - })?; - receiver - .await - .map_err(|e| Error::RecvError("KeyKeeperAction::GetImdsRules".to_string(), e)) - } - - pub async fn set_hostga_rules(&self, rules: Option) -> Result<()> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(KeyKeeperAction::SetHostGARules { - rules: rules.map(ComputedAuthorizationItem::from_authorization_item), - response, - }) - .await - .map_err(|e| { - Error::SendError("KeyKeeperAction::SetHostGARules".to_string(), e.to_string()) - })?; - receiver - .await - .map_err(|e| Error::RecvError("KeyKeeperAction::SetHostGARules".to_string(), e)) - } - - pub async fn get_hostga_rules(&self) -> Result> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(KeyKeeperAction::GetHostGARules { response }) - .await - .map_err(|e| { - Error::SendError("KeyKeeperAction::GetHostGARules".to_string(), e.to_string()) - })?; - receiver - .await - .map_err(|e| Error::RecvError("KeyKeeperAction::GetHostGARules".to_string(), e)) - } - pub async fn get_notify(&self) -> Result> { let (response, receiver) = oneshot::channel(); self.0 @@ -643,8 +469,6 @@ impl KeyKeeperSharedState { #[cfg(test)] mod tests { - use crate::proxy::authorization_rules; - use super::*; #[tokio::test] @@ -690,28 +514,6 @@ mod tests { assert!(updated); assert_eq!(old_rule_id, ""); assert_eq!(key_keeper.get_wireserver_rule_id().await.unwrap(), rule_id); - let rules = AuthorizationItem { - defaultAccess: "allow".to_string(), - mode: "audit".to_string(), - id: rule_id.to_string(), - rules: None, - }; - key_keeper - .set_wireserver_rules(Some(rules.clone())) - .await - .unwrap(); - let retrieved_rules = key_keeper.get_wireserver_rules().await.unwrap(); - assert!(retrieved_rules.is_some()); - let retrieved_rules = retrieved_rules.unwrap(); - assert_eq!(rules.id, retrieved_rules.id); - assert_eq!(true, retrieved_rules.defaultAllowed); - assert_eq!( - authorization_rules::AuthorizationMode::Audit, - retrieved_rules.mode - ); - assert_eq!(0, retrieved_rules.privilegeAssignments.len()); - assert_eq!(0, retrieved_rules.privileges.len()); - assert_eq!(0, retrieved_rules.identities.len()); // test IMDS Rule let rule_id = "test_imds_rule_id".to_string(); @@ -722,25 +524,6 @@ mod tests { assert!(updated); assert_eq!(old_rule_id, ""); assert_eq!(key_keeper.get_imds_rule_id().await.unwrap(), rule_id); - let rules = AuthorizationItem { - defaultAccess: "deny".to_string(), - mode: "enforce".to_string(), - id: rule_id.to_string(), - rules: None, - }; - key_keeper - .set_imds_rules(Some(rules.clone())) - .await - .unwrap(); - let retrieved_rules = key_keeper.get_imds_rules().await.unwrap(); - assert!(retrieved_rules.is_some()); - let retrieved_rules = retrieved_rules.unwrap(); - assert_eq!(rules.id, retrieved_rules.id); - assert_eq!(false, retrieved_rules.defaultAllowed); - assert_eq!( - authorization_rules::AuthorizationMode::Enforce, - retrieved_rules.mode - ); // test HostGA Rule let rule_id = "test_hostga_rule_id".to_string(); diff --git a/proxy_agent/src/shared_state/telemetry_wrapper.rs b/proxy_agent/src/shared_state/telemetry_wrapper.rs deleted file mode 100644 index f386de37..00000000 --- a/proxy_agent/src/shared_state/telemetry_wrapper.rs +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) Microsoft Corporation -// SPDX-License-Identifier: MIT - -//! This module contains the logic to interact with the telemetry module. -//! Example -//! ```rust -//! use proxy_agent::shared_state::telemetry_wrapper::TelemetrySharedState; -//! use proxy_agent::telemetry::event_reader::VmMetaData; -//! -//! let telemetry_shared_state = TelemetrySharedState::start_new(); -//! let vm_meta_data = VmMetaData::new("vm_id".to_string(), "vm_name".to_string()); -//! telemetry_shared_state.set_vm_meta_data(Some(vm_meta_data.clone())).await.unwrap(); -//! let meta_data = telemetry_shared_state.get_vm_meta_data().await.unwrap().unwrap(); -//! assert_eq!(meta_data, vm_meta_data); -//! ``` - -use crate::common::result::Result; -use crate::common::{error::Error, logger}; -use crate::telemetry::event_reader::VmMetaData; -use tokio::sync::{mpsc, oneshot}; - -enum TelemetryAction { - SetVmMetaData { - vm_meta_data: Option, - response: oneshot::Sender<()>, - }, - GetVmMetaData { - response: oneshot::Sender>, - }, -} - -#[derive(Clone, Debug)] -pub struct TelemetrySharedState(mpsc::Sender); - -impl TelemetrySharedState { - pub fn start_new() -> Self { - let (sender, mut receiver) = mpsc::channel(100); - tokio::spawn(async move { - let mut vm_meta_data: Option = None; - loop { - match receiver.recv().await { - Some(TelemetryAction::SetVmMetaData { - vm_meta_data: meta_data, - response, - }) => { - vm_meta_data = meta_data.clone(); - if response.send(()).is_err() { - logger::write_warning(format!( - "Failed to send response to TelemetryAction::SetVmMetaData '{meta_data:?}'" - )); - } - } - Some(TelemetryAction::GetVmMetaData { response }) => { - if let Err(meta_data) = response.send(vm_meta_data.clone()) { - logger::write_warning(format!( - "Failed to send response to TelemetryAction::GetVmMetaData '{meta_data:?}'" - )); - } - } - None => { - break; - } - } - } - }); - - Self(sender) - } - - pub async fn set_vm_meta_data(&self, vm_meta_data: Option) -> Result<()> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(TelemetryAction::SetVmMetaData { - vm_meta_data, - response, - }) - .await - .map_err(|e| { - Error::SendError("TelemetryAction::SetVmMetaData".to_string(), e.to_string()) - })?; - receiver - .await - .map_err(|e| Error::RecvError("TelemetryAction::SetVmMetaData".to_string(), e)) - } - - pub async fn get_vm_meta_data(&self) -> Result> { - let (response, receiver) = oneshot::channel(); - self.0 - .send(TelemetryAction::GetVmMetaData { response }) - .await - .map_err(|e| { - Error::SendError("TelemetryAction::GetVmMetaData".to_string(), e.to_string()) - })?; - receiver - .await - .map_err(|e| Error::RecvError("TelemetryAction::GetVmMetaData".to_string(), e)) - } -} diff --git a/proxy_agent/src/telemetry.rs b/proxy_agent/src/telemetry.rs deleted file mode 100644 index e01bd963..00000000 --- a/proxy_agent/src/telemetry.rs +++ /dev/null @@ -1,4 +0,0 @@ -// Copyright (c) Microsoft Corporation -// SPDX-License-Identifier: MIT -pub mod event_reader; -pub mod telemetry_event; diff --git a/proxy_agent_extension/Cargo.toml b/proxy_agent_extension/Cargo.toml index 7b30e8ed..10e66563 100644 --- a/proxy_agent_extension/Cargo.toml +++ b/proxy_agent_extension/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "ProxyAgentExt" -version = "1.0.38" # always 3-number version +version = "1.0.39" # always 3-number version edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proxy_agent_extension/src/common.rs b/proxy_agent_extension/src/common.rs index fea6c0b5..4d5ce712 100644 --- a/proxy_agent_extension/src/common.rs +++ b/proxy_agent_extension/src/common.rs @@ -8,15 +8,13 @@ use crate::structs; use crate::structs::FormattedMessage; use crate::structs::HandlerEnvironment; use crate::structs::TopLevelStatus; +use proxy_agent_shared::service; use proxy_agent_shared::{misc_helpers, telemetry}; use std::fs; use std::path::Path; use std::path::PathBuf; use std::process; -#[cfg(windows)] -use proxy_agent_shared::service; - pub fn get_handler_environment(exe_path: &Path) -> HandlerEnvironment { let mut handler_env_path: PathBuf = exe_path.to_path_buf(); handler_env_path.push(constants::HANDLER_ENVIRONMENT_FILE); @@ -169,15 +167,7 @@ pub fn get_current_seq_no(exe_path: &Path) -> String { } pub fn get_proxy_agent_service_path() -> PathBuf { - #[cfg(windows)] - { - service::query_service_executable_path(constants::PROXY_AGENT_SERVICE_NAME) - } - #[cfg(not(windows))] - { - // linux service hard-coded to this location - PathBuf::from(proxy_agent_shared::linux::EXE_FOLDER_PATH).join("azure-proxy-agent") - } + service::query_service_executable_path(constants::PROXY_AGENT_SERVICE_NAME) } pub fn get_proxy_agent_exe_path() -> PathBuf { diff --git a/proxy_agent_extension/src/constants.rs b/proxy_agent_extension/src/constants.rs index b25f5e71..e66ffd4f 100644 --- a/proxy_agent_extension/src/constants.rs +++ b/proxy_agent_extension/src/constants.rs @@ -13,7 +13,10 @@ pub const EXTENSION_PROCESS_NAME: &str = "ProxyAgentExt"; #[cfg(windows)] pub const EXTENSION_PROCESS_NAME: &str = "ProxyAgentExt.exe"; pub const EXTENSION_SERVICE_DISPLAY_NAME: &str = "Microsoft Azure GuestProxyAgent VMExtension"; +#[cfg(windows)] pub const PROXY_AGENT_SERVICE_NAME: &str = "GuestProxyAgent"; +#[cfg(not(windows))] +pub const PROXY_AGENT_SERVICE_NAME: &str = "azure-proxy-agent"; pub const UPDATE_TAG_FILE: &str = "update.tag"; pub const ENABLE_OPERATION: &str = "Enable"; pub const LANG_EN_US: &str = "en-US"; @@ -64,6 +67,11 @@ pub const MIN_SUPPORTED_OS_BUILD: u32 = 17763; pub const STATE_KEY_READ_PROXY_AGENT_STATUS_FILE: &str = "ReadProxyAgentStatusFile"; pub const STATE_KEY_FILE_VERSION: &str = "FileVersion"; +pub const STATE_KEY_STALE_PROXY_AGENT_STATUS: &str = "StaleProxyAgentStatus"; +pub const STATE_KEY_PARSE_TIMESTAMP_ERROR: &str = "ParseTimestampError"; + +// Max time in seconds before proxy agent status is considered stale +pub const MAX_TIME_BEFORE_STALE_STATUS_SECS: u64 = 5 * 60; pub const EBPF_CORE: &str = "EbpfCore"; pub const EBPF_EXT: &str = "NetEbpfExt"; diff --git a/proxy_agent_extension/src/service_main.rs b/proxy_agent_extension/src/service_main.rs index 4ad85e32..aeeac69a 100644 --- a/proxy_agent_extension/src/service_main.rs +++ b/proxy_agent_extension/src/service_main.rs @@ -421,6 +421,61 @@ fn report_proxy_agent_aggregate_status( } } +fn report_error_status( + status: &mut StatusObj, + status_state_obj: &mut common::StatusState, + service_state: &mut ServiceState, + error_key: &str, + error_message: String, +) { + use proxy_agent_shared::logger::LoggerLevel; + use proxy_agent_shared::telemetry::event_logger; + + status.status = status_state_obj.update_state(false); + if service_state.update_service_state_entry(error_key, constants::ERROR_STATUS, MAX_STATE_COUNT) + { + event_logger::write_event( + LoggerLevel::Info, + error_message.clone(), + "extension_substatus", + "service_main", + &logger::get_logger_key(), + ); + } + status.configurationAppliedTime = misc_helpers::get_date_time_string(); + status.substatus = { + vec![ + SubStatus { + name: constants::PLUGIN_CONNECTION_NAME.to_string(), + status: constants::TRANSITIONING_STATUS.to_string(), + code: constants::STATUS_CODE_NOT_OK, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: error_message.to_string(), + }, + }, + SubStatus { + name: constants::PLUGIN_STATUS_NAME.to_string(), + status: constants::TRANSITIONING_STATUS.to_string(), + code: constants::STATUS_CODE_NOT_OK, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: error_message.to_string(), + }, + }, + SubStatus { + name: constants::PLUGIN_FAILED_AUTH_NAME.to_string(), + status: constants::TRANSITIONING_STATUS.to_string(), + code: constants::STATUS_CODE_NOT_OK, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: error_message.to_string(), + }, + }, + ] + }; +} + fn extension_substatus( proxy_agent_aggregate_status_top_level: GuestProxyAgentAggregateStatus, proxyagent_file_version_in_extension: &String, @@ -428,173 +483,162 @@ fn extension_substatus( status_state_obj: &mut common::StatusState, service_state: &mut ServiceState, ) { + let proxy_agent_status_timestamp_result = + proxy_agent_aggregate_status_top_level.get_status_timestamp(); let proxy_agent_aggregate_status_obj = proxy_agent_aggregate_status_top_level.proxyAgentStatus; - let proxy_agent_aggregate_status_file_version = proxy_agent_aggregate_status_obj.version.to_string(); - if proxy_agent_aggregate_status_file_version != *proxyagent_file_version_in_extension { - status.status = status_state_obj.update_state(false); + + // Check for timestamp staleness or parse errors + let timestamp_error = match proxy_agent_status_timestamp_result { + Ok(status_timestamp) => { + let current_time = misc_helpers::get_current_utc_time(); + let duration = current_time - status_timestamp; + if duration > Duration::from_secs(constants::MAX_TIME_BEFORE_STALE_STATUS_SECS) { + Some((constants::STATE_KEY_STALE_PROXY_AGENT_STATUS, format!("Proxy agent aggregate status file is stale. Status timestamp: {}, Current time: {}", status_timestamp, current_time))) + } else { + None + } + } + Err(e) => Some(( + constants::STATE_KEY_PARSE_TIMESTAMP_ERROR, + format!("Error in parsing timestamp from proxy agent aggregate status file: {e}"), + )), + }; + + // Determine error for status reporting + if let Some((error_key, error_message)) = timestamp_error { + report_error_status( + status, + status_state_obj, + service_state, + error_key, + error_message, + ); + return; + } else if proxy_agent_aggregate_status_file_version != *proxyagent_file_version_in_extension { let version_mismatch_message = format!("Proxy agent aggregate status file version {proxy_agent_aggregate_status_file_version} does not match proxy agent file version in extension {proxyagent_file_version_in_extension}"); - write_state_event( - constants::STATE_KEY_FILE_VERSION, - constants::ERROR_STATUS, - version_mismatch_message.to_string(), - "extension_substatus", - "service_main", - &logger::get_logger_key(), + report_error_status( + status, + status_state_obj, service_state, + constants::STATE_KEY_FILE_VERSION, + version_mismatch_message, ); - status.configurationAppliedTime = misc_helpers::get_date_time_string(); - status.substatus = { - vec![ - SubStatus { - name: constants::PLUGIN_CONNECTION_NAME.to_string(), - status: constants::TRANSITIONING_STATUS.to_string(), - code: constants::STATUS_CODE_NOT_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: version_mismatch_message.to_string(), - }, - }, - SubStatus { - name: constants::PLUGIN_STATUS_NAME.to_string(), - status: constants::TRANSITIONING_STATUS.to_string(), - code: constants::STATUS_CODE_NOT_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: version_mismatch_message.to_string(), - }, - }, - SubStatus { - name: constants::PLUGIN_FAILED_AUTH_NAME.to_string(), - status: constants::TRANSITIONING_STATUS.to_string(), - code: constants::STATUS_CODE_NOT_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: version_mismatch_message.to_string(), - }, - }, - ] - }; + return; } // Success Status and report to status file for CRP to read from - else { - let substatus_proxy_agent_message = - match serde_json::to_string(&proxy_agent_aggregate_status_obj) { - Ok(proxy_agent_aggregate_status) => proxy_agent_aggregate_status, - Err(e) => { - let error_message = - format!("Error in serializing proxy agent aggregate status: {e}"); - logger::write(error_message.to_string()); - error_message - } - }; - let mut substatus_proxy_agent_connection_message: String; - if !proxy_agent_aggregate_status_top_level - .proxyConnectionSummary - .is_empty() - { - let proxy_agent_aggregate_connection_status_obj = get_top_proxy_connection_summary( - proxy_agent_aggregate_status_top_level - .proxyConnectionSummary - .clone(), - constants::MAX_CONNECTION_SUMMARY_LEN, - ); - match serde_json::to_string(&proxy_agent_aggregate_connection_status_obj) { - Ok(proxy_agent_aggregate_connection_status) => { - substatus_proxy_agent_connection_message = - proxy_agent_aggregate_connection_status; - } - Err(e) => { - let error_message = format!( - "Error in serializing proxy agent aggregate connection status: {e}" - ); - logger::write(error_message.to_string()); - substatus_proxy_agent_connection_message = error_message; - } + let substatus_proxy_agent_message = + match serde_json::to_string(&proxy_agent_aggregate_status_obj) { + Ok(proxy_agent_aggregate_status) => proxy_agent_aggregate_status, + Err(e) => { + let error_message = + format!("Error in serializing proxy agent aggregate status: {e}"); + logger::write(error_message.to_string()); + error_message + } + }; + let mut substatus_proxy_agent_connection_message: String; + if !proxy_agent_aggregate_status_top_level + .proxyConnectionSummary + .is_empty() + { + let proxy_agent_aggregate_connection_status_obj = get_top_proxy_connection_summary( + proxy_agent_aggregate_status_top_level + .proxyConnectionSummary + .clone(), + constants::MAX_CONNECTION_SUMMARY_LEN, + ); + match serde_json::to_string(&proxy_agent_aggregate_connection_status_obj) { + Ok(proxy_agent_aggregate_connection_status) => { + substatus_proxy_agent_connection_message = proxy_agent_aggregate_connection_status; + } + Err(e) => { + let error_message = + format!("Error in serializing proxy agent aggregate connection status: {e}"); + logger::write(error_message.to_string()); + substatus_proxy_agent_connection_message = error_message; } - } else { - logger::write("proxy connection summary is empty".to_string()); - substatus_proxy_agent_connection_message = - "proxy connection summary is empty".to_string(); } - let mut substatus_failed_auth_message: String; - if !proxy_agent_aggregate_status_top_level - .failedAuthenticateSummary - .is_empty() - { - let proxy_agent_aggregate_failed_auth_status_obj = get_top_proxy_connection_summary( - proxy_agent_aggregate_status_top_level - .failedAuthenticateSummary - .clone(), - constants::MAX_FAILED_AUTH_SUMMARY_LEN, - ); - match serde_json::to_string(&proxy_agent_aggregate_failed_auth_status_obj) { - Ok(proxy_agent_aggregate_failed_auth_status) => { - substatus_failed_auth_message = proxy_agent_aggregate_failed_auth_status; - } - Err(e) => { - let error_message = format!( - "Error in serializing proxy agent aggregate failed auth status: {e}" - ); - logger::write(error_message.to_string()); - substatus_failed_auth_message = error_message; - } + } else { + logger::write("proxy connection summary is empty".to_string()); + substatus_proxy_agent_connection_message = "proxy connection summary is empty".to_string(); + } + let mut substatus_failed_auth_message: String; + if !proxy_agent_aggregate_status_top_level + .failedAuthenticateSummary + .is_empty() + { + let proxy_agent_aggregate_failed_auth_status_obj = get_top_proxy_connection_summary( + proxy_agent_aggregate_status_top_level + .failedAuthenticateSummary + .clone(), + constants::MAX_FAILED_AUTH_SUMMARY_LEN, + ); + match serde_json::to_string(&proxy_agent_aggregate_failed_auth_status_obj) { + Ok(proxy_agent_aggregate_failed_auth_status) => { + substatus_failed_auth_message = proxy_agent_aggregate_failed_auth_status; + } + Err(e) => { + let error_message = + format!("Error in serializing proxy agent aggregate failed auth status: {e}"); + logger::write(error_message.to_string()); + substatus_failed_auth_message = error_message; } - } else { - logger::write("proxy failed auth summary is empty".to_string()); - substatus_failed_auth_message = "proxy failed auth summary is empty".to_string(); } + } else { + logger::write("proxy failed auth summary is empty".to_string()); + substatus_failed_auth_message = "proxy failed auth summary is empty".to_string(); + } - trim_proxy_agent_status_file( - &mut substatus_failed_auth_message, - &mut substatus_proxy_agent_connection_message, - constants::MAX_PROXYAGENT_CONNECTION_DATA_SIZE_IN_KB, - ); + trim_proxy_agent_status_file( + &mut substatus_failed_auth_message, + &mut substatus_proxy_agent_connection_message, + constants::MAX_PROXYAGENT_CONNECTION_DATA_SIZE_IN_KB, + ); - status.substatus = { - vec![ - SubStatus { - name: constants::PLUGIN_CONNECTION_NAME.to_string(), - status: constants::SUCCESS_STATUS.to_string(), - code: constants::STATUS_CODE_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: substatus_proxy_agent_connection_message.to_string(), - }, + status.substatus = { + vec![ + SubStatus { + name: constants::PLUGIN_CONNECTION_NAME.to_string(), + status: constants::SUCCESS_STATUS.to_string(), + code: constants::STATUS_CODE_OK, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: substatus_proxy_agent_connection_message.to_string(), }, - SubStatus { - name: constants::PLUGIN_STATUS_NAME.to_string(), - status: constants::SUCCESS_STATUS.to_string(), - code: constants::STATUS_CODE_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: substatus_proxy_agent_message.to_string(), - }, + }, + SubStatus { + name: constants::PLUGIN_STATUS_NAME.to_string(), + status: constants::SUCCESS_STATUS.to_string(), + code: constants::STATUS_CODE_OK, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: substatus_proxy_agent_message.to_string(), }, - SubStatus { - name: constants::PLUGIN_FAILED_AUTH_NAME.to_string(), - status: constants::SUCCESS_STATUS.to_string(), - code: constants::STATUS_CODE_OK, - formattedMessage: FormattedMessage { - lang: constants::LANG_EN_US.to_string(), - message: substatus_failed_auth_message.to_string(), - }, + }, + SubStatus { + name: constants::PLUGIN_FAILED_AUTH_NAME.to_string(), + status: constants::SUCCESS_STATUS.to_string(), + code: constants::STATUS_CODE_OK, + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: substatus_failed_auth_message.to_string(), }, - ] - }; - status.status = status_state_obj.update_state(true); - status.configurationAppliedTime = misc_helpers::get_date_time_string(); - write_state_event( - constants::STATE_KEY_FILE_VERSION, - constants::SUCCESS_STATUS, - substatus_proxy_agent_connection_message.to_string(), - "extension_substatus", - "service_main", - &logger::get_logger_key(), - service_state, - ); - } + }, + ] + }; + status.status = status_state_obj.update_state(true); + status.configurationAppliedTime = misc_helpers::get_date_time_string(); + write_state_event( + constants::STATE_KEY_FILE_VERSION, + constants::SUCCESS_STATUS, + substatus_proxy_agent_connection_message.to_string(), + "extension_substatus", + "service_main", + &logger::get_logger_key(), + service_state, + ); } fn trim_proxy_agent_status_file( @@ -921,6 +965,12 @@ mod tests { proxyConnectionSummary: vec![proxy_connection_summary_obj], failedAuthenticateSummary: vec![proxy_failedAuthenticateSummary_obj], }; + let result = toplevel_status.get_status_timestamp(); + assert!( + result.is_ok(), + "Status timestamp parse expected Ok result, got Err: {:?}", + result.err() + ); let mut status = StatusObj { name: constants::PLUGIN_NAME.to_string(), @@ -1099,4 +1149,187 @@ mod tests { assert_eq!(connection_summary, orig_conn); assert_eq!(failed_auth_summary, orig_auth); } + + #[test] + fn test_stale_status_timestamp_greater_than_5_minutes() { + let proxy_agent_status_obj = ProxyAgentStatus { + version: "1.0.0".to_string(), + status: OverallState::SUCCESS, + monitorStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + keyLatchStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + ebpfProgramStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + proxyListenerStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + telemetryLoggerStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + proxyConnectionsCount: 1, + }; + + let proxy_connection_summary_obj = ProxyConnectionSummary { + userName: "test".to_string(), + ip: "test".to_string(), + port: 1, + processCmdLine: "test".to_string(), + responseStatus: "test".to_string(), + count: 1, + processFullPath: Some("test".to_string()), + userGroups: Some(vec!["test".to_string()]), + }; + + // Create a timestamp that is 10 minutes old (greater than 5 minutes) + // Use a fixed old timestamp format to simulate staleness + let stale_timestamp = "2024-01-01T00:00:00Z".to_string(); + + let toplevel_status = GuestProxyAgentAggregateStatus { + timestamp: stale_timestamp, + proxyAgentStatus: proxy_agent_status_obj, + proxyConnectionSummary: vec![proxy_connection_summary_obj.clone()], + failedAuthenticateSummary: vec![proxy_connection_summary_obj], + }; + + let mut status = StatusObj { + name: constants::PLUGIN_NAME.to_string(), + operation: constants::ENABLE_OPERATION.to_string(), + configurationAppliedTime: misc_helpers::get_date_time_string(), + code: constants::STATUS_CODE_OK, + status: constants::SUCCESS_STATUS.to_string(), + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: "Update Proxy Agent command output successfully".to_string(), + }, + substatus: Default::default(), + }; + + let mut status_state_obj = super::common::StatusState::new(); + let proxyagent_file_version_in_extension: &String = &"1.0.0".to_string(); + let mut service_state = super::service_state::ServiceState::default(); + + super::extension_substatus( + toplevel_status, + proxyagent_file_version_in_extension, + &mut status, + &mut status_state_obj, + &mut service_state, + ); + + // Verify that status is not successful due to stale timestamp + assert_ne!(status.status, constants::SUCCESS_STATUS.to_string()); + assert_eq!(status.substatus.len(), 3); + assert_eq!( + status.substatus[0].status, + constants::TRANSITIONING_STATUS.to_string() + ); + assert_eq!(status.substatus[0].code, constants::STATUS_CODE_NOT_OK); + assert!(status.substatus[0] + .formattedMessage + .message + .contains("stale")); + } + + #[test] + fn test_fresh_status_timestamp_within_5_minutes() { + let proxy_agent_status_obj = ProxyAgentStatus { + version: "1.0.0".to_string(), + status: OverallState::SUCCESS, + monitorStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + keyLatchStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + ebpfProgramStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + proxyListenerStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + telemetryLoggerStatus: ProxyAgentDetailStatus { + status: ModuleState::RUNNING, + message: "test".to_string(), + states: None, + }, + proxyConnectionsCount: 1, + }; + + let proxy_connection_summary_obj = ProxyConnectionSummary { + userName: "test".to_string(), + ip: "test".to_string(), + port: 1, + processCmdLine: "test".to_string(), + responseStatus: "test".to_string(), + count: 1, + processFullPath: Some("test".to_string()), + userGroups: Some(vec!["test".to_string()]), + }; + + // Create a fresh timestamp (current time) + let fresh_timestamp = misc_helpers::get_date_time_string(); + + let toplevel_status = GuestProxyAgentAggregateStatus { + timestamp: fresh_timestamp, + proxyAgentStatus: proxy_agent_status_obj, + proxyConnectionSummary: vec![proxy_connection_summary_obj.clone()], + failedAuthenticateSummary: vec![proxy_connection_summary_obj], + }; + + let mut status = StatusObj { + name: constants::PLUGIN_NAME.to_string(), + operation: constants::ENABLE_OPERATION.to_string(), + configurationAppliedTime: misc_helpers::get_date_time_string(), + code: constants::STATUS_CODE_OK, + status: constants::SUCCESS_STATUS.to_string(), + formattedMessage: FormattedMessage { + lang: constants::LANG_EN_US.to_string(), + message: "Update Proxy Agent command output successfully".to_string(), + }, + substatus: Default::default(), + }; + + let mut status_state_obj = super::common::StatusState::new(); + let proxyagent_file_version_in_extension: &String = &"1.0.0".to_string(); + let mut service_state = super::service_state::ServiceState::default(); + + super::extension_substatus( + toplevel_status, + proxyagent_file_version_in_extension, + &mut status, + &mut status_state_obj, + &mut service_state, + ); + + // Verify that status is successful with fresh timestamp + assert_eq!(status.status, constants::SUCCESS_STATUS.to_string()); + assert_eq!(status.substatus.len(), 3); + assert_eq!( + status.substatus[0].status, + constants::SUCCESS_STATUS.to_string() + ); + assert_eq!(status.substatus[0].code, constants::STATUS_CODE_OK); + } } diff --git a/proxy_agent_extension/src/windows/HandlerManifest.json b/proxy_agent_extension/src/windows/HandlerManifest.json index 07dafeea..16d82ca6 100644 --- a/proxy_agent_extension/src/windows/HandlerManifest.json +++ b/proxy_agent_extension/src/windows/HandlerManifest.json @@ -18,9 +18,7 @@ "memoryQuotaMB": 75 }, { - "name": "GuestProxyAgent", - "cpuQuotaPercentage": 15, - "memoryQuotaMB": 17 + "name": "GuestProxyAgent" }] } }] \ No newline at end of file diff --git a/proxy_agent_setup/Cargo.toml b/proxy_agent_setup/Cargo.toml index 0bd7f0b7..ccc017b0 100644 --- a/proxy_agent_setup/Cargo.toml +++ b/proxy_agent_setup/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy_agent_setup" -version = "1.0.38" +version = "1.0.39" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/proxy_agent_shared/Cargo.toml b/proxy_agent_shared/Cargo.toml index f97ba1df..d8137078 100644 --- a/proxy_agent_shared/Cargo.toml +++ b/proxy_agent_shared/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "proxy_agent_shared" -version = "1.0.38" +version = "1.0.39" edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -8,7 +8,7 @@ edition = "2021" [dependencies] concurrent-queue = "2.1.0" # for event queue once_cell = "1.17.0" # use Lazy -time = { version = "0.3.30", features = ["formatting"] } +time = { version = "0.3.30", features = ["formatting", "parsing"] } thread-id = "4.0.0" serde = "1.0.152" serde_derive = "1.0.152" @@ -54,10 +54,12 @@ features = [ "Win32_System_Diagnostics_Debug", "Win32_System_SystemInformation", "Win32_Storage_FileSystem", + "Win32_System_JobObjects", ] [target.'cfg(not(windows))'.dependencies] os_info = "3.7.0" # read Linux OS version and arch +sysinfo = "0.30.13" # read CPU & RAM information for Linux # For MUSL targets (Linux MUSL) [target.'cfg(all(target_env = "musl", not(target_os = "windows")))'.dependencies] diff --git a/proxy_agent_shared/src/common_state.rs b/proxy_agent_shared/src/common_state.rs new file mode 100644 index 00000000..e5ab29a5 --- /dev/null +++ b/proxy_agent_shared/src/common_state.rs @@ -0,0 +1,156 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +//! This module contains the logic to get and update common states. + +use crate::result::Result; +use crate::{error::Error, logger::logger_manager, telemetry::event_reader::VmMetaData}; +use tokio::sync::{mpsc, oneshot}; + +pub const SECURE_KEY_GUID: &str = "key_guid"; +pub const SECURE_KEY_VALUE: &str = "key_value"; + +enum CommonStateAction { + SetVmMetaData { + vm_meta_data: Option, + response: oneshot::Sender<()>, + }, + GetVmMetaData { + response: oneshot::Sender>, + }, + SetState { + key: String, + value: String, + response: oneshot::Sender<()>, + }, + GetState { + key: String, + response: oneshot::Sender>, + }, +} + +#[derive(Clone, Debug)] +pub struct CommonState(mpsc::Sender); + +impl CommonState { + pub fn start_new() -> Self { + let (sender, mut receiver) = mpsc::channel(100); + tokio::spawn(async move { + let mut vm_meta_data: Option = None; + let mut states: std::collections::HashMap = + std::collections::HashMap::new(); + loop { + match receiver.recv().await { + Some(CommonStateAction::SetVmMetaData { + vm_meta_data: meta_data, + response, + }) => { + vm_meta_data = meta_data.clone(); + if response.send(()).is_err() { + logger_manager::write_warn(format!( + "Failed to send response to CommonStateAction::SetVmMetaData '{meta_data:?}'" + )); + } + } + Some(CommonStateAction::GetVmMetaData { response }) => { + if let Err(meta_data) = response.send(vm_meta_data.clone()) { + logger_manager::write_warn(format!( + "Failed to send response to CommonStateAction::GetVmMetaData '{meta_data:?}'" + )); + } + } + Some(CommonStateAction::SetState { + key, + value, + response, + }) => { + states.insert(key.clone(), value.clone()); + if response.send(()).is_err() { + logger_manager::write_warn(format!( + "Failed to send response to CommonStateAction::SetState '{key}':'{value}'" + )); + } + } + Some(CommonStateAction::GetState { key, response }) => { + let value = states.get(&key).cloned(); + if let Err(v) = response.send(value) { + logger_manager::write_warn(format!( + "Failed to send response to CommonStateAction::GetState '{key}':'{v:?}'" + )); + } + } + None => { + break; + } + } + } + }); + + Self(sender) + } + + pub async fn set_vm_meta_data(&self, vm_meta_data: Option) -> Result<()> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(CommonStateAction::SetVmMetaData { + vm_meta_data, + response, + }) + .await + .map_err(|e| { + Error::SendError( + "CommonStateAction::SetVmMetaData".to_string(), + e.to_string(), + ) + })?; + receiver + .await + .map_err(|e| Error::RecvError("CommonStateAction::SetVmMetaData".to_string(), e)) + } + + pub async fn get_vm_meta_data(&self) -> Result> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(CommonStateAction::GetVmMetaData { response }) + .await + .map_err(|e| { + Error::SendError( + "CommonStateAction::GetVmMetaData".to_string(), + e.to_string(), + ) + })?; + receiver + .await + .map_err(|e| Error::RecvError("CommonStateAction::GetVmMetaData".to_string(), e)) + } + + pub async fn set_state(&self, key: String, value: String) -> Result<()> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(CommonStateAction::SetState { + key, + value, + response, + }) + .await + .map_err(|e| { + Error::SendError("CommonStateAction::SetState".to_string(), e.to_string()) + })?; + receiver + .await + .map_err(|e| Error::RecvError("CommonStateAction::SetState".to_string(), e)) + } + + pub async fn get_state(&self, key: String) -> Result> { + let (response, receiver) = oneshot::channel(); + self.0 + .send(CommonStateAction::GetState { key, response }) + .await + .map_err(|e| { + Error::SendError("CommonStateAction::GetState".to_string(), e.to_string()) + })?; + receiver + .await + .map_err(|e| Error::RecvError("CommonStateAction::GetState".to_string(), e)) + } +} diff --git a/proxy_agent_shared/src/current_info.rs b/proxy_agent_shared/src/current_info.rs new file mode 100644 index 00000000..e9e66544 --- /dev/null +++ b/proxy_agent_shared/src/current_info.rs @@ -0,0 +1,78 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +use crate::misc_helpers; +use once_cell::sync::Lazy; + +#[cfg(not(windows))] +use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; + +#[cfg(windows)] +use super::windows; + +static CURRENT_SYS_INFO: Lazy<(u64, usize)> = Lazy::new(|| { + #[cfg(windows)] + { + let ram_in_mb = match windows::get_memory_in_mb() { + Ok(ram) => ram, + Err(e) => { + crate::logger::logger_manager::write_err(format!("get_memory_in_mb failed: {e}")); + 0 + } + }; + let cpu_count = windows::get_processor_count(); + (ram_in_mb, cpu_count) + } + #[cfg(not(windows))] + { + let sys = System::new_with_specifics( + RefreshKind::new() + .with_memory(MemoryRefreshKind::everything()) + .with_cpu(CpuRefreshKind::everything()), + ); + let ram = sys.total_memory(); + let ram_in_mb = ram / 1024 / 1024; + let cpu_count = sys.cpus().len(); + (ram_in_mb, cpu_count) + } +}); + +static CURRENT_OS_INFO: Lazy<(String, String)> = Lazy::new(|| { + //arch + let arch = misc_helpers::get_processor_arch(); + // os + let os = misc_helpers::get_long_os_version(); + (arch, os) +}); + +pub fn get_ram_in_mb() -> u64 { + CURRENT_SYS_INFO.0 +} + +pub fn get_cpu_count() -> usize { + CURRENT_SYS_INFO.1 +} + +pub fn get_cpu_arch() -> String { + CURRENT_OS_INFO.0.to_string() +} + +pub fn get_long_os_version() -> String { + CURRENT_OS_INFO.1.to_string() +} + +#[cfg(test)] +mod tests { + #[test] + fn get_system_info_tests() { + let ram = super::get_ram_in_mb(); + assert!(ram > 100, "total ram must great than 100MB"); + let cpu_count = super::get_cpu_count(); + assert!( + cpu_count >= 1, + "total cpu count must great than or equal to 1" + ); + let cpu_arch = super::get_cpu_arch(); + assert_ne!("unknown", cpu_arch, "cpu arch cannot be 'unknown'"); + } +} diff --git a/proxy_agent_shared/src/error.rs b/proxy_agent_shared/src/error.rs index c529508f..b933a999 100644 --- a/proxy_agent_shared/src/error.rs +++ b/proxy_agent_shared/src/error.rs @@ -47,6 +47,15 @@ pub enum Error { #[error("{0} command: {1}")] Command(CommandErrorType, String), + + #[error("Failed to send '{0}' action response with error {1}")] + SendError(String, String), + + #[error("Failed to receive '{0}' action response with error {1}")] + RecvError(String, tokio::sync::oneshot::error::RecvError), + + #[error("Parse datetime string error: {0}")] + ParseDateTimeStringError(String), } #[derive(Debug, thiserror::Error)] @@ -80,8 +89,8 @@ pub enum HyperErrorType { #[error("Failed to build request with error: {0}")] RequestBuilder(String), - #[error("Failed to receive the request body with error: {0}")] - RequestBody(String), + #[error("Failed to receive the body with error: {0}")] + ReceiveBody(String), #[error("Failed to get response from {0}, status code: {1}")] ServerError(String, StatusCode), diff --git a/proxy_agent_shared/src/lib.rs b/proxy_agent_shared/src/lib.rs index 3d75e50a..3961bb6a 100644 --- a/proxy_agent_shared/src/lib.rs +++ b/proxy_agent_shared/src/lib.rs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +pub mod common_state; +pub mod current_info; pub mod error; #[cfg(windows)] pub mod etw; @@ -14,6 +16,7 @@ pub mod secrets_redactor; pub mod service; pub mod telemetry; pub mod version; + #[cfg(windows)] pub mod windows; diff --git a/proxy_agent_shared/src/linux.rs b/proxy_agent_shared/src/linux.rs index a95f82c6..b09cbe0b 100644 --- a/proxy_agent_shared/src/linux.rs +++ b/proxy_agent_shared/src/linux.rs @@ -134,6 +134,46 @@ pub fn compute_signature(hex_encoded_key: &str, input_to_sign: &[u8]) -> Result< } } +/// Set the CPU quota for a service. +/// The CPU quota is set in percentage of the one CPU time available. +/// For example, if the total CPU time available is 100%, setting the CPU quota to 50% will limit the service to use up to 50% of the total CPU time available. +pub fn set_cpu_quota(service_name: &str, cpu_quota: u16) -> Result<()> { + misc_helpers::execute_command( + "systemctl", + vec![ + "set-property", + service_name, + &format!("CPUQuota={cpu_quota}%"), + ], + -1, + )?; + + Ok(()) +} + +#[derive(Debug)] +pub struct MemStatus { + pub vmrss_kb: Option, + pub vmhwm_kb: Option, +} + +pub fn read_proc_memory_status(pid: u32) -> Result { + let s = fs::read_to_string(format!("/proc/{pid}/status"))?; + let mut vmrss_kb = None; + let mut vmhwm_kb = None; + for line in s.lines() { + if line.starts_with("VmRSS:") { + // Format: "VmRSS:\t 12345 kB" + let val = line.split_whitespace().nth(1).and_then(|x| x.parse().ok()); + vmrss_kb = val; + } else if line.starts_with("VmHWM:") { + let val = line.split_whitespace().nth(1).and_then(|x| x.parse().ok()); + vmhwm_kb = val; + } + } + Ok(MemStatus { vmrss_kb, vmhwm_kb }) +} + #[cfg(test)] mod tests { use crate::misc_helpers; diff --git a/proxy_agent_shared/src/misc_helpers.rs b/proxy_agent_shared/src/misc_helpers.rs index d96d501e..b6ee12b2 100644 --- a/proxy_agent_shared/src/misc_helpers.rs +++ b/proxy_agent_shared/src/misc_helpers.rs @@ -13,7 +13,7 @@ use std::{ process::Command, }; use thread_id; -use time::{format_description, OffsetDateTime}; +use time::{format_description, OffsetDateTime, PrimitiveDateTime}; #[cfg(windows)] use super::windows; @@ -57,6 +57,58 @@ pub fn get_date_time_unix_nano() -> i128 { OffsetDateTime::now_utc().unix_timestamp_nanos() } +pub fn get_current_utc_time() -> OffsetDateTime { + OffsetDateTime::now_utc() +} + +/// Parse a datetime string to OffsetDateTime (UTC) +/// Supports multiple formats: +/// - ISO 8601 with/without 'Z': "YYYY-MM-DDTHH:MM:SS" or "YYYY-MM-DDTHH:MM:SSZ" +/// - With milliseconds: "YYYY-MM-DDTHH:MM:SS.mmm" +/// # Arguments +/// * `datetime_str` - A datetime string to parse +/// # Returns +/// A Result containing the parsed OffsetDateTime (UTC) or an error if parsing fails +/// # Example +/// ```rust +/// use proxy_agent_shared::misc_helpers; +/// let datetime1 = misc_helpers::parse_date_time_string("2024-01-15T10:30:45Z").unwrap(); +/// let datetime2 = misc_helpers::parse_date_time_string("2024-01-15T10:30:45").unwrap(); +/// let datetime3 = misc_helpers::parse_date_time_string("2024-01-15T10:30:45.123").unwrap(); +/// ``` +pub fn parse_date_time_string(datetime_str: &str) -> Result { + // Remove the 'Z' suffix if present + let datetime_str_trimmed = datetime_str.trim_end_matches('Z'); + + // Try parsing with milliseconds first + let date_format_with_millis = + format_description::parse("[year]-[month]-[day]T[hour]:[minute]:[second].[subsecond]") + .map_err(|e| { + Error::ParseDateTimeStringError(format!("Failed to parse date format: {e}")) + })?; + + if let Ok(primitive_datetime) = + PrimitiveDateTime::parse(datetime_str_trimmed, &date_format_with_millis) + { + return Ok(primitive_datetime.assume_utc()); + } + + // Fall back to parsing without milliseconds + let date_format = format_description::parse("[year]-[month]-[day]T[hour]:[minute]:[second]") + .map_err(|e| { + Error::ParseDateTimeStringError(format!("Failed to parse date format: {e}")) + })?; + + let primitive_datetime = + PrimitiveDateTime::parse(datetime_str_trimmed, &date_format).map_err(|e| { + Error::ParseDateTimeStringError(format!( + "Failed to parse datetime string '{datetime_str}': {e}" + )) + })?; + + Ok(primitive_datetime.assume_utc()) +} + pub fn try_create_folder(dir: &Path) -> Result<()> { match dir.try_exists() { Ok(exists) => { @@ -150,12 +202,44 @@ pub fn get_file_name(path: &Path) -> String { } } +/// It is the version from Cargo.toml of proxy_agent_shared crate const VERSION: &str = env!("CARGO_PKG_VERSION"); pub fn get_current_version() -> String { VERSION.to_string() } +/// Get the current executable version, +/// trying to read version from file properties on Windows, +/// otherwise fallback to Cargo.toml version. +/// # Returns +/// A string representing the current executable version +pub fn get_current_exe_version() -> String { + #[cfg(windows)] + { + match try_get_current_exe_version() { + Ok(version) => version, + Err(e) => { + eprintln!( + "Failed to get current exe version from file properties, fallback to Cargo.toml version: {e}", + ); + get_current_version() + } + } + } + #[cfg(not(windows))] + { + get_current_version() + } +} + +#[cfg(windows)] +pub fn try_get_current_exe_version() -> Result { + let exe_path = std::env::current_exe()?; + let version = windows::get_file_product_version(&exe_path)?; + Ok(version.to_string()) +} + pub fn get_files(dir: &Path) -> Result> { // search files let mut files: Vec = Vec::new(); @@ -329,6 +413,15 @@ pub use linux::compute_signature; #[cfg(windows)] pub use windows::compute_signature; +// replace xml escape characters +pub fn xml_escape(s: String) -> String { + s.replace('&', "&") + .replace('\'', "'") + .replace('"', """) + .replace('<', "<") + .replace('>', ">") +} + #[cfg(test)] mod tests { use serde_derive::{Deserialize, Serialize}; @@ -585,4 +678,104 @@ mod tests { "Error does not contains the invalid key" ) } + + #[test] + fn get_current_exe_version_test() { + let version = super::get_current_exe_version(); + println!("get_current_exe_version: {version}"); + assert!( + !version.is_empty(), + "get_current_exe_version should return a non-empty string" + ); + + let cargo_version = super::get_current_version(); + #[cfg(windows)] + { + // "%UserProfile%\\.cargo\\bin\\rustup.exe" does not have file version info + // so get_current_exe_version uses the version from current Cargo.toml file + assert_eq!( + cargo_version, version, + "get_current_exe_version should return the same version as Cargo.toml as '%UserProfile%\\.cargo\\bin\\rustup.exe' does not have file version info" + ); + } + #[cfg(not(windows))] + { + assert_eq!( + cargo_version, version, + "get_current_exe_version should return the same version as Cargo.toml in Linux" + ); + } + } + + #[test] + fn parse_date_time_string_test() { + // Test parsing with milliseconds + let datetime_str = "2024-01-15T10:30:45.123"; + let result = super::parse_date_time_string(datetime_str); + assert!( + result.is_ok(), + "Failed to parse datetime string with milliseconds" + ); + + let datetime = result.unwrap(); + assert_eq!(datetime.year(), 2024); + assert_eq!(datetime.month() as u8, 1); + assert_eq!(datetime.day(), 15); + assert_eq!(datetime.hour(), 10); + assert_eq!(datetime.minute(), 30); + assert_eq!(datetime.second(), 45); + assert_eq!(datetime.millisecond(), 123); + + // Test parsing with 'Z' suffix + let datetime_str = "2024-01-15T10:30:45Z"; + let result = super::parse_date_time_string(datetime_str); + assert!( + result.is_ok(), + "Failed to parse datetime string with Z suffix" + ); + + let datetime = result.unwrap(); + assert_eq!(datetime.year(), 2024); + assert_eq!(datetime.month() as u8, 1); + assert_eq!(datetime.day(), 15); + assert_eq!(datetime.hour(), 10); + assert_eq!(datetime.minute(), 30); + assert_eq!(datetime.second(), 45); + + // Test parsing without 'Z' suffix + let datetime_str_without_z = "2024-01-15T10:30:45"; + let result = super::parse_date_time_string(datetime_str_without_z); + assert!(result.is_ok(), "Should parse datetime string without 'Z'"); + + // Test round-trip with milliseconds format + let original_datetime_str = super::get_date_time_string_with_milliseconds(); + let result = super::parse_date_time_string(&original_datetime_str); + assert!( + result.is_ok(), + "Failed to parse datetime string with milliseconds" + ); + + // Test round-trip with standard format + let original_datetime_str = super::get_date_time_string(); + let result = super::parse_date_time_string(&original_datetime_str); + assert!( + result.is_ok(), + "Failed to parse datetime string without milliseconds" + ); + + // Test invalid format + let invalid_datetime_str = "2024-01-15 10:30:45"; // space instead of 'T' + let result = super::parse_date_time_string(invalid_datetime_str); + assert!( + result.is_err(), + "Should fail to parse invalid datetime string" + ); + + let invalid_datetime_str = "2024-01-15T10:30"; // without seconds + let result = super::parse_date_time_string(invalid_datetime_str); + assert!( + result.is_err(), + "Should fail to parse invalid datetime string" + ); + } } diff --git a/proxy_agent_shared/src/proxy_agent_aggregate_status.rs b/proxy_agent_shared/src/proxy_agent_aggregate_status.rs index 5a3f2e6a..6a6a2971 100644 --- a/proxy_agent_shared/src/proxy_agent_aggregate_status.rs +++ b/proxy_agent_shared/src/proxy_agent_aggregate_status.rs @@ -3,6 +3,7 @@ use crate::misc_helpers; use serde_derive::{Deserialize, Serialize}; use std::{collections::HashMap, path::PathBuf}; +use time::OffsetDateTime; #[cfg(windows)] const PROXY_AGENT_AGGREGATE_STATUS_FOLDER: &str = "%SYSTEMDRIVE%\\WindowsAzure\\ProxyAgent\\Logs\\"; @@ -88,3 +89,9 @@ pub struct GuestProxyAgentAggregateStatus { pub proxyConnectionSummary: Vec, pub failedAuthenticateSummary: Vec, } + +impl GuestProxyAgentAggregateStatus { + pub fn get_status_timestamp(&self) -> crate::result::Result { + misc_helpers::parse_date_time_string(&self.timestamp) + } +} diff --git a/proxy_agent_shared/src/service.rs b/proxy_agent_shared/src/service.rs index 88e57658..53c2cd6f 100644 --- a/proxy_agent_shared/src/service.rs +++ b/proxy_agent_shared/src/service.rs @@ -94,44 +94,46 @@ pub fn update_service( } } -pub fn query_service_executable_path(_service_name: &str) -> PathBuf { +pub fn query_service_executable_path(service_name: &str) -> PathBuf { #[cfg(windows)] { - match windows_service::query_service_config(_service_name) { + match windows_service::query_service_config(service_name) { Ok(service_config) => { - logger_manager::write_info( - format!("Service {_service_name} successfully queried",), - ); + logger_manager::write_info(format!("Service {service_name} successfully queried",)); service_config.executable_path.to_path_buf() } Err(e) => { - logger_manager::write_info(format!("Service {_service_name} query failed: {e}",)); - eprintln!("Service {_service_name} query failed: {e}"); + logger_manager::write_info(format!("Service {service_name} query failed: {e}",)); + eprintln!("Service {service_name} query failed: {e}"); PathBuf::new() } } } #[cfg(not(windows))] { - println!("Not support query service on this platform"); - PathBuf::new() + match linux_service::query_service_executable_path(service_name) { + Ok(path) => path, + Err(e) => { + eprintln!("Service {service_name} query failed: {e}"); + PathBuf::new() + } + } } } -pub fn check_service_installed(_service_name: &str) -> (bool, String) { - let message; +pub fn check_service_installed(service_name: &str) -> (bool, String) { #[cfg(windows)] { - match windows_service::query_service_config(_service_name) { - Ok(_service_config) => { - message = format!( - "check_service_installed: service: {_service_name} successfully queried.", + match windows_service::query_service_config(service_name) { + Ok(_) => { + let message = format!( + "check_service_installed: service: {service_name} successfully queried.", ); (true, message) } Err(e) => { - message = format!( - "check_service_installed: service: {_service_name} unsuccessfully queried with error: {e}" + let message = format!( + "check_service_installed: service: {service_name} unsuccessfully queried with error: {e}" ); (false, message) } @@ -139,8 +141,7 @@ pub fn check_service_installed(_service_name: &str) -> (bool, String) { } #[cfg(not(windows))] { - message = "Not support query service on this platform".to_string(); - (false, message) + linux_service::check_service_installed(service_name) } } diff --git a/proxy_agent_shared/src/service/linux_service.rs b/proxy_agent_shared/src/service/linux_service.rs index 84efd3e8..cfa632a4 100644 --- a/proxy_agent_shared/src/service/linux_service.rs +++ b/proxy_agent_shared/src/service/linux_service.rs @@ -1,5 +1,7 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +use crate::error::CommandErrorType; +use crate::error::Error; use crate::linux; use crate::logger::logger_manager; use crate::misc_helpers; @@ -17,14 +19,23 @@ pub fn stop_service(service_name: &str) -> Result<()> { Ok(()) } +/// Starts the specified service with `systemctl start` command. +/// If the command fails, an Error is returned. pub fn start_service(service_name: &str) -> Result<()> { let output = misc_helpers::execute_command("systemctl", vec!["start", service_name], -1)?; - logger_manager::write_info(format!( - "start_service: {} result: {}", - service_name, - output.message() - )); - Ok(()) + if output.is_success() { + logger_manager::write_info(format!("Service {service_name} started successfully")); + Ok(()) + } else { + let error_message = format!( + "start_service: {service_name} failed with error: {}", + output.message() + ); + Err(Error::Command( + CommandErrorType::CommandName("systemctl start".to_string()), + error_message, + )) + } } pub fn install_or_update_service(service_name: &str) -> Result<()> { @@ -96,3 +107,68 @@ fn delete_service_config_file(service_name: &str) -> Result<()> { } Ok(()) } + +/// Queries the executable path of the specified service. +/// It uses systemctl show command to get the ExecStart property. +/// If the command fails or the output cannot be parsed, an Error is returned. +pub fn query_service_executable_path(service_name: &str) -> Result { + let output = misc_helpers::execute_command( + "systemctl", + vec!["show", "--property=ExecStart", service_name], + -1, + )?; + + if !output.is_success() { + let error_message = format!( + "query_service_executable_path: {service_name} failed with error: {}", + output.message() + ); + return Err(Error::Command( + CommandErrorType::CommandName("systemctl show --property=ExecStart".to_string()), + error_message, + )); + } + + let stdout = output.stdout(); + logger_manager::write_info(format!( + "query_service_executable_path: {service_name} result: {stdout}", + )); + + // Parse ExecStart output + // Format: ExecStart={ path=/path/to/executable ; argv[]=/path/to/executable [args] ; ... } + if let Some(path_start) = stdout.find("path=") { + let path_str = &stdout[path_start + 5..]; + if let Some(semicolon_pos) = path_str.find(" ;") { + let executable_path = path_str[..semicolon_pos].trim(); + return Ok(PathBuf::from(executable_path)); + } + } + + let error_message = format!( + "query_service_executable_path: {service_name} failed to parse ExecStart output: {stdout}" + ); + Err(Error::Command( + CommandErrorType::CommandName("systemctl show --property=ExecStart".to_string()), + error_message, + )) +} + +/// Check if the service is installed by verifying the existence of its unit file. +pub fn check_service_installed(service_name: &str) -> (bool, String) { + let config_file_path = + PathBuf::from(linux::SERVICE_CONFIG_FOLDER_PATH).join(format!("{service_name}.service")); + + if config_file_path.exists() && config_file_path.is_file() { + let message = + format!("check_service_installed: service: {service_name} successfully queried."); + logger_manager::write_info(message.clone()); + (true, message) + } else { + let message = format!( + "check_service_installed: service: {service_name} unit file not found at '{}'", + misc_helpers::path_to_string(&config_file_path) + ); + logger_manager::write_info(message.clone()); + (false, message) + } +} diff --git a/proxy_agent_shared/src/service/windows_service.rs b/proxy_agent_shared/src/service/windows_service.rs index 975e7803..c2a86de2 100644 --- a/proxy_agent_shared/src/service/windows_service.rs +++ b/proxy_agent_shared/src/service/windows_service.rs @@ -56,7 +56,11 @@ pub async fn start_service_with_retry( } tokio::time::sleep(duration).await; } - Ok(()) + + // reach here means all retries exhausted + Err(Error::Io(std::io::Error::other( + "start_service_with_retry exceeded maximum retry attempts".to_string(), + ))) } async fn start_service_once(service_name: &str) -> Result { diff --git a/proxy_agent_shared/src/telemetry.rs b/proxy_agent_shared/src/telemetry.rs index e7090938..77be8ee0 100644 --- a/proxy_agent_shared/src/telemetry.rs +++ b/proxy_agent_shared/src/telemetry.rs @@ -1,7 +1,9 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT pub mod event_logger; +pub mod event_reader; pub mod span; +pub mod telemetry_event; use crate::misc_helpers; use serde_derive::{Deserialize, Serialize}; @@ -24,7 +26,7 @@ impl Event { Event { EventLevel: level, Message: message, - Version: misc_helpers::get_current_version(), + Version: misc_helpers::get_current_exe_version(), TaskName: task_name, EventPid: std::process::id().to_string(), EventTid: misc_helpers::get_thread_identity(), diff --git a/proxy_agent/src/telemetry/event_reader.rs b/proxy_agent_shared/src/telemetry/event_reader.rs similarity index 68% rename from proxy_agent/src/telemetry/event_reader.rs rename to proxy_agent_shared/src/telemetry/event_reader.rs index 86291fef..ac3d526f 100644 --- a/proxy_agent/src/telemetry/event_reader.rs +++ b/proxy_agent_shared/src/telemetry/event_reader.rs @@ -3,59 +3,30 @@ //! This module contains the logic to read the telemetry event files and send them to the wire server. //! The telemetry event files are written by the event_logger module. -//! Example -//! ```rust -//! use proxy_agent::telemetry::event_reader; -//! use proxy_agent::shared_state::agent_status::wrapper::AgentStatusSharedState; -//! use proxy_agent::shared_state::key_keeper::wrapper::KeyKeeperSharedState; -//! use proxy_agent::shared_state::telemetry::wrapper::TelemetrySharedState; -//! use std::path::PathBuf; -//! use std::time::Duration; -//! use tokio_util::sync::CancellationToken; -//! -//! // start the telemetry event reader with the shared state -//! let agent_status_shared_state = AgentStatusSharedState::start_new(); -//! let key_keeper_shared_state = KeyKeeperSharedState::start_new(); -//! let telemetry_shared_state = TelemetrySharedState::start_new(); -//! let cancellation_token = CancellationToken::new(); -//! -//! let dir_path = PathBuf::from("/tmp"); -//! let interval = Some(Duration::from_secs(300)); -//! let delay_start = false; -//! let server_ip = None; -//! let server_port = None; -//! let event_reader = event_reader::EventReader::new( -//! dir_path, -//! delay_start, -//! cancellation_token, -//! key_keeper_shared_state, -//! telemetry_shared_state, -//! agent_status_shared_state, -//! ); -//! -//! tokio::spawn(event_reader.start(interval, server_ip, server_port)); -//! -//! // stop the telemetry event reader -//! cancellation_token.cancel(); -//! ``` - -use super::telemetry_event::TelemetryData; -use super::telemetry_event::TelemetryEvent; -use crate::common::{constants, logger, result::Result}; -use crate::shared_state::agent_status_wrapper::AgentStatusModule; -use crate::shared_state::agent_status_wrapper::AgentStatusSharedState; -use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; -use crate::shared_state::telemetry_wrapper::TelemetrySharedState; -use proxy_agent_shared::host_clients::imds_client::ImdsClient; -use proxy_agent_shared::host_clients::wire_server_client::WireServerClient; -use proxy_agent_shared::misc_helpers; -use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; -use proxy_agent_shared::telemetry::Event; + +use crate::common_state; +use crate::common_state::CommonState; +use crate::host_clients::imds_client::ImdsClient; +use crate::host_clients::wire_server_client::WireServerClient; +use crate::logger::logger_manager; +use crate::misc_helpers; +use crate::result::Result; +use crate::telemetry::telemetry_event::TelemetryData; +use crate::telemetry::telemetry_event::TelemetryEvent; +use crate::telemetry::Event; use std::fs::remove_file; use std::path::PathBuf; use std::time::Duration; use tokio_util::sync::CancellationToken; +#[cfg(test)] +const EMPTY_GUID: &str = "00000000-0000-0000-0000-000000000000"; + +const WIRE_SERVER_IP: &str = "168.63.129.16"; +const WIRE_SERVER_PORT: u16 = 80u16; +const IMDS_IP: &str = "169.254.169.254"; +const IMDS_PORT: u16 = 80u16; + /// VmMetaData contains the metadata of the VM. /// The metadata is used to identify the VM and the image origin. /// It will be part of the telemetry data send to the wire server. @@ -76,13 +47,13 @@ impl VmMetaData { #[cfg(test)] pub fn empty() -> Self { VmMetaData { - container_id: constants::EMPTY_GUID.to_string(), - tenant_name: constants::EMPTY_GUID.to_string(), - role_name: constants::EMPTY_GUID.to_string(), - role_instance_name: constants::EMPTY_GUID.to_string(), - subscription_id: constants::EMPTY_GUID.to_string(), - resource_group_name: constants::EMPTY_GUID.to_string(), - vm_id: constants::EMPTY_GUID.to_string(), + container_id: EMPTY_GUID.to_string(), + tenant_name: EMPTY_GUID.to_string(), + role_name: EMPTY_GUID.to_string(), + role_instance_name: EMPTY_GUID.to_string(), + subscription_id: EMPTY_GUID.to_string(), + resource_group_name: EMPTY_GUID.to_string(), + vm_id: EMPTY_GUID.to_string(), image_origin: 3, // unknown } } @@ -92,9 +63,9 @@ pub struct EventReader { dir_path: PathBuf, delay_start: bool, cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - agent_status_shared_state: AgentStatusSharedState, + common_state: CommonState, + execution_mode: String, + event_name: String, } impl EventReader { @@ -102,17 +73,17 @@ impl EventReader { dir_path: PathBuf, delay_start: bool, cancellation_token: CancellationToken, - key_keeper_shared_state: KeyKeeperSharedState, - telemetry_shared_state: TelemetrySharedState, - agent_status_shared_state: AgentStatusSharedState, + common_state: CommonState, + execution_mode: String, + event_name: String, ) -> EventReader { EventReader { dir_path, delay_start, cancellation_token, - key_keeper_shared_state, - telemetry_shared_state, - agent_status_shared_state, + common_state, + execution_mode, + event_name, } } @@ -122,23 +93,22 @@ impl EventReader { server_ip: Option<&str>, server_port: Option, ) { - logger::write_information("telemetry event reader task started.".to_string()); + logger_manager::write_info("telemetry event reader task started.".to_string()); let wire_server_client = WireServerClient::new( - server_ip.unwrap_or(constants::WIRE_SERVER_IP), - server_port.unwrap_or(constants::WIRE_SERVER_PORT), + server_ip.unwrap_or(WIRE_SERVER_IP), + server_port.unwrap_or(WIRE_SERVER_PORT), ); let imds_client = ImdsClient::new( - server_ip.unwrap_or(constants::IMDS_IP), - server_port.unwrap_or(constants::IMDS_PORT), + server_ip.unwrap_or(IMDS_IP), + server_port.unwrap_or(IMDS_PORT), ); let interval = interval.unwrap_or(Duration::from_secs(300)); tokio::select! { _ = self.loop_reader(interval, wire_server_client, imds_client ) => {} _ = self.cancellation_token.cancelled() => { - logger::write_warning("cancellation token signal received, stop the telemetry event reader task.".to_string()); - self.stop().await; + logger_manager::write_warn("cancellation token signal received, stop the telemetry event reader task.".to_string()); } } } @@ -166,18 +136,21 @@ impl EventReader { .await { Ok(()) => { - logger::write("success updated the vm metadata.".to_string()); + logger_manager::write_info("success updated the vm metadata.".to_string()); } Err(e) => { - logger::write_warning(format!("Failed to read vm metadata with error {e}.")); + logger_manager::write_warn(format!( + "Failed to read vm metadata with error {e}." + )); } } - if let Ok(Some(vm_meta_data)) = self.telemetry_shared_state.get_vm_meta_data().await { + if let Ok(Some(vm_meta_data)) = self.common_state.get_vm_meta_data().await { let _processed = self .process_events(&wire_server_client, &vm_meta_data) .await; } + tokio::time::sleep(interval).await; } } @@ -198,10 +171,10 @@ impl EventReader { let message = format!( "Telemetry event reader sent {event_count} events from {file_count} files" ); - logger::write(message); + logger_manager::write_info(message); } Err(e) => { - logger::write_warning(format!( + logger_manager::write_warn(format!( "Event Files not found in directory {}: {}", self.dir_path.display(), e @@ -212,26 +185,19 @@ impl EventReader { event_count } - async fn stop(&self) { - let _ = self - .agent_status_shared_state - .set_module_state(ModuleState::STOPPED, AgentStatusModule::TelemetryReader) - .await; - } - async fn update_vm_meta_data( &self, wire_server_client: &WireServerClient, imds_client: &ImdsClient, ) -> Result<()> { let guid = self - .key_keeper_shared_state - .get_current_key_guid() + .common_state + .get_state(common_state::SECURE_KEY_GUID.to_string()) .await .unwrap_or(None); let key = self - .key_keeper_shared_state - .get_current_key_value() + .common_state + .get_state(common_state::SECURE_KEY_VALUE.to_string()) .await .unwrap_or(None); let goal_state = wire_server_client @@ -259,11 +225,10 @@ impl EventReader { image_origin: instance_info.get_image_origin(), }; - self.telemetry_shared_state - .set_vm_meta_data(Some(vm_meta_data.clone())) + self.common_state + .set_vm_meta_data(Some(vm_meta_data)) .await?; - logger::write(format!("Updated VM Metadata: {vm_meta_data:?}")); Ok(()) } @@ -278,10 +243,11 @@ impl EventReader { match misc_helpers::json_read_from_file::>(&file) { Ok(events) => { num_events_logged += events.len(); - Self::send_events(events, wire_server_client, vm_meta_data).await; + self.send_events(events, wire_server_client, vm_meta_data) + .await; } Err(e) => { - logger::write_warning(format!( + logger_manager::write_warn(format!( "Failed to read events from file {}: {}", file.display(), e @@ -295,6 +261,7 @@ impl EventReader { const MAX_MESSAGE_SIZE: usize = 1024 * 64; async fn send_events( + &self, mut events: Vec, wire_server_client: &WireServerClient, vm_meta_data: &VmMetaData, @@ -308,6 +275,8 @@ impl EventReader { telemetry_data.add_event(TelemetryEvent::from_event_log( &event, vm_meta_data.clone(), + self.execution_mode.clone(), + self.event_name.clone(), )); if telemetry_data.get_size() >= Self::MAX_MESSAGE_SIZE { @@ -315,12 +284,12 @@ impl EventReader { if telemetry_data.event_count() == 0 { match serde_json::to_string(&event) { Ok(json) => { - logger::write_warning(format!( + logger_manager::write_warn(format!( "Event data too large. Not sending to wire-server. Event: {json}.", )); } Err(_) => { - logger::write_warning( + logger_manager::write_warn( "Event data too large. Not sending to wire-server. Event cannot be displayed.".to_string() ); } @@ -358,7 +327,7 @@ impl EventReader { break; } Err(e) => { - logger::write_warning(format!( + logger_manager::write_warn(format!( "Failed to send telemetry data to host with error: {e}" )); // wait 15 seconds and retry @@ -371,17 +340,21 @@ impl EventReader { fn clean_files(file: PathBuf) { match remove_file(&file) { Ok(_) => { - logger::write(format!("Removed File: {}", file.display())); + logger_manager::write_info(format!("Removed File: {}", file.display())); } Err(e) => { - logger::write_warning(format!("Failed to remove file {}: {}", file.display(), e)); + logger_manager::write_warn(format!( + "Failed to remove file {}: {}", + file.display(), + e + )); } } } #[cfg(test)] async fn get_vm_meta_data(&self) -> VmMetaData { - if let Ok(Some(vm_meta_data)) = self.telemetry_shared_state.get_vm_meta_data().await { + if let Ok(Some(vm_meta_data)) = self.common_state.get_vm_meta_data().await { vm_meta_data } else { VmMetaData::empty() @@ -392,10 +365,8 @@ impl EventReader { #[cfg(test)] mod tests { use super::*; - use crate::common::logger; - use crate::key_keeper::key::Key; - use proxy_agent_shared::misc_helpers; - use proxy_agent_shared::server_mock; + use crate::misc_helpers; + use crate::server_mock; use std::{env, fs}; #[tokio::test] @@ -411,39 +382,34 @@ mod tests { let ip = "127.0.0.1"; let port = 7071u16; let cancellation_token = CancellationToken::new(); - let key_keeper_shared_state = KeyKeeperSharedState::start_new(); + let common_state = CommonState::start_new(); let event_reader = EventReader { dir_path: events_dir.clone(), delay_start: false, - key_keeper_shared_state: key_keeper_shared_state.clone(), - telemetry_shared_state: TelemetrySharedState::start_new(), cancellation_token: cancellation_token.clone(), - agent_status_shared_state: AgentStatusSharedState::start_new(), + common_state: common_state.clone(), + execution_mode: "Test".to_string(), + event_name: "test_event_reader_thread".to_string(), }; let wire_server_client = WireServerClient::new(ip, port); let imds_client = ImdsClient::new(ip, port); - - key_keeper_shared_state - .update_key(Key::empty()) - .await - .unwrap(); tokio::spawn(server_mock::start( ip.to_string(), port, cancellation_token.clone(), )); tokio::time::sleep(Duration::from_millis(100)).await; - logger::write("server_mock started.".to_string()); + logger_manager::write_info("server_mock started.".to_string()); match event_reader .update_vm_meta_data(&wire_server_client, &imds_client) .await { Ok(()) => { - logger::write("success updated the vm metadata.".to_string()); + logger_manager::write_info("success updated the vm metadata.".to_string()); } Err(e) => { - logger::write_warning(format!("Failed to read vm metadata with error {}.", e)); + logger_manager::write_warn(format!("Failed to read vm metadata with error {}.", e)); } } @@ -458,7 +424,7 @@ mod tests { "test_deserialize_events_from_file".to_string(), )); } - logger::write("10 events created.".to_string()); + logger_manager::write_info("10 events created.".to_string()); misc_helpers::try_create_folder(&events_dir).unwrap(); let mut file_path = events_dir.to_path_buf(); file_path.push(format!("{}.json", misc_helpers::get_date_time_unix_nano())); @@ -469,7 +435,7 @@ mod tests { let events_processed = event_reader .process_events(&wire_server_client, &vm_meta_data) .await; - logger::write(format!("Send {} events from event files", events_processed)); + logger_manager::write_info(format!("Send {} events from event files", events_processed)); //Should be 10 events written and read into events Vector assert_eq!(events_processed, 10, "Events processed should be 10"); let files = misc_helpers::get_files(&events_dir).unwrap(); diff --git a/proxy_agent/src/telemetry/telemetry_event.rs b/proxy_agent_shared/src/telemetry/telemetry_event.rs similarity index 67% rename from proxy_agent/src/telemetry/telemetry_event.rs rename to proxy_agent_shared/src/telemetry/telemetry_event.rs index 84adee44..79149d70 100644 --- a/proxy_agent/src/telemetry/telemetry_event.rs +++ b/proxy_agent_shared/src/telemetry/telemetry_event.rs @@ -2,54 +2,11 @@ // SPDX-License-Identifier: MIT //! This module contains the logic to generate the telemetry data to be send to wire server. -//! Example -//! ```rust -//! use proxy_agent::telemetry::TelemetryData; -//! -//! // Create the telemetry data -//! let mut telemetry_data = TelemetryData::new(); -//! -//! // Add the event to the telemetry data -//! let event_log = Event { -//! EventPid: "123".to_string(), -//! EventTid: "456".to_string(), -//! Version: "1.0".to_string(), -//! TaskName: "TaskName".to_string(), -//! TimeStamp: "2024-09-04T02:00:00.222z".to_string(), -//! EventLevel: "Info".to_string(), -//! Message: "Message".to_string(), -//! OperationId: "OperationId".to_string(), -//! }; -//! let vm_meta_data = VmMetaData { -//! container_id: "container_id".to_string(), -//! tenant_name: "tenant_name".to_string(), -//! role_name: "role_name".to_string(), -//! role_instance_name: "role_instance_name".to_string(), -//! subscription_id: "subscription_id".to_string(), -//! resource_group_name: "resource_group_name".to_string(), -//! vm_id: "vm_id".to_string(), -//! image_origin: 1, -//! }; -//! let event = TelemetryEvent::from_event_log(&event_log, vm_meta_data); -//! telemetry_data.add_event(event); -//! -//! // Get the size of the telemetry data -//! let size = telemetry_data.get_size(); -//! -//! // Get the xml of the telemetry data -//! let xml = telemetry_data.to_xml(); -//! -//! // Remove the last event from the telemetry data -//! let event = telemetry_data.remove_last_event(); -//! -//! // Get the event count of the telemetry data -//! let count = telemetry_data.event_count(); -//! ``` use super::event_reader::VmMetaData; -use crate::common::helpers; +use crate::telemetry::Event; +use crate::{current_info, misc_helpers}; use once_cell::sync::Lazy; -use proxy_agent_shared::telemetry::Event; use serde_derive::{Deserialize, Serialize}; /// TelemetryData struct to hold the telemetry events send to wire server. @@ -129,7 +86,12 @@ pub struct TelemetryEvent { } impl TelemetryEvent { - pub fn from_event_log(event_log: &Event, vm_meta_data: VmMetaData) -> Self { + pub fn from_event_log( + event_log: &Event, + vm_meta_data: VmMetaData, + execution_mode: String, + event_name: String, + ) -> Self { TelemetryEvent { event_pid: event_log.EventPid.parse::().unwrap_or(0), event_tid: event_log.EventTid.parse::().unwrap_or(0), @@ -141,12 +103,12 @@ impl TelemetryEvent { context2: event_log.TimeStamp.to_string(), context3: event_log.OperationId.to_string(), - execution_mode: "ProxyAgent".to_string(), - event_name: "MicrosoftAzureGuestProxyAgent".to_string(), - os_version: helpers::get_long_os_version(), + execution_mode, + event_name, + os_version: current_info::get_long_os_version(), keyword_name: CURRENT_KEYWORD_NAME.to_string(), - ram: helpers::get_ram_in_mb(), - processors: helpers::get_cpu_count() as u64, + ram: current_info::get_ram_in_mb(), + processors: current_info::get_cpu_count() as u64, container_id: vm_meta_data.container_id, tenant_name: vm_meta_data.tenant_name, @@ -165,43 +127,43 @@ impl TelemetryEvent { xml.push_str(&format!( "", - helpers::xml_escape(self.opcode_name.to_string()) + misc_helpers::xml_escape(self.opcode_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.keyword_name.to_string()) + misc_helpers::xml_escape(self.keyword_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.task_name.to_string()) + misc_helpers::xml_escape(self.task_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.tenant_name.to_string()) + misc_helpers::xml_escape(self.tenant_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.role_name.to_string()) + misc_helpers::xml_escape(self.role_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.role_instance_name.to_string()) + misc_helpers::xml_escape(self.role_instance_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.container_id.to_string()) + misc_helpers::xml_escape(self.container_id.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.resource_group_name.to_string()) + misc_helpers::xml_escape(self.resource_group_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.subscription_id.to_string()) + misc_helpers::xml_escape(self.subscription_id.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.vm_id.to_string()) + misc_helpers::xml_escape(self.vm_id.to_string()) )); xml.push_str(&format!( "", @@ -218,15 +180,15 @@ impl TelemetryEvent { xml.push_str(&format!( "", - helpers::xml_escape(self.execution_mode.to_string()) + misc_helpers::xml_escape(self.execution_mode.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.os_version.to_string()) + misc_helpers::xml_escape(self.os_version.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.ga_version.to_string()) + misc_helpers::xml_escape(self.ga_version.to_string()) )); xml.push_str(&format!( "", @@ -239,23 +201,23 @@ impl TelemetryEvent { xml.push_str(&format!( "", - helpers::xml_escape(self.event_name.to_string()) + misc_helpers::xml_escape(self.event_name.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.capability_used.to_string()) + misc_helpers::xml_escape(self.capability_used.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.context1.to_string()) + misc_helpers::xml_escape(self.context1.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.context2.to_string()) + misc_helpers::xml_escape(self.context2.to_string()) )); xml.push_str(&format!( "", - helpers::xml_escape(self.context3.to_string()) + misc_helpers::xml_escape(self.context3.to_string()) )); xml.push_str("]]>"); @@ -264,7 +226,7 @@ impl TelemetryEvent { } static CURRENT_KEYWORD_NAME: Lazy = - Lazy::new(|| KeywordName::new(helpers::get_cpu_arch()).to_json()); + Lazy::new(|| KeywordName::new(current_info::get_cpu_arch()).to_json()); #[derive(Serialize, Deserialize)] #[allow(non_snake_case)] diff --git a/proxy_agent_shared/src/windows.rs b/proxy_agent_shared/src/windows.rs index b4b40a71..dd49d5d1 100644 --- a/proxy_agent_shared/src/windows.rs +++ b/proxy_agent_shared/src/windows.rs @@ -10,6 +10,7 @@ use std::os::windows::ffi::OsStrExt; use std::path::Path; use windows_service::service::{ServiceAccess, ServiceState}; use windows_service::service_manager::{ServiceManager, ServiceManagerAccess}; +use windows_sys::Win32::Foundation::{CloseHandle, FALSE, HANDLE}; use windows_sys::Win32::Security::Cryptography::{ //bcrypt.dll functions BCryptCreateHash, @@ -24,7 +25,24 @@ use windows_sys::Win32::Storage::FileSystem::{ VerQueryValueW, VS_FIXEDFILEINFO, }; -use windows_sys::Win32::System::SystemInformation::SYSTEM_INFO; +use windows_sys::Win32::System::JobObjects::{ + AssignProcessToJobObject, CreateJobObjectW, IsProcessInJob, JobObjectCpuRateControlInformation, + JobObjectExtendedLimitInformation, SetInformationJobObject, + JOBOBJECT_CPU_RATE_CONTROL_INFORMATION, JOBOBJECT_CPU_RATE_CONTROL_INFORMATION_0, + JOBOBJECT_CPU_RATE_CONTROL_INFORMATION_0_0, JOBOBJECT_EXTENDED_LIMIT_INFORMATION, + JOB_OBJECT_CPU_RATE_CONTROL_ENABLE, JOB_OBJECT_CPU_RATE_CONTROL_MIN_MAX_RATE, + JOB_OBJECT_LIMIT_PROCESS_MEMORY, JOB_OBJECT_LIMIT_WORKINGSET, +}; +use windows_sys::Win32::System::SystemInformation::{ + GetSystemInfo, // kernel32.dll + GlobalMemoryStatusEx, // kernel32.dll + MEMORYSTATUSEX, + SYSTEM_INFO, +}; +use windows_sys::Win32::System::Threading::{ + OpenProcess, PROCESS_ACCESS_RIGHTS, PROCESS_QUERY_INFORMATION, PROCESS_SET_QUOTA, + PROCESS_TERMINATE, +}; use winreg::enums::*; use winreg::RegKey; @@ -206,6 +224,30 @@ pub fn get_processor_arch() -> String { } } +pub fn get_processor_count() -> usize { + let mut data = MaybeUninit::::uninit(); + unsafe { GetSystemInfo(data.as_mut_ptr()) }; + + let data = unsafe { data.assume_init() }; + data.dwNumberOfProcessors as usize +} + +pub fn get_memory_in_mb() -> Result { + let mut data = MaybeUninit::::uninit(); + let data = data.as_mut_ptr(); + unsafe { + (*data).dwLength = std::mem::size_of::() as u32; + if GlobalMemoryStatusEx(data) == 0 { + return Err(Error::WindowsApi( + "GlobalMemoryStatusEx".to_string(), + std::io::Error::last_os_error(), + )); + } + let memory_in_mb = (*data).ullTotalPhys / 1024 / 1024; + Ok(memory_in_mb) + } +} + pub fn ensure_service_running(service_name: &str) -> (bool, String) { let mut message = String::new(); let service_manager = @@ -401,6 +443,179 @@ pub fn compute_signature(hex_encoded_key: &str, input_to_sign: &[u8]) -> Result< } } +/// Set CPU quota for a process +/// # Arguments +/// * `process_id` - Process ID +/// * `cpu_percent` - CPU quota percentage (0-100) +/// # Returns +/// * `Result` - Job object handle +/// The job object handle must remain valid for the lifetime of the job. +/// The caller is responsible for closing the job object handle when it is no longer needed. +/// # Errors +/// * `Error::Io` - If the cpu_percent is invalid +/// * `Error::WindowsApi` - If any Windows API call fails +pub fn set_resource_limits(process_id: u32, cpu_percent: u16, ram_in_mb: usize) -> Result { + if cpu_percent == 0 || cpu_percent > 100 { + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "CPU quota percent must be between 1 and 100", + ))); + } + + // create job object + let job_object = unsafe { CreateJobObjectW(std::ptr::null(), std::ptr::null()) }; + if job_object == 0 { + return Err(Error::WindowsApi( + "CreateJobObjectW".to_string(), + std::io::Error::last_os_error(), + )); + } + + // Configure the CPU cap first + let mut cpu = JOBOBJECT_CPU_RATE_CONTROL_INFORMATION { + ControlFlags: JOB_OBJECT_CPU_RATE_CONTROL_ENABLE | JOB_OBJECT_CPU_RATE_CONTROL_MIN_MAX_RATE, + Anonymous: JOBOBJECT_CPU_RATE_CONTROL_INFORMATION_0 { + Anonymous: JOBOBJECT_CPU_RATE_CONTROL_INFORMATION_0_0 { + // Per MSDN: Set CpuRate to a percentage times 100. For example, to let the job use 20% of the CPU, set CpuRate to 20 times 100, or 2,000. + MinRate: 300, + MaxRate: cpu_percent * 100, + }, + }, + }; + let ok = unsafe { + SetInformationJobObject( + job_object, + JobObjectCpuRateControlInformation, + &mut cpu as *mut _ as *mut _, + std::mem::size_of::() as u32, + ) + }; + if ok == 0 { + _ = close_handler(job_object); + return Err(Error::WindowsApi( + "SetInformationJobObject - JOBOBJECT_CPU_RATE_CONTROL_INFORMATION".to_string(), + std::io::Error::last_os_error(), + )); + } + + // Configure the memory limit + if ram_in_mb > 0 { + let mut ext: JOBOBJECT_EXTENDED_LIMIT_INFORMATION = unsafe { std::mem::zeroed() }; + ext.BasicLimitInformation.LimitFlags = + JOB_OBJECT_LIMIT_PROCESS_MEMORY | JOB_OBJECT_LIMIT_WORKINGSET; + let ram = ram_in_mb * 1024 * 1024; // Convert MB to bytes + ext.BasicLimitInformation.MaximumWorkingSetSize = ram; + ext.BasicLimitInformation.MinimumWorkingSetSize = ram / 8; + // Set the maximum amount of committed virtual memory too. + ext.ProcessMemoryLimit = ram * 4; + let ok = unsafe { + SetInformationJobObject( + job_object, + JobObjectExtendedLimitInformation, + &mut ext as *mut _ as *mut _, + std::mem::size_of::() as u32, + ) + }; + if ok == 0 { + _ = close_handler(job_object); + return Err(Error::WindowsApi( + "SetInformationJobObject - JOBOBJECT_EXTENDED_LIMIT_INFORMATION".to_string(), + std::io::Error::last_os_error(), + )); + } + } + + // Open the target process with sufficient rights + // The handle must have the PROCESS_SET_QUOTA and PROCESS_TERMINATE access rights. + let process_handle = get_process_handler( + process_id, + PROCESS_QUERY_INFORMATION | PROCESS_SET_QUOTA | PROCESS_TERMINATE, + )?; + + // Check if process is already in a job + let mut in_job: i32 = 0; // BOOL + if unsafe { IsProcessInJob(process_handle, 0, &mut in_job as *mut i32) } == 0 { + let e = std::io::Error::last_os_error(); + _ = close_handler(process_handle); + _ = close_handler(job_object); + return Err(Error::WindowsApi("IsProcessInJob".to_string(), e)); + } + if in_job != 0 { + // Already in a job -> likely cause of ERROR_ACCESS_DENIED + _ = close_handler(process_handle); + _ = close_handler(job_object); + return Err(Error::WindowsApi( + "IsProcessInJob".to_string(), + std::io::Error::new( + std::io::ErrorKind::PermissionDenied, + "Target process is already in a job; cannot assign to a new job", + ), + )); + } + + // Assign the process to the job object + let ok = unsafe { AssignProcessToJobObject(job_object, process_handle) }; + let err = std::io::Error::last_os_error(); + _ = close_handler(process_handle); + if ok == 0 { + return Err(Error::WindowsApi( + "AssignProcessToJobObject".to_string(), + err, + )); + } + + // Do NOT close job_object as we need keep job_object open while need limits. + Ok(job_object) +} + +/// Get process handler by pid +/// # Arguments +/// * `pid` - Process ID +/// # Returns +/// * `Result` - Process handler +/// # Errors +/// * `Error::Invalid` - If the pid is 0 +/// * `Error::WindowsApi` - If the OpenProcess call fails +/// # Safety +/// This function is safe to call as it does not dereference any raw pointers. +/// However, the caller is responsible for closing the process handler using `close_handler` +/// when it is no longer needed to avoid resource leaks. +pub fn get_process_handler(pid: u32, options: PROCESS_ACCESS_RIGHTS) -> Result { + if pid == 0 { + return Err(Error::Io(std::io::Error::new( + std::io::ErrorKind::InvalidInput, + "Process ID cannot be 0", + ))); + } + // https://learn.microsoft.com/en-us/windows/win32/api/processthreadsapi/nf-processthreadsapi-openprocess + let handler = unsafe { OpenProcess(options, FALSE, pid) }; + if handler == 0 { + return Err(Error::WindowsApi( + "OpenProcess".to_string(), + std::io::Error::last_os_error(), + )); + } + Ok(handler) +} + +/// Close process handler +/// # Arguments +/// * `handler` - Process handler +/// # Returns +/// * `Result<()>` - Ok if successful, Err if failed +pub fn close_handler(handler: HANDLE) -> Result<()> { + if handler != 0 { + // https://learn.microsoft.com/en-us/windows/win32/api/handleapi/nf-handleapi-closehandle + if 0 != unsafe { CloseHandle(handler) } { + return Err(Error::WindowsApi( + "CloseHandle".to_string(), + std::io::Error::last_os_error(), + )); + } + } + Ok(()) +} + #[cfg(test)] mod tests { @@ -467,4 +682,19 @@ mod tests { // Clean up super::remove_reg_key(key_name).unwrap(); } + + #[test] + fn set_resource_limits_test() { + let pid = std::process::id(); + match super::set_resource_limits(pid, 100, 10) { + Ok(job_handle) => { + // Close the job handle + super::close_handler(job_handle).unwrap(); + } + Err(e) => { + // Print the error but do not fail the test as setting resource limits may fail due to the test environment already under a job object + println!("Failed to set resource limits: {}", e); + } + } + } }