diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt
index d930bd2e..c8a46ad2 100644
--- a/.github/actions/spelling/expect.txt
+++ b/.github/actions/spelling/expect.txt
@@ -58,6 +58,7 @@ daddr
datacenter
davidanson
DDCE
+DDTHH
debbuild
DEBHELPER
debian
@@ -145,6 +146,8 @@ INVM
Ioctl
iusr
jetbrains
+jqlang
+JOBOBJECT
jobsjob
joutvhu
JScript
@@ -175,6 +178,7 @@ MFC
microsoftcblmariner
microsoftlosangeles
microsoftwindowsdesktop
+mmm
mnt
msasn
msp
@@ -270,6 +274,7 @@ spellright
splitn
SRPMS
SSRF
+SSZ
stackoverflow
stdbool
stdint
@@ -320,10 +325,12 @@ vflji
vhd
vmagentlog
VMGA
+vmhwm
VMId
vmlinux
vmr
vmrcs
+vmrss
vms
vns
VTeam
@@ -351,6 +358,7 @@ wireserverand
wireserverandimds
WMI
workarounds
+WORKINGSET
WORKDIR
WScript
wsf
diff --git a/.github/workflows/reusable-build.yml b/.github/workflows/reusable-build.yml
index 0c409b0d..750b320d 100644
--- a/.github/workflows/reusable-build.yml
+++ b/.github/workflows/reusable-build.yml
@@ -44,7 +44,7 @@ jobs:
- name: rust-toolchain
uses: actions-rs/toolchain@v1.0.6
with:
- toolchain: 1.69.0
+ toolchain: stable
- name: Install llvm Code Coverage
uses: taiki-e/install-action@cargo-llvm-cov
@@ -198,7 +198,7 @@ jobs:
- name: rust-toolchain
uses: actions-rs/toolchain@v1.0.6
with:
- toolchain: 1.69.0
+ toolchain: stable
- name: Run Build.cmd Debug arm64
run: .\build.cmd debug arm64
diff --git a/Cargo.lock b/Cargo.lock
index 875c9875..1b252ea4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4,7 +4,7 @@ version = 4
[[package]]
name = "ProxyAgentExt"
-version = "1.0.38"
+version = "1.0.39"
dependencies = [
"clap",
"ctor",
@@ -172,7 +172,7 @@ dependencies = [
[[package]]
name = "azure-proxy-agent"
-version = "1.0.38"
+version = "1.0.39"
dependencies = [
"aya",
"bitflags",
@@ -925,7 +925,7 @@ dependencies = [
[[package]]
name = "proxy_agent_setup"
-version = "1.0.38"
+version = "1.0.39"
dependencies = [
"clap",
"proxy_agent_shared",
@@ -937,7 +937,7 @@ dependencies = [
[[package]]
name = "proxy_agent_shared"
-version = "1.0.38"
+version = "1.0.39"
dependencies = [
"chrono",
"concurrent-queue",
@@ -957,6 +957,7 @@ dependencies = [
"serde-xml-rs",
"serde_derive",
"serde_json",
+ "sysinfo",
"thiserror",
"thread-id",
"time",
diff --git a/build-linux.sh b/build-linux.sh
index 5a3094c5..ae4cf0fd 100755
--- a/build-linux.sh
+++ b/build-linux.sh
@@ -60,7 +60,7 @@ echo "======= BuildEnvironment is $BuildEnvironment"
echo "======= rustup update to a particular version"
-rustup_version=1.88.0
+rustup_version=1.92.0
rustup update $rustup_version
# This command sets a specific Rust toolchain version for the current directory.
diff --git a/build.cmd b/build.cmd
index ae93d5d7..57e2a790 100644
--- a/build.cmd
+++ b/build.cmd
@@ -55,7 +55,7 @@ if "%Target%"=="arm64" (
)
echo ======= rustup update to a particular version of the Rust toolchain
-SET rustup_version=1.88.0
+SET rustup_version=1.92.0
call rustup update %rustup_version%
REM This command sets a specific Rust toolchain version for the current directory.
REM It means that whenever you are in this directory, Rust commands will use the specified toolchain version, regardless of the global default.
diff --git a/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj b/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj
index 1e27c879..9575c6bb 100644
--- a/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj
+++ b/e2etest/GuestProxyAgentTest/GuestProxyAgentTest.csproj
@@ -9,12 +9,13 @@
private preview feed for Azure SDK for .NET packages
https://pkgs.dev.azure.com/azure-sdk/public/_packaging/azure-sdk-for-net/nuget/v3/index.json
-->
+ false
-
-
+
+
diff --git a/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh b/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh
index 4b5d05fa..1ea12b8d 100755
--- a/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh
+++ b/e2etest/GuestProxyAgentTest/LinuxScripts/GuestProxyAgentValidation.sh
@@ -4,8 +4,11 @@
# SPDX-License-Identifier: MIT
customOutputJsonUrl=$(echo $customOutputJsonSAS | base64 -d)
+expectedSecureChannelState=$(echo $expectedSecureChannelState)
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - Start Guest Proxy Agent Validation"
+echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - expectedSecureChannelState=$expectedSecureChannelState"
+
currentDir=$(pwd)
customOutputJsonPath=$currentDir/proxyagentvalidation.json
@@ -47,12 +50,104 @@ else
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - logdir does not exist"
fi
+echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - detecting os and installing jq"
+os=$(hostnamectl | grep "Operating System")
+echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - os=$os"
+if [[ $os == *"Ubuntu"* ]]; then
+ for i in {1..3}; do
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - start installing jq via apt-get $i"
+ sudo apt update
+ sudo apt-get install -y jq
+ sleep 10
+ install=$(apt list --installed jq)
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - install=$install"
+ if [[ $install == *"jq"* ]]; then
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - jq installed successfully"
+ break
+ fi
+ done
+elif [[ $os == *"SUSE"* ]]; then
+ # SUSE repo does not have jq package, so we download jq binary directly
+ # Detect architecture
+ arch=$(uname -m)
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - detected architecture: $arch"
+ if [[ $arch == "x86_64" ]]; then
+ jq_binary="jq-linux-amd64"
+ elif [[ $arch == "aarch64" ]] || [[ $arch == "arm64" ]]; then
+ jq_binary="jq-linux-arm64"
+ else
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - unsupported architecture: $arch"
+ jq_binary="jq-linux64"
+ fi
+ # Download jq binary directly from GitHub
+ for i in {1..3}; do
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - downloading $jq_binary binary (attempt $i)"
+ sudo curl -L https://github.com/jqlang/jq/releases/download/jq-1.8.1/$jq_binary -o /usr/local/bin/jq
+ if [ $? -eq 0 ]; then
+ sudo chmod +x /usr/local/bin/jq
+ if command -v jq &> /dev/null; then
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - jq installed successfully"
+ jq --version
+ break
+ fi
+ fi
+ sleep 5
+ done
+else
+ for i in {1..3}; do
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - start installing jq via dnf $i"
+ sudo dnf -y install jq
+ sleep 10
+ install=$(dnf list --installed jq)
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - install=$install"
+ if [[ $install == *"jq"* ]]; then
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - jq installed successfully"
+ break
+ fi
+ done
+fi
+
+# check status.json file Content
+## check timestamp of last entry in status.json file
+## check the secure channel status
+timeout=300
+elapsed=0
+statusFile=$logdir/status.json
+secureChannelState=""
+
+# Current UTC time in epoch seconds
+currentUtcTime=$(date -u +%s)
+echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - Checking GPA status file $statusFile with 5 minute timeout"
+while :; do
+ timestamp=$(cat "$statusFile" | jq -r '.timestamp')
+ # Convert timestamp to epoch seconds
+ timestampEpoch=$(date -u -d "$timestamp" +%s)
+ if ((timestampEpoch > currentUtcTime)); then
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - The last entry timestamp '$timestamp' is valid."
+ ## check secure channel status
+ secureChannelState=$(cat "$statusFile" | jq -r '.proxyAgentStatus.keyLatchStatus.states.secureChannelState')
+ if [[ "$secureChannelState" == "$expectedSecureChannelState" ]]; then
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - The secure channel status '$secureChannelState' matches the expected state: '$expectedSecureChannelState'."
+ break
+ else
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - The secure channel status '$secureChannelState' does not match the expected state: '$expectedSecureChannelState'."
+ fi
+ fi
+ ((elapsed += 3))
+ if [[ $elapsed -ge $timeout ]]; then
+ echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - Timeout reached. Error, The secureChannelState is '$secureChannelState'."
+ break
+ fi
+ sleep 3
+done
+
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyAgentServiceExist=$guestProxyAgentServiceExist"
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyAgentServiceStatus=$guestProxyAgentServiceStatus"
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyProcessStarted=$guestProxyProcessStarted"
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - guestProxyAgentLogGenerated=$guestProxyAgentLogGenerated"
+echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - secureChannelState=$secureChannelState"
-jsonString='{"guestProxyAgentServiceInstalled": "'$guestProxyAgentServiceExist'", "guestProxyAgentServiceStatus": "'$guestProxyAgentServiceStatus'", "guestProxyProcessStarted": "'$guestProxyProcessStarted'", "guestProxyAgentLogGenerated": "'$guestProxyAgentLogGenerated'"}'
+jsonString='{"guestProxyAgentServiceInstalled": "'$guestProxyAgentServiceExist'", "guestProxyAgentServiceStatus": "'$guestProxyAgentServiceStatus'", "guestProxyProcessStarted": "'$guestProxyProcessStarted'", "secureChannelState": "'$secureChannelState'", "guestProxyAgentLogGenerated": "'$guestProxyAgentLogGenerated'"}'
echo "$(date -u +"%Y-%m-%dT%H:%M:%SZ") - $jsonString"
# write to $customOutputJsonPath
diff --git a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1 b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1
index e3f7c474..e2362587 100644
--- a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1
+++ b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentExtensionValidation.ps1
@@ -3,7 +3,7 @@
param (
[Parameter(Mandatory=$true, Position=0)]
- [string]$customOutputJsonSAS,
+ [string]$customOutputJsonSAS,
[string]$expectedProxyAgentVersion
)
Write-Output "$((Get-Date).ToUniversalTime()) - expectedProxyAgentVersion=$expectedProxyAgentVersion"
diff --git a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1 b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1
index a8037d01..22c9c90b 100644
--- a/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1
+++ b/e2etest/GuestProxyAgentTest/Scripts/GuestProxyAgentValidation.ps1
@@ -3,13 +3,15 @@
param (
[Parameter(Mandatory=$true, Position=0)]
- [string]$customOutputJsonSAS
+ [string]$customOutputJsonSAS,
+ [string]$expectedSecureChannelState
)
$decodedUrlBytes = [System.Convert]::FromBase64String($customOutputJsonSAS)
$decodedUrlString = [System.Text.Encoding]::UTF8.GetString($decodedUrlBytes)
Write-Output "$((Get-Date).ToUniversalTime()) - Start Guest Proxy Agent Validation"
+Write-Output "$((Get-Date).ToUniversalTime()) - expectedSecureChannelState=$expectedSecureChannelState"
$currentFolder = $PWD.Path
$customOutputJsonPath = $currentFolder + "\proxyagentvalidation.json";
@@ -21,7 +23,7 @@ $guestProxyAgentServiceExist = $true
$guestProxyAgentServiceStatus = ""
$guestProxyAgentProcessExist = $true
-if ($service -ne $null) {
+if ($null -ne $service) {
Write-Output "$((Get-Date).ToUniversalTime()) - The service $serviceName exists."
$guestProxyAgentServiceStatus = $service.Status
} else {
@@ -34,7 +36,7 @@ $processName = "GuestProxyAgent"
$process = Get-Process -Name $processName -ErrorAction SilentlyContinue
-if ($process -ne $null) {
+if ($null -ne $process) {
Write-Output "$((Get-Date).ToUniversalTime()) - The process $processName exists."
} else {
$guestProxyAgentProcessExist = $false
@@ -57,10 +59,52 @@ if (Test-Path -Path $folderPath -PathType Container) {
Write-Output "$((Get-Date).ToUniversalTime()) - The folder $folderPath does not exist."
}
+# check status.json file Content
+## check timestamp of last entry in status.json file
+## check the secure channel status
+$timeoutInSeconds = 300
+$statusFilePath = [IO.Path]::Combine($folderPath, "status.json")
+Write-Output "$((Get-Date).ToUniversalTime()) - Checking GPA status file $statusFilePath with 5 minute timeout"
+$secureChannelState = ""
+$stopwatch = [System.Diagnostics.Stopwatch]::StartNew()
+$currentUtcTime = (Get-Date).ToUniversalTime()
+do {
+ $boolStatus = Test-Path -Path $statusFilePath
+ if ($boolStatus) {
+ $json = Get-Content $statusFilePath | Out-String | ConvertFrom-Json
+ $timestamp = $json.timestamp
+ if ($null -ne $timestamp -and $timestamp -ne "") {
+ Write-Output "$((Get-Date).ToUniversalTime()) - The status.json file contains a valid timestamp: $timestamp"
+ # parse the timestamp to UTC DateTime object, if must later than $currentUtcTime
+ $timestampDateTime = [DateTime]::Parse($timestamp).ToUniversalTime()
+ if ($timestampDateTime -gt $currentUtcTime) {
+ Write-Output "$((Get-Date).ToUniversalTime()) - The status.json timestamp $timestampDateTime is later than $currentUtcTime."
+ ## check secure channel status
+ $secureChannelState = $json.proxyAgentStatus.keyLatchStatus.states.secureChannelState
+ Write-Output "$((Get-Date).ToUniversalTime()) - The secure channel status is $secureChannelState."
+ if ($secureChannelState -eq $expectedSecureChannelState) {
+ Write-Output "$((Get-Date).ToUniversalTime()) - The secure channel status '$secureChannelState' matches the expected state: '$expectedSecureChannelState'."
+ # break
+ } else {
+ Write-Output "$((Get-Date).ToUniversalTime()) - The secure channel status '$secureChannelState' does not match the expected state: '$expectedSecureChannelState'."
+ }
+
+ if ($stopwatch.Elapsed.TotalSeconds -ge $timeoutInSeconds) {
+ Write-Output "$((Get-Date).ToUniversalTime()) - Timeout reached. Error, The secureChannelState is '$secureChannelState'."
+ break
+ }
+ }
+ } else {
+ Write-Output "$((Get-Date).ToUniversalTime()) - The status.json file does not contain a valid timestamp yet."
+ }
+ }
+ start-sleep -Seconds 3
+} until ($false)
$jsonString = '{"guestProxyAgentServiceInstalled": ' + $guestProxyAgentServiceExist.ToString().ToLower() `
+ ', "guestProxyProcessStarted": ' + $guestProxyAgentProcessExist.ToString().ToLower() `
+ ', "guestProxyAgentServiceStatus": "' + $guestProxyAgentServiceStatus `
+ + '", "secureChannelState": "' + $secureChannelState `
+ '", "guestProxyAgentLogGenerated": ' + $guestProxyAgentLogGenerated.ToString().ToLower() + '}'
Write-Output "$((Get-Date).ToUniversalTime()) - $jsonString"
diff --git a/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs b/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs
index 19e2e066..8c17594a 100644
--- a/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs
+++ b/e2etest/GuestProxyAgentTest/Settings/TestSetting.cs
@@ -3,10 +3,7 @@
using Azure.Core;
using GuestProxyAgentTest.Models;
using GuestProxyAgentTest.Utilities;
-using Newtonsoft.Json;
using System.Reflection;
-using System.Runtime.Serialization;
-using System.Text.Json.Serialization;
namespace GuestProxyAgentTest.Settings
{
diff --git a/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs b/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs
index 87fb512e..1373465e 100644
--- a/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs
+++ b/e2etest/GuestProxyAgentTest/TestCases/EnableProxyAgentCase.cs
@@ -1,25 +1,33 @@
-using Azure.ResourceManager.Compute.Models;
+// Copyright (c) Microsoft Corporation
+// SPDX-License-Identifier: MIT
+//
+using Azure.ResourceManager.Compute.Models;
using GuestProxyAgentTest.Models;
using GuestProxyAgentTest.Settings;
using GuestProxyAgentTest.TestScenarios;
+using GuestProxyAgentTest.Utilities;
using Newtonsoft.Json;
+using System.Xml.Linq;
namespace GuestProxyAgentTest.TestCases
{
internal class EnableProxyAgentCase : TestCaseBase
{
- public EnableProxyAgentCase() : this("EnableProxyAgentCase", true)
+ public EnableProxyAgentCase() : this("EnableProxyAgentCase", true, false)
{ }
- public EnableProxyAgentCase(string testCaseName) : this(testCaseName, true)
+ public EnableProxyAgentCase(string testCaseName) : this(testCaseName, true, false)
{ }
- public EnableProxyAgentCase(string testCaseName, bool enableProxyAgent) : base(testCaseName)
+ public EnableProxyAgentCase(string testCaseName, bool enableProxyAgent, bool addProxyAgentExtensionForLinuxVM) : base(testCaseName)
{
EnableProxyAgent = enableProxyAgent;
+ AddProxyAgentVMExtension = addProxyAgentExtensionForLinuxVM;
}
internal bool EnableProxyAgent { get; set; }
+ internal bool AddProxyAgentVMExtension { get; set; } = false;
+
public override async Task StartAsync(TestCaseExecutionContext context)
{
var vmr = context.VirtualMachineResource;
@@ -34,16 +42,23 @@ public override async Task StartAsync(TestCaseExecutionContext context)
}
}
};
+ // Only Linux VMs support flag 'AddProxyAgentExtension',
+ // Windows VMs always have the GPA VM Extension installed when ProxyAgentSettings.Enabled is true.
+ if (!Constants.IS_WINDOWS())
+ {
+ patch.SecurityProfile.ProxyAgentSettings.AddProxyAgentExtension = AddProxyAgentVMExtension;
+ }
if (EnableProxyAgent)
{
+ // property 'inVMAccessControlProfileReferenceId' cannot be used together with property 'mode'
patch.SecurityProfile.ProxyAgentSettings.WireServer = new HostEndpointSettings
{
- InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmWireServerAccessControlProfileReferenceId
+ InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmWireServerAccessControlProfileReferenceId,
};
patch.SecurityProfile.ProxyAgentSettings.Imds = new HostEndpointSettings
{
- InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmIMDSAccessControlProfileReferenceId
+ InVmAccessControlProfileReferenceId = TestSetting.Instance.InVmIMDSAccessControlProfileReferenceId,
};
}
diff --git a/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs b/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs
index 40e02aac..acd55770 100644
--- a/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs
+++ b/e2etest/GuestProxyAgentTest/TestCases/GuestProxyAgentValidationCase.cs
@@ -14,6 +14,7 @@ public class GuestProxyAgentValidationCase : TestCaseBase
{
private static readonly string EXPECTED_GUEST_PROXY_AGENT_SERVICE_STATUS;
+ private string expectedSecureChannelState = "disabled";
static GuestProxyAgentValidationCase()
{
if (Constants.IS_WINDOWS())
@@ -28,9 +29,16 @@ static GuestProxyAgentValidationCase()
public GuestProxyAgentValidationCase() : base("GuestProxyAgentValidationCase")
{ }
+ public GuestProxyAgentValidationCase(string testCaseName, string expectedSecureChannelState) : base(testCaseName)
+ {
+ this.expectedSecureChannelState = expectedSecureChannelState;
+ }
+
public override async Task StartAsync(TestCaseExecutionContext context)
{
- context.TestResultDetails = (await RunScriptViaRunCommandV2Async(context, Constants.GUEST_PROXY_AGENT_VALIDATION_SCRIPT_NAME, null!)).ToTestResultDetails(ConsoleLog);
+ List<(string, string)> parameterList = new List<(string, string)>();
+ parameterList.Add(("expectedSecureChannelState", expectedSecureChannelState));
+ context.TestResultDetails = (await RunScriptViaRunCommandV2Async(context, Constants.GUEST_PROXY_AGENT_VALIDATION_SCRIPT_NAME, parameterList)).ToTestResultDetails(ConsoleLog);
if (context.TestResultDetails.Succeed && context.TestResultDetails.CustomOut != null)
{
var validationDetails = context.TestResultDetails.SafeDeserializedCustomOutAs();
@@ -40,7 +48,8 @@ public override async Task StartAsync(TestCaseExecutionContext context)
&& validationDetails.GuestProxyAgentServiceInstalled
&& validationDetails.GuestProxyAgentServiceStatus.Equals(EXPECTED_GUEST_PROXY_AGENT_SERVICE_STATUS, StringComparison.OrdinalIgnoreCase)
&& validationDetails.GuestProxyProcessStarted
- && validationDetails.GuestProxyAgentLogGenerated)
+ && validationDetails.GuestProxyAgentLogGenerated
+ && validationDetails.SecureChannelState.Equals(expectedSecureChannelState, StringComparison.OrdinalIgnoreCase))
{
context.TestResultDetails.Succeed = true;
}
@@ -58,5 +67,6 @@ class GuestProxyAgentValidationDetails
public bool GuestProxyProcessStarted { get; set; }
public bool GuestProxyAgentLogGenerated { get; set; }
public string GuestProxyAgentServiceStatus { get; set; } = null!;
+ public string SecureChannelState { get; set; } = null!;
}
}
diff --git a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml
index d5dfdccb..c68fd4df 100644
--- a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml
+++ b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Arm64-TestGroup.yml
@@ -9,4 +9,6 @@ scenarios:
- name: LinuxPackageScenario
className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario
- name: ProxyAgentExtension
- className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
\ No newline at end of file
+ className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
+ - name: LinuxImplicitExtension
+ className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension
\ No newline at end of file
diff --git a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml
index ad34139f..68dfb487 100644
--- a/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml
+++ b/e2etest/GuestProxyAgentTest/TestMap/AzureLinux3-Fips-TestGroup.yml
@@ -9,4 +9,6 @@ scenarios:
- name: LinuxPackageScenario
className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario
- name: ProxyAgentExtension
- className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
\ No newline at end of file
+ className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
+ - name: LinuxImplicitExtension
+ className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension
\ No newline at end of file
diff --git a/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml b/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml
index 0a0d1e6b..3cd46b95 100644
--- a/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml
+++ b/e2etest/GuestProxyAgentTest/TestMap/Test-Map-Linux.yml
@@ -1,6 +1,5 @@
testGroupList:
- include: AzureLinux3-Fips-TestGroup.yml
- - include: Mariner2-Fips-TestGroup.yml
- include: Ubuntu24-TestGroup.yml
- include: Ubuntu22-TestGroup.yml
- include: Ubuntu20-TestGroup.yml
diff --git a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml
index e10e26b1..5664d9f6 100644
--- a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml
+++ b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-Arm64-TestGroup.yml
@@ -9,4 +9,6 @@ scenarios:
- name: LinuxPackageScenario
className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario
- name: ProxyAgentExtension
- className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
\ No newline at end of file
+ className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
+ - name: LinuxImplicitExtension
+ className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension
\ No newline at end of file
diff --git a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml
index 24c28233..bda47658 100644
--- a/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml
+++ b/e2etest/GuestProxyAgentTest/TestMap/Ubuntu24-TestGroup.yml
@@ -9,4 +9,6 @@ scenarios:
- name: LinuxPackageScenario
className: GuestProxyAgentTest.TestScenarios.LinuxPackageScenario
- name: ProxyAgentExtension
- className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
\ No newline at end of file
+ className: GuestProxyAgentTest.TestScenarios.ProxyAgentExtension
+ - name: LinuxImplicitExtension
+ className: GuestProxyAgentTest.TestScenarios.LinuxImplicitExtension
\ No newline at end of file
diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs b/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs
index 4226d896..3bcc2c9b 100644
--- a/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs
+++ b/e2etest/GuestProxyAgentTest/TestScenarios/BVTScenario.cs
@@ -27,6 +27,7 @@ public override void TestScenarioSetup()
// it will add GPA VM Extension and overwrite the private GPA package
AddTestCase(new EnableProxyAgentCase());
secureChannelEnabled = true;
+ AddTestCase(new GuestProxyAgentValidationCase("GuestProxyAgentValidationWithSecureChannelEnabled", "WireServer Enforce - IMDS Enforce - HostGA Enforce"));
}
AddTestCase(new IMDSPingTestCase("IMDSPingTestBeforeReboot", secureChannelEnabled));
diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/LinuxImplicitExtension.cs b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxImplicitExtension.cs
new file mode 100644
index 00000000..c535c097
--- /dev/null
+++ b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxImplicitExtension.cs
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation
+// SPDX-License-Identifier: MIT
+using GuestProxyAgentTest.TestCases;
+using GuestProxyAgentTest.Utilities;
+
+namespace GuestProxyAgentTest.TestScenarios
+{
+ public class LinuxImplicitExtension : TestScenarioBase
+ {
+ public override void TestScenarioSetup()
+ {
+ if (Constants.IS_WINDOWS())
+ {
+ throw new InvalidOperationException("LinuxImplicitExtension scenario can only run on Linux VMs.");
+ }
+
+ // Passing in 0 version number for the first validation case
+ string proxyAgentVersionBeforeUpdate = "0";
+ string proxyAgentVersion = Settings.TestSetting.Instance.proxyAgentVersion;
+ ConsoleLog(string.Format("Received ProxyAgent Version:{0}", proxyAgentVersion));
+ // implicitly enable the Guest Proxy Agent extension by setting EnableProxyAgent to true and AddProxyAgentVMExtension to true
+ AddTestCase(new EnableProxyAgentCase("EnableProxyAgentCase", true, true));
+ AddTestCase(new GuestProxyAgentExtensionValidationCase("GuestProxyAgentExtensionValidationCaseBeforeUpdate", proxyAgentVersionBeforeUpdate));
+ AddTestCase(new InstallOrUpdateGuestProxyAgentExtensionCase());
+ AddTestCase(new GuestProxyAgentExtensionValidationCase("GuestProxyAgentExtensionValidationCaseAfterUpdate", proxyAgentVersion));
+ AddTestCase(new IMDSPingTestCase("IMDSPingTestBeforeReboot", true));
+ AddTestCase(new RebootVMCase("RebootVMCaseAfterUpdateGuestProxyAgentExtension"));
+ AddTestCase(new IMDSPingTestCase("IMDSPingTestAfterReboot", true));
+ }
+ }
+}
diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs
index 82661abe..15ba7bc8 100644
--- a/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs
+++ b/e2etest/GuestProxyAgentTest/TestScenarios/LinuxPackageScenario.cs
@@ -13,6 +13,7 @@ public override void TestScenarioSetup()
AddTestCase(new InstallOrUpdateGuestProxyAgentPackageCase());
AddTestCase(new GuestProxyAgentValidationCase());
AddTestCase(new EnableProxyAgentCase());
+ AddTestCase(new GuestProxyAgentValidationCase("GuestProxyAgentValidationWithSecureChannelEnabled", "WireServer Enforce - IMDS Enforce - HostGA Enforce"));
AddTestCase(new IMDSPingTestCase("IMDSPingTestBeforeReboot", true));
AddTestCase(new RebootVMCase("RebootVMCaseAfterInstallOrUpdateGuestProxyAgent"));
AddTestCase(new IMDSPingTestCase("IMDSPingTestAfterReboot", true));
diff --git a/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs b/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs
index 5be330e7..7b5ae1c2 100644
--- a/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs
+++ b/e2etest/GuestProxyAgentTest/TestScenarios/ProxyAgentExtension.cs
@@ -2,8 +2,6 @@
// SPDX-License-Identifier: MIT
using GuestProxyAgentTest.TestCases;
using GuestProxyAgentTest.Utilities;
-using System.Diagnostics;
-using System.IO.Compression;
namespace GuestProxyAgentTest.TestScenarios
{
diff --git a/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs b/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs
index 7dc8b460..f96d2518 100644
--- a/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs
+++ b/e2etest/GuestProxyAgentTest/Utilities/VMBuilder.cs
@@ -140,6 +140,12 @@ private async Task DoCreateVMData(ResourceGroupResource rgr,
},
}
};
+ if (!Constants.IS_WINDOWS())
+ {
+ // Only Linux VMs support flag 'AddProxyAgentExtension',
+ // Windows VMs always have the GPA VM Extension installed when ProxyAgentSettings.Enabled is true.
+ vmData.SecurityProfile.ProxyAgentSettings.AddProxyAgentExtension = true;
+ }
}
if (Constants.IS_WINDOWS())
diff --git a/proxy_agent/Cargo.toml b/proxy_agent/Cargo.toml
index d128f2e0..8f279dda 100644
--- a/proxy_agent/Cargo.toml
+++ b/proxy_agent/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "azure-proxy-agent"
-version = "1.0.38" # always 3-number version
+version = "1.0.39" # always 3-number version
edition = "2021"
build = "build.rs"
readme = "README.md"
diff --git a/proxy_agent/src/common/config.rs b/proxy_agent/src/common/config.rs
index d9d1fdab..a9dae135 100644
--- a/proxy_agent/src/common/config.rs
+++ b/proxy_agent/src/common/config.rs
@@ -49,10 +49,6 @@ pub fn get_monitor_duration() -> Duration {
pub fn get_poll_key_status_duration() -> Duration {
Duration::from_secs(SYSTEM_CONFIG.get_poll_key_status_interval())
}
-//TODO: remove this config/function once the contract is defined for HostGAPlugin
-pub fn get_host_gaplugin_support() -> u8 {
- SYSTEM_CONFIG.hostGAPluginSupport
-}
pub fn get_max_event_file_count() -> usize {
SYSTEM_CONFIG.get_max_event_file_count()
@@ -90,7 +86,6 @@ pub struct Config {
latchKeyFolder: String,
monitorIntervalInSeconds: u64,
pollKeyStatusIntervalInSeconds: u64,
- hostGAPluginSupport: u8, // 0 not support; 1 proxy only; 2 proxy + authentication check
#[serde(skip_serializing_if = "Option::is_none")]
maxEventFileCount: Option,
#[serde(skip_serializing_if = "Option::is_none")]
@@ -170,10 +165,6 @@ impl Config {
self.pollKeyStatusIntervalInSeconds
}
- pub fn get_host_gaplugin_support(&self) -> u8 {
- self.hostGAPluginSupport
- }
-
pub fn get_max_event_file_count(&self) -> usize {
self.maxEventFileCount
.unwrap_or(constants::DEFAULT_MAX_EVENT_FILE_COUNT)
@@ -269,12 +260,6 @@ mod tests {
"get_poll_key_status_interval mismatch"
);
- assert_eq!(
- 1u8,
- config.get_host_gaplugin_support(),
- "get_host_gaplugin_support mismatch"
- );
-
assert_eq!(
constants::DEFAULT_MAX_EVENT_FILE_COUNT,
config.get_max_event_file_count(),
diff --git a/proxy_agent/src/common/constants.rs b/proxy_agent/src/common/constants.rs
index 22ae07d7..095aaf8f 100644
--- a/proxy_agent/src/common/constants.rs
+++ b/proxy_agent/src/common/constants.rs
@@ -17,8 +17,6 @@ pub const GA_PLUGIN_IP_NETWORK_BYTE_ORDER: u32 = 0x10813FA8; // 168.63.129.16
pub const IMDS_IP_NETWORK_BYTE_ORDER: u32 = 0xFEA9FEA9; //"169.254.169.254";
pub const PROXY_AGENT_IP_NETWORK_BYTE_ORDER: u32 = 0x100007F; //"127.0.0.1";
-pub const EMPTY_GUID: &str = "00000000-0000-0000-0000-000000000000";
-
pub const KEY_DELIVERY_METHOD_HTTP: &str = "http";
pub const KEY_DELIVERY_METHOD_VTPM: &str = "vtpm";
diff --git a/proxy_agent/src/common/helpers.rs b/proxy_agent/src/common/helpers.rs
index 0ee66384..518555cd 100644
--- a/proxy_agent/src/common/helpers.rs
+++ b/proxy_agent/src/common/helpers.rs
@@ -1,76 +1,9 @@
// Copyright (c) Microsoft Corporation
// SPDX-License-Identifier: MIT
-use super::logger;
+
use once_cell::sync::Lazy;
-use proxy_agent_shared::misc_helpers;
use proxy_agent_shared::telemetry::span::SimpleSpan;
-#[cfg(not(windows))]
-use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System};
-
-#[cfg(windows)]
-use super::windows;
-
-static CURRENT_SYS_INFO: Lazy<(u64, usize)> = Lazy::new(|| {
- #[cfg(windows)]
- {
- let ram_in_mb = match windows::get_memory_in_mb() {
- Ok(ram) => ram,
- Err(e) => {
- logger::write_error(format!("get_memory_in_mb failed: {e}"));
- 0
- }
- };
- let cpu_count = windows::get_processor_count();
- (ram_in_mb, cpu_count)
- }
- #[cfg(not(windows))]
- {
- let sys = System::new_with_specifics(
- RefreshKind::new()
- .with_memory(MemoryRefreshKind::everything())
- .with_cpu(CpuRefreshKind::everything()),
- );
- let ram = sys.total_memory();
- let ram_in_mb = ram / 1024 / 1024;
- let cpu_count = sys.cpus().len();
- (ram_in_mb, cpu_count)
- }
-});
-
-static CURRENT_OS_INFO: Lazy<(String, String)> = Lazy::new(|| {
- //arch
- let arch = misc_helpers::get_processor_arch();
- // os
- let os = misc_helpers::get_long_os_version();
- (arch, os)
-});
-
-pub fn get_ram_in_mb() -> u64 {
- CURRENT_SYS_INFO.0
-}
-
-pub fn get_cpu_count() -> usize {
- CURRENT_SYS_INFO.1
-}
-
-pub fn get_cpu_arch() -> String {
- CURRENT_OS_INFO.0.to_string()
-}
-
-pub fn get_long_os_version() -> String {
- CURRENT_OS_INFO.1.to_string()
-}
-
-// replace xml escape characters
-pub fn xml_escape(s: String) -> String {
- s.replace('&', "&")
- .replace('\'', "'")
- .replace('"', """)
- .replace('<', "<")
- .replace('>', ">")
-}
-
static START: Lazy = Lazy::new(SimpleSpan::new);
pub fn get_elapsed_time_in_millisec() -> u128 {
@@ -85,22 +18,6 @@ pub fn write_startup_event(
) -> String {
let message = START.write_event(task, method_name, module_name, logger_key);
#[cfg(not(windows))]
- logger::write_serial_console_log(message.clone());
+ crate::common::logger::write_serial_console_log(message.clone());
message
}
-
-#[cfg(test)]
-mod tests {
- #[test]
- fn get_system_info_tests() {
- let ram = super::get_ram_in_mb();
- assert!(ram > 100, "total ram must great than 100MB");
- let cpu_count = super::get_cpu_count();
- assert!(
- cpu_count >= 1,
- "total cpu count must great than or equal to 1"
- );
- let cpu_arch = super::get_cpu_arch();
- assert_ne!("unknown", cpu_arch, "cpu arch cannot be 'unknown'");
- }
-}
diff --git a/proxy_agent/src/key_keeper.rs b/proxy_agent/src/key_keeper.rs
index 1f444c19..fb527e9b 100644
--- a/proxy_agent/src/key_keeper.rs
+++ b/proxy_agent/src/key_keeper.rs
@@ -29,16 +29,19 @@ use self::key::Key;
use crate::common::error::{Error, KeyErrorType};
use crate::common::result::Result;
use crate::common::{constants, helpers, logger};
+use crate::key_keeper::key::KeyStatus;
use crate::provision;
use crate::proxy::authorization_rules::{AuthorizationRulesForLogging, ComputedAuthorizationRules};
+use crate::shared_state::access_control_wrapper::AccessControlSharedState;
use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState};
+use crate::shared_state::connection_summary_wrapper::ConnectionSummarySharedState;
use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState;
use crate::shared_state::provision_wrapper::ProvisionSharedState;
use crate::shared_state::redirector_wrapper::RedirectorSharedState;
-use crate::shared_state::telemetry_wrapper::TelemetrySharedState;
-use crate::shared_state::SharedState;
+use crate::shared_state::{EventThreadsSharedState, SharedState};
use crate::{acl, redirector};
use hyper::Uri;
+use proxy_agent_shared::common_state::CommonState;
use proxy_agent_shared::logger::LoggerLevel;
use proxy_agent_shared::misc_helpers;
use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState;
@@ -56,7 +59,7 @@ pub const MUST_SIG_WIRESERVER_IMDS: &str = "wireserverandimds";
pub const UNKNOWN_STATE: &str = "Unknown";
static FREQUENT_PULL_INTERVAL: Duration = Duration::from_secs(1); // 1 second
const FREQUENT_PULL_TIMEOUT_IN_MILLISECONDS: u128 = 300000; // 5 minutes
-const PROVISION_TIMEUP_IN_MILLISECONDS: u128 = 120000; // 2 minute
+const PROVISION_TIMEOUT_IN_MILLISECONDS: u128 = 120000; // 2 minute
const DELAY_START_EVENT_THREADS_IN_MILLISECONDS: u128 = 60000; // 1 minute
#[derive(Clone)]
@@ -73,14 +76,27 @@ pub struct KeyKeeper {
cancellation_token: CancellationToken,
/// key_keeper_shared_state: the sender for the key details, secure channel state, access control rule
key_keeper_shared_state: KeyKeeperSharedState,
- /// telemetry_shared_state: the sender for the telemetry events
- telemetry_shared_state: TelemetrySharedState,
+ /// common_state: the sender for the common states
+ common_state: CommonState,
/// redirector_shared_state: the sender for the redirector/eBPF module
redirector_shared_state: RedirectorSharedState,
/// provision_shared_state: the sender for the provision state
provision_shared_state: ProvisionSharedState,
/// agent_status_shared_state: the sender for the agent status
agent_status_shared_state: AgentStatusSharedState,
+ /// access_control_shared_state: the sender for the access control rules
+ access_control_shared_state: AccessControlSharedState,
+ /// connection_summary_shared_state: the sender for the connection summary module
+ connection_summary_shared_state: ConnectionSummarySharedState,
+}
+
+/// Reason for waking up from sleep
+/// Used in the pull_secure_channel_status loop
+/// Notified: woke up due to notification
+/// TimerElapsed: woke up due to sleep-timer elapsed
+enum WakeReason {
+ Notified,
+ TimerElapsed,
}
impl KeyKeeper {
@@ -98,10 +114,12 @@ impl KeyKeeper {
interval,
cancellation_token: shared_state.get_cancellation_token(),
key_keeper_shared_state: shared_state.get_key_keeper_shared_state(),
- telemetry_shared_state: shared_state.get_telemetry_shared_state(),
+ common_state: shared_state.get_common_state(),
redirector_shared_state: shared_state.get_redirector_shared_state(),
provision_shared_state: shared_state.get_provision_shared_state(),
agent_status_shared_state: shared_state.get_agent_status_shared_state(),
+ access_control_shared_state: shared_state.get_access_control_shared_state(),
+ connection_summary_shared_state: shared_state.get_connection_summary_shared_state(),
}
}
@@ -171,7 +189,7 @@ impl KeyKeeper {
async fn loop_poll(&self) {
let mut first_iteration: bool = true;
let mut started_event_threads: bool = false;
- let mut provision_timeup: bool = false;
+ let mut provision_timeout: bool = false;
let notify = match self.key_keeper_shared_state.get_notify().await {
Ok(notify) => notify,
Err(e) => {
@@ -191,11 +209,10 @@ impl KeyKeeper {
));
}
- let mut start = Instant::now();
+ let mut provision_start_time = Instant::now();
+ let mut redirect_policy_updated = false;
loop {
if !first_iteration {
- // skip the sleep for the first loop
-
let current_state = match self
.key_keeper_shared_state
.get_current_secure_channel_state()
@@ -210,376 +227,569 @@ impl KeyKeeper {
}
};
- let sleep = if current_state == UNKNOWN_STATE
- && helpers::get_elapsed_time_in_millisec()
- < FREQUENT_PULL_TIMEOUT_IN_MILLISECONDS
- {
- // frequent poll the secure channel status every second for the first 5 minutes
- // until the secure channel state is known
- FREQUENT_PULL_INTERVAL
- } else {
- self.interval
- };
+ let sleep_interval = self.calculate_sleep_duration(¤t_state);
+ let (continue_loop, reset_timer) = self
+ .handle_notification(¬ify, sleep_interval, ¤t_state)
+ .await;
- let time = Instant::now();
- tokio::select! {
- // notify to query the secure channel status immediately when the secure channel state is unknown or disabled
- // this is to handle quicker response to the secure channel state change during VM provisioning.
- _ = notify.notified() => {
- if current_state == DISABLE_STATE || current_state == UNKNOWN_STATE {
- logger::write_warning(format!("poll_secure_channel_status task notified and secure channel state is '{current_state}', reset states and start poll status now."));
- provision::key_latch_ready_state_reset(self.provision_shared_state.clone()).await;
- if let Err(e) = self.key_keeper_shared_state.update_current_secure_channel_state(UNKNOWN_STATE.to_string()).await{
- logger::write_warning(format!("Failed to update secure channel state to 'Unknown': {e}"));
- }
-
- if start.elapsed().as_millis() > PROVISION_TIMEUP_IN_MILLISECONDS {
- // already timeup, reset the start timer
- start = Instant::now();
- provision_timeup = false;
- }
- } else {
- // report key latched ready to try update the provision finished time_tick
- provision::key_latched(
- self.cancellation_token.clone(),
- self.key_keeper_shared_state.clone(),
- self.telemetry_shared_state.clone(),
- self.provision_shared_state.clone(),
- self.agent_status_shared_state.clone(),
- ).await;
- let slept_time_in_millisec = time.elapsed().as_millis();
- let continue_sleep = sleep.as_millis() - slept_time_in_millisec;
- if continue_sleep > 0 {
- let continue_sleep = Duration::from_millis(continue_sleep as u64);
- let message = format!("poll_secure_channel_status task notified but secure channel state is '{current_state}', continue with sleep wait for {continue_sleep:?}.");
- logger::write_warning(message);
- tokio::time::sleep(continue_sleep).await;
- }
- }
- },
- _ = tokio::time::sleep(sleep) => {}
+ if reset_timer {
+ provision_start_time = Instant::now();
+ provision_timeout = false;
}
- }
- first_iteration = false;
- if !provision_timeup && start.elapsed().as_millis() > PROVISION_TIMEUP_IN_MILLISECONDS {
- provision::provision_timeup(
- None,
- self.provision_shared_state.clone(),
- self.agent_status_shared_state.clone(),
- )
- .await;
- provision_timeup = true;
+ if !continue_loop {
+ continue;
+ }
}
+ first_iteration = false;
- if !started_event_threads
- && helpers::get_elapsed_time_in_millisec()
- > DELAY_START_EVENT_THREADS_IN_MILLISECONDS
- {
- provision::start_event_threads(
- self.cancellation_token.clone(),
- self.key_keeper_shared_state.clone(),
- self.telemetry_shared_state.clone(),
- self.provision_shared_state.clone(),
- self.agent_status_shared_state.clone(),
- )
+ self.handle_provision_timeout(&mut provision_start_time, &mut provision_timeout)
.await;
- started_event_threads = true;
- }
+ started_event_threads = self.handle_event_threads_start(started_event_threads).await;
let status = match key::get_status(&self.base_url).await {
Ok(s) => s,
Err(e) => {
self.update_status_message(format!("Failed to get key status - {e}"), true)
.await;
+ // failed to get status, skip to next iteration
continue;
}
};
self.update_status_message(format!("Got key status successfully: {status}."), true)
.await;
- let mut access_control_rules_changed = false;
- let wireserver_rule_id = status.get_wireserver_rule_id();
- let imds_rule_id: String = status.get_imds_rule_id();
- let hostga_rule_id: String = status.get_hostga_rule_id();
+ self.update_access_control_rules(&status).await;
- match self
+ let state = status.get_secure_channel_state();
+ let secure_channel_state_updated = self
.key_keeper_shared_state
- .update_wireserver_rule_id(wireserver_rule_id.to_string())
+ .update_current_secure_channel_state(state.to_string())
+ .await;
+
+ if !self.handle_key_acquisition(&status, &state).await {
+ // Handle key acquisition failed, skip to next iteration
+ continue;
+ }
+
+ if self
+ .handle_secure_channel_state_change(
+ secure_channel_state_updated,
+ &state,
+ &status,
+ &mut redirect_policy_updated,
+ )
.await
{
- Ok((updated, old_wire_server_rule_id)) => {
- if updated {
- logger::write_warning(format!(
- "Wireserver rule id changed from '{old_wire_server_rule_id}' to '{wireserver_rule_id}'."
- ));
- if let Err(e) = self
- .key_keeper_shared_state
- .set_wireserver_rules(status.get_wireserver_rules())
- .await
- {
- logger::write_error(format!("Failed to set wireserver rules: {e}"));
- }
- access_control_rules_changed = true;
+ // successfully handled secure channel state change, skip to next iteration
+ continue;
+ }
+
+ // check and update the redirect policy if not updated successfully before, try again here
+ // this could happen when the eBPF/redirector module was not started yet before
+ if !redirect_policy_updated {
+ logger::write_warning(
+ "redirect policy was not update successfully before, retrying now".to_string(),
+ );
+ redirect_policy_updated = self.update_redirector_policy(&status).await;
+ }
+ }
+ }
+
+ /// Calculate sleep duration based on current secure channel state
+ fn calculate_sleep_duration(&self, current_state: &str) -> Duration {
+ if current_state == UNKNOWN_STATE
+ && helpers::get_elapsed_time_in_millisec() < FREQUENT_PULL_TIMEOUT_IN_MILLISECONDS
+ {
+ // frequent poll the secure channel status every second for the first 5 minutes
+ // until the secure channel state is known
+ FREQUENT_PULL_INTERVAL
+ } else {
+ self.interval
+ }
+ }
+
+ async fn wait_for_wake(notify: &tokio::sync::Notify, sleep_interval: Duration) -> WakeReason {
+ tokio::select! {
+ // if a notification arrives
+ _ = notify.notified() => WakeReason::Notified,
+ // if the sleep duration elapses
+ _ = tokio::time::sleep(sleep_interval) => WakeReason::TimerElapsed,
+ }
+ }
+
+ /// Handle notification and sleep logic
+ /// Returns (continue_loop, reset_timer)
+ async fn handle_notification(
+ &self,
+ notify: &tokio::sync::Notify,
+ sleep_interval: Duration,
+ current_state: &str,
+ ) -> (bool, bool) {
+ let start_time = Instant::now();
+
+ match Self::wait_for_wake(notify, sleep_interval).await {
+ WakeReason::Notified => {
+ self.handle_notified_state(current_state, sleep_interval, start_time)
+ .await
+ }
+ WakeReason::TimerElapsed => (true, false),
+ }
+ }
+
+ /// Handle the notified state - either reset or continue with key latched
+ /// Returns (continue_loop, reset_timer)
+ async fn handle_notified_state(
+ &self,
+ current_state: &str,
+ sleep_interval: Duration,
+ current_loop_iteration_start_time: Instant,
+ ) -> (bool, bool) {
+ // notify to query the secure channel status immediately when the secure channel state is unknown or disabled
+ // this is to handle quicker response to the secure channel state change during VM provisioning.
+ if self.should_reset_state(current_state) {
+ self.reset_state_on_notification(current_state).await
+ } else {
+ self.continue_with_key_latched(
+ current_state,
+ sleep_interval,
+ current_loop_iteration_start_time,
+ )
+ .await
+ }
+ }
+
+ /// Check if the state should be reset based on current secure channel state
+ fn should_reset_state(&self, current_state: &str) -> bool {
+ current_state == DISABLE_STATE || current_state == UNKNOWN_STATE
+ }
+
+ /// Reset state when notified in disabled or unknown state
+ /// Returns (continue_loop, reset_timer)
+ async fn reset_state_on_notification(&self, current_state: &str) -> (bool, bool) {
+ logger::write_warning(format!(
+ "poll_secure_channel_status task notified and secure channel state is '{current_state}', reset states and start poll status now."
+ ));
+
+ provision::key_latch_ready_state_reset(self.provision_shared_state.clone()).await;
+
+ if let Err(e) = self
+ .key_keeper_shared_state
+ .update_current_secure_channel_state(UNKNOWN_STATE.to_string())
+ .await
+ {
+ logger::write_warning(format!(
+ "Failed to update secure channel state to 'Unknown': {e}"
+ ));
+ }
+
+ (true, true)
+ }
+
+ /// Continue with key latched when notified in a stable state
+ /// Returns (continue_loop, reset_timer)
+ async fn continue_with_key_latched(
+ &self,
+ current_state: &str,
+ sleep_interval: Duration,
+ current_loop_iteration_start_time: Instant,
+ ) -> (bool, bool) {
+ // report key latched ready to try update the provision finished time_tick
+ provision::key_latched(self.create_event_threads_shared_state()).await;
+
+ self.handle_remaining_sleep(
+ current_state,
+ sleep_interval,
+ current_loop_iteration_start_time,
+ )
+ .await;
+
+ (true, false)
+ }
+
+ /// Handle the remaining sleep time after the notification
+ async fn handle_remaining_sleep(
+ &self,
+ current_state: &str,
+ sleep_interval: Duration,
+ current_loop_iteration_start_time: Instant,
+ ) {
+ let slept_time_in_millisec = current_loop_iteration_start_time.elapsed().as_millis();
+ let continue_sleep = sleep_interval
+ .as_millis()
+ .saturating_sub(slept_time_in_millisec);
+ if continue_sleep > 0 {
+ // continue sleep with the remaining time for current loop iteration
+ // it is to avoid too frequent polling when the secure channel state is stable
+ let continue_sleep = Duration::from_millis(continue_sleep as u64);
+ let message = format!(
+ "poll_secure_channel_status task notified but secure channel state is '{current_state}', continue with sleep wait for {continue_sleep:?}."
+ );
+ logger::write_warning(message);
+ tokio::time::sleep(continue_sleep).await;
+ }
+ }
+
+ /// Handle provision timeout logic
+ async fn handle_provision_timeout(
+ &self,
+ provision_start_time: &mut Instant,
+ provision_timeout: &mut bool,
+ ) {
+ if !*provision_timeout
+ && provision_start_time.elapsed().as_millis() > PROVISION_TIMEOUT_IN_MILLISECONDS
+ {
+ provision::provision_timeout(
+ None,
+ self.provision_shared_state.clone(),
+ self.agent_status_shared_state.clone(),
+ )
+ .await;
+ *provision_timeout = true;
+ }
+ }
+
+ /// Handle starting event threads
+ async fn handle_event_threads_start(&self, started_event_threads: bool) -> bool {
+ if !started_event_threads
+ && helpers::get_elapsed_time_in_millisec() > DELAY_START_EVENT_THREADS_IN_MILLISECONDS
+ {
+ provision::start_event_threads(self.create_event_threads_shared_state()).await;
+ return true;
+ }
+ started_event_threads
+ }
+
+ /// Update access control rules from the key status
+ /// Returns true if any rules changed
+ async fn update_access_control_rules(&self, status: &KeyStatus) -> bool {
+ let mut access_control_rules_changed = false;
+ let wireserver_rule_id = status.get_wireserver_rule_id();
+ let imds_rule_id = status.get_imds_rule_id();
+ let hostga_rule_id = status.get_hostga_rule_id();
+
+ // Update wireserver rules
+ match self
+ .key_keeper_shared_state
+ .update_wireserver_rule_id(wireserver_rule_id.to_string())
+ .await
+ {
+ Ok((updated, old_wire_server_rule_id)) => {
+ if updated {
+ logger::write_warning(format!(
+ "Wireserver rule id changed from '{old_wire_server_rule_id}' to '{wireserver_rule_id}'."
+ ));
+ if let Err(e) = self
+ .access_control_shared_state
+ .set_wireserver_rules(status.get_wireserver_rules())
+ .await
+ {
+ logger::write_error(format!("Failed to set wireserver rules: {e}"));
}
- }
- Err(e) => {
- logger::write_warning(format!("Failed to update wireserver rule id: {e}"));
+ access_control_rules_changed = true;
}
}
+ Err(e) => {
+ logger::write_warning(format!("Failed to update wireserver rule id: {e}"));
+ }
+ }
- match self
- .key_keeper_shared_state
- .update_imds_rule_id(imds_rule_id.to_string())
- .await
- {
- Ok((updated, old_imds_rule_id)) => {
- if updated {
- logger::write_warning(format!(
- "IMDS rule id changed from '{old_imds_rule_id}' to '{imds_rule_id}'."
- ));
- if let Err(e) = self
- .key_keeper_shared_state
- .set_imds_rules(status.get_imds_rules())
- .await
- {
- logger::write_error(format!("Failed to set imds rules: {e}"));
- }
- access_control_rules_changed = true;
+ // Update IMDS rules
+ match self
+ .key_keeper_shared_state
+ .update_imds_rule_id(imds_rule_id.to_string())
+ .await
+ {
+ Ok((updated, old_imds_rule_id)) => {
+ if updated {
+ logger::write_warning(format!(
+ "IMDS rule id changed from '{old_imds_rule_id}' to '{imds_rule_id}'."
+ ));
+ if let Err(e) = self
+ .access_control_shared_state
+ .set_imds_rules(status.get_imds_rules())
+ .await
+ {
+ logger::write_error(format!("Failed to set imds rules: {e}"));
}
- }
- Err(e) => {
- logger::write_warning(format!("Failed to update imds rule id: {e}"));
+ access_control_rules_changed = true;
}
}
+ Err(e) => {
+ logger::write_warning(format!("Failed to update imds rule id: {e}"));
+ }
+ }
- match self
- .key_keeper_shared_state
- .update_hostga_rule_id(hostga_rule_id.to_string())
- .await
- {
- Ok((updated, old_hostga_rule_id)) => {
- if updated {
- logger::write_warning(format!(
- "HostGA rule id changed from '{old_hostga_rule_id}' to '{hostga_rule_id}'."
- ));
- if let Err(e) = self
- .key_keeper_shared_state
- .set_hostga_rules(status.get_hostga_rules())
- .await
- {
- logger::write_error(format!("Failed to set HostGA rules: {e}"));
- }
- access_control_rules_changed = true;
+ // Update HostGA rules
+ match self
+ .key_keeper_shared_state
+ .update_hostga_rule_id(hostga_rule_id.to_string())
+ .await
+ {
+ Ok((updated, old_hostga_rule_id)) => {
+ if updated {
+ logger::write_warning(format!(
+ "HostGA rule id changed from '{old_hostga_rule_id}' to '{hostga_rule_id}'."
+ ));
+ if let Err(e) = self
+ .access_control_shared_state
+ .set_hostga_rules(status.get_hostga_rules())
+ .await
+ {
+ logger::write_error(format!("Failed to set HostGA rules: {e}"));
}
+ access_control_rules_changed = true;
}
- Err(e) => {
- logger::write_warning(format!("Failed to update HostGA rule id: {e}"));
- }
}
+ Err(e) => {
+ logger::write_warning(format!("Failed to update HostGA rule id: {e}"));
+ }
+ }
- if access_control_rules_changed {
- if let (Ok(wireserver_rules), Ok(imds_rules), Ok(hostga_rules)) = (
- self.key_keeper_shared_state.get_wireserver_rules().await,
- self.key_keeper_shared_state.get_imds_rules().await,
- self.key_keeper_shared_state.get_hostga_rules().await,
- ) {
- let rules = AuthorizationRulesForLogging::new(
- status.authorizationRules.clone(),
- ComputedAuthorizationRules {
- wireserver: wireserver_rules,
- imds: imds_rules,
- hostga: hostga_rules,
- },
+ // Write authorization rules to file if changed
+ if access_control_rules_changed {
+ if let (Ok(wireserver_rules), Ok(imds_rules), Ok(hostga_rules)) = (
+ self.access_control_shared_state
+ .get_wireserver_rules()
+ .await,
+ self.access_control_shared_state.get_imds_rules().await,
+ self.access_control_shared_state.get_hostga_rules().await,
+ ) {
+ let rules = AuthorizationRulesForLogging::new(
+ status.authorizationRules.clone(),
+ ComputedAuthorizationRules {
+ wireserver: wireserver_rules,
+ imds: imds_rules,
+ hostga: hostga_rules,
+ },
+ );
+ rules.write_all(&self.status_dir, constants::MAX_LOG_FILE_COUNT);
+ }
+ }
+
+ access_control_rules_changed
+ }
+
+ /// Handle key acquisition from local or server
+ /// Returns true if successful, false if should continue to next iteration
+ async fn handle_key_acquisition(&self, status: &KeyStatus, state: &str) -> bool {
+ // check if need fetch the key
+ if state != DISABLE_STATE
+ && (status.keyGuid.is_none() // key has not latched yet
+ || status.keyGuid != self.key_keeper_shared_state.get_current_key_guid().await.unwrap_or(None))
+ // key changed
+ {
+ if self.try_fetch_local_key(&status.keyGuid).await {
+ return true;
+ }
+
+ // if key has not latched before, or not found, or could not read locally,
+ // try fetch from server
+ return self.acquire_key_from_server().await;
+ }
+ true
+ }
+
+ /// Try to fetch key from local storage
+ /// Returns true if key was found and loaded successfully
+ async fn try_fetch_local_key(&self, key_guid: &Option) -> bool {
+ if let Some(guid) = key_guid {
+ // key latched before and search the key locally first
+ match Self::fetch_key(&self.key_dir, guid) {
+ Ok(key) => {
+ if let Err(e) = self.update_key_to_shared_state(key.clone()).await {
+ logger::write_warning(format!("Failed to update key: {e}"));
+ }
+
+ let message = helpers::write_startup_event(
+ "Found key details from local and ready to use.",
+ "poll_secure_channel_status",
+ "key_keeper",
+ logger::AGENT_LOGGER_KEY,
+ );
+ self.update_status_message(message, false).await;
+
+ provision::key_latched(self.create_event_threads_shared_state()).await;
+ return true;
+ }
+ Err(e) => {
+ event_logger::write_event(
+ LoggerLevel::Info,
+ format!("Failed to fetch local key details with error: {e:?}. Will try acquire the key details from Server."),
+ "poll_secure_channel_status",
+ "key_keeper",
+ logger::AGENT_LOGGER_KEY,
);
- rules.write_all(&self.status_dir, constants::MAX_LOG_FILE_COUNT);
}
+ };
+ }
+ false
+ }
+
+ /// Acquire key from server, persist it, and attest it
+ /// Returns true if successful, false if should continue to next iteration
+ async fn acquire_key_from_server(&self) -> bool {
+ let key = match key::acquire_key(&self.base_url).await {
+ Ok(k) => k,
+ Err(e) => {
+ self.update_status_message(format!("Failed to acquire key details: {e:?}"), true)
+ .await;
+ return false;
}
+ };
- let state = status.get_secure_channel_state();
- let secure_channel_state_updated = self
- .key_keeper_shared_state
- .update_current_secure_channel_state(state.to_string())
+ // persist the new key to local disk
+ let guid = key.guid.to_string();
+ if let Err(e) = Self::store_key(&self.key_dir, &key) {
+ self.update_status_message(format!("Failed to save key details to file: {e:?}"), true)
.await;
+ return false;
+ }
+ logger::write_information(format!(
+ "Successfully acquired the key '{guid}' details from server and saved locally."
+ ));
- // check if need fetch the key
- if state != DISABLE_STATE
- && (status.keyGuid.is_none() // key has not latched yet
- || status.keyGuid != self.key_keeper_shared_state.get_current_key_guid().await.unwrap_or(None))
- // key changed
- {
- let mut key_found = false;
- if let Some(guid) = &status.keyGuid {
- // key latched before and search the key locally first
- match Self::fetch_key(&self.key_dir, guid) {
- Ok(key) => {
- if let Err(e) =
- self.key_keeper_shared_state.update_key(key.clone()).await
- {
- logger::write_warning(format!("Failed to update key: {e}"));
- }
-
- let message = helpers::write_startup_event(
- "Found key details from local and ready to use.",
- "poll_secure_channel_status",
- "key_keeper",
- logger::AGENT_LOGGER_KEY,
- );
- self.update_status_message(message, false).await;
- key_found = true;
-
- provision::key_latched(
- self.cancellation_token.clone(),
- self.key_keeper_shared_state.clone(),
- self.telemetry_shared_state.clone(),
- self.provision_shared_state.clone(),
- self.agent_status_shared_state.clone(),
- )
- .await;
- }
- Err(e) => {
- event_logger::write_event(
- LoggerLevel::Info,
- format!("Failed to fetch local key details with error: {e:?}. Will try acquire the key details from Server."),
- "poll_secure_channel_status",
- "key_keeper",
- logger::AGENT_LOGGER_KEY,
- );
- }
- };
+ // double check the key details saved correctly to local disk
+ if let Err(e) = Self::check_key(&self.key_dir, &key) {
+ self.update_status_message(
+ format!("Failed to check the key '{guid}' details saved locally: {e:?}."),
+ true,
+ )
+ .await;
+ return false;
+ }
+
+ // attest the key
+ match key::attest_key(&self.base_url, &key).await {
+ Ok(()) => {
+ // update in memory
+ if let Err(e) = self.update_key_to_shared_state(key.clone()).await {
+ logger::write_warning(format!("Failed to update key: {e}"));
}
- // if key has not latched before,
- // or not found
- // or could not read locally,
- // try fetch from server
- if !key_found {
- let key = match key::acquire_key(&self.base_url).await {
- Ok(k) => k,
- Err(e) => {
- self.update_status_message(
- format!("Failed to acquire key details: {e:?}"),
- true,
- )
- .await;
- continue;
- }
- };
-
- // persist the new key to local disk
- let guid = key.guid.to_string();
- match Self::store_key(&self.key_dir, &key) {
- Ok(()) => {
- logger::write_information(format!(
- "Successfully acquired the key '{guid}' details from server and saved locally."));
- }
- Err(e) => {
- self.update_status_message(
- format!("Failed to save key details to file: {e:?}"),
- true,
- )
- .await;
- continue;
- }
- }
+ let message = helpers::write_startup_event(
+ "Successfully attest the key and ready to use.",
+ "poll_secure_channel_status",
+ "key_keeper",
+ logger::AGENT_LOGGER_KEY,
+ );
+ self.update_status_message(message, false).await;
+ provision::key_latched(self.create_event_threads_shared_state()).await;
+ true
+ }
+ Err(e) => {
+ logger::write_warning(format!("Failed to attest the key: {e:?}"));
+ false
+ }
+ }
+ }
- // double check the key details saved correctly to local disk
- if let Err(e) = Self::check_key(&self.key_dir, &key) {
- self.update_status_message(
- format!(
- "Failed to check the key '{guid}' details saved locally: {e:?}."
- ),
- true,
- )
- .await;
- continue;
- } else {
- match key::attest_key(&self.base_url, &key).await {
- Ok(()) => {
- // update in memory
- if let Err(e) =
- self.key_keeper_shared_state.update_key(key.clone()).await
- {
- logger::write_warning(format!("Failed to update key: {e}"));
- }
-
- let message = helpers::write_startup_event(
- "Successfully attest the key and ready to use.",
- "poll_secure_channel_status",
- "key_keeper",
- logger::AGENT_LOGGER_KEY,
- );
- self.update_status_message(message, false).await;
- provision::key_latched(
- self.cancellation_token.clone(),
- self.key_keeper_shared_state.clone(),
- self.telemetry_shared_state.clone(),
- self.provision_shared_state.clone(),
- self.agent_status_shared_state.clone(),
- )
- .await;
- }
- Err(e) => {
- logger::write_warning(format!("Failed to attest the key: {e:?}"));
- continue;
- }
+ /// Handle secure channel state change
+ /// Returns true if should continue to next iteration
+ async fn handle_secure_channel_state_change(
+ &self,
+ secure_channel_state_updated: std::result::Result,
+ state: &str,
+ status: &KeyStatus,
+ redirect_policy_updated: &mut bool,
+ ) -> bool {
+ match secure_channel_state_updated {
+ Ok(updated) => {
+ if updated {
+ // secure channel state changed, update the redirect policy
+ *redirect_policy_updated = self.update_redirector_policy(status).await;
+
+ // customer has not enforce the secure channel state
+ if state == DISABLE_STATE {
+ let message = helpers::write_startup_event(
+ "Customer has not enforce the secure channel state.",
+ "poll_secure_channel_status",
+ "key_keeper",
+ logger::AGENT_LOGGER_KEY,
+ );
+ // Update the status message and let the provision to continue
+ self.update_status_message(message, false).await;
+
+ // clear key in memory for disabled state
+ if let Err(e) = self.key_keeper_shared_state.clear_key().await {
+ logger::write_warning(format!("Failed to clear key: {e}"));
}
+ provision::key_latched(self.create_event_threads_shared_state()).await;
}
+
+ return true;
}
}
+ Err(e) => {
+ logger::write_warning(format!("Failed to update secure channel state: {e}"));
+ }
+ }
+ false
+ }
- // update redirect policy if current secure channel state updated
- match secure_channel_state_updated {
- Ok(updated) => {
- if updated {
- // update the redirector policy map
- redirector::update_wire_server_redirect_policy(
- status.get_wire_server_mode() != DISABLE_STATE,
- self.redirector_shared_state.clone(),
- )
- .await;
- redirector::update_imds_redirect_policy(
- status.get_imds_mode() != DISABLE_STATE,
- self.redirector_shared_state.clone(),
- )
- .await;
- redirector::update_hostga_redirect_policy(
- status.get_hostga_mode() != DISABLE_STATE,
- self.redirector_shared_state.clone(),
- )
- .await;
+ /// Create EventThreadsSharedState from current state
+ fn create_event_threads_shared_state(&self) -> EventThreadsSharedState {
+ EventThreadsSharedState {
+ cancellation_token: self.cancellation_token.clone(),
+ common_state: self.common_state.clone(),
+ access_control_shared_state: self.access_control_shared_state.clone(),
+ redirector_shared_state: self.redirector_shared_state.clone(),
+ key_keeper_shared_state: self.key_keeper_shared_state.clone(),
+ provision_shared_state: self.provision_shared_state.clone(),
+ agent_status_shared_state: self.agent_status_shared_state.clone(),
+ connection_summary_shared_state: self.connection_summary_shared_state.clone(),
+ }
+ }
- // customer has not enforce the secure channel state
- if state == DISABLE_STATE {
- let message = helpers::write_startup_event(
- "Customer has not enforce the secure channel state.",
- "poll_secure_channel_status",
- "key_keeper",
- logger::AGENT_LOGGER_KEY,
- );
- // Update the status message and let the provision to continue
- self.update_status_message(message, false).await;
-
- // clear key in memory for disabled state
- if let Err(e) = self.key_keeper_shared_state.clear_key().await {
- logger::write_warning(format!("Failed to clear key: {e}"));
- }
- provision::key_latched(
- self.cancellation_token.clone(),
- self.key_keeper_shared_state.clone(),
- self.telemetry_shared_state.clone(),
- self.provision_shared_state.clone(),
- self.agent_status_shared_state.clone(),
- )
- .await;
- }
- }
- }
- Err(e) => {
- logger::write_warning(format!("Failed to update secure channel state: {e}"));
- }
- }
+ async fn update_key_to_shared_state(&self, key: Key) -> Result<()> {
+ self.key_keeper_shared_state.update_key(key.clone()).await?;
+
+ // update the current key guid and value to common states
+ self.common_state
+ .set_state(
+ proxy_agent_shared::common_state::SECURE_KEY_GUID.to_string(),
+ key.guid.to_string(),
+ )
+ .await?;
+ self.common_state
+ .set_state(
+ proxy_agent_shared::common_state::SECURE_KEY_VALUE.to_string(),
+ key.key.to_string(),
+ )
+ .await?;
+ Ok(())
+ }
+
+ /// update the redirector/eBPF policy based on the secure channel status
+ /// it should be called when the secure channel state is changed
+ async fn update_redirector_policy(&self, status: &KeyStatus) -> bool {
+ // update the redirector policy map
+ if !redirector::update_wire_server_redirect_policy(
+ status.get_wire_server_mode() != DISABLE_STATE,
+ self.redirector_shared_state.clone(),
+ )
+ .await
+ {
+ return false;
+ }
+ if !redirector::update_imds_redirect_policy(
+ status.get_imds_mode() != DISABLE_STATE,
+ self.redirector_shared_state.clone(),
+ )
+ .await
+ {
+ return false;
}
+ if !redirector::update_hostga_redirect_policy(
+ status.get_hostga_mode() != DISABLE_STATE,
+ self.redirector_shared_state.clone(),
+ )
+ .await
+ {
+ return false;
+ }
+
+ true
}
async fn update_status_message(&self, message: String, log_to_file: bool) {
@@ -849,10 +1059,12 @@ mod tests {
interval: Duration::from_millis(10),
cancellation_token: cancellation_token.clone(),
key_keeper_shared_state: key_keeper::KeyKeeperSharedState::start_new(),
- telemetry_shared_state: key_keeper::TelemetrySharedState::start_new(),
+ common_state: key_keeper::CommonState::start_new(),
redirector_shared_state: key_keeper::RedirectorSharedState::start_new(),
provision_shared_state: key_keeper::ProvisionSharedState::start_new(),
agent_status_shared_state: key_keeper::AgentStatusSharedState::start_new(),
+ access_control_shared_state: key_keeper::AccessControlSharedState::start_new(),
+ connection_summary_shared_state: key_keeper::ConnectionSummarySharedState::start_new(),
};
tokio::spawn({
diff --git a/proxy_agent/src/key_keeper/key.rs b/proxy_agent/src/key_keeper/key.rs
index 7039b0e6..667b0932 100644
--- a/proxy_agent/src/key_keeper/key.rs
+++ b/proxy_agent/src/key_keeper/key.rs
@@ -580,12 +580,14 @@ impl KeyStatus {
}
pub fn get_hostga_rules(&self) -> Option {
- // short-term: HostGA has no rules
+ // match &self.authorizationRules {
+ // Some(rules) => rules.hostga.clone(),
+ // None => None,
+ // }
+
+ // short-term: HostGA uses wireserver rules
// long-term: TBD
- match &self.authorizationRules {
- Some(rules) => rules.hostga.clone(),
- None => None,
- }
+ self.get_wireserver_rules()
}
pub fn get_wire_server_mode(&self) -> String {
@@ -1261,95 +1263,105 @@ mod tests {
"roleAssignment identities mismatch"
);
- // Validate HostGA rules
+ // Validate HostGA rules to match WireServer rules
let hostga_rules = status.get_hostga_rules().unwrap();
assert_eq!(
- "allow", hostga_rules.defaultAccess,
- "defaultAccess mismatch"
+ wireserver_rules.defaultAccess, hostga_rules.defaultAccess,
+ "hostga_rules defaultAccess mismatch"
);
assert_eq!(
- "sigid",
+ status.get_wireserver_rule_id(),
status.get_hostga_rule_id(),
"HostGA rule id mismatch"
);
- assert_eq!("enforce", status.get_hostga_mode(), "HostGA mode mismatch");
-
- // Validate HostGA rule details
- // Retrieve and validate second privilege for HostGA
- let privilege = &hostga_rules
- .rules
- .as_ref()
- .unwrap()
- .privileges
- .as_ref()
- .unwrap()[1];
-
- assert_eq!("test2", privilege.name, "privilege name mismatch");
- assert_eq!("/test2", privilege.path, "privilege path mismatch");
-
- assert_eq!(
- "value3",
- privilege.queryParameters.as_ref().unwrap()["key1"],
- "privilege queryParameters mismatch"
- );
- assert_eq!(
- "value4",
- privilege.queryParameters.as_ref().unwrap()["key2"],
- "privilege queryParameters mismatch"
- );
-
- // Retrieve and validate second role for HostGA
- let role = &hostga_rules.rules.as_ref().unwrap().roles.as_ref().unwrap()[1];
- assert_eq!("test6", role.name, "role name mismatch");
- assert_eq!("test4", role.privileges[0], "role privilege mismatch");
- assert_eq!("test5", role.privileges[1], "role privilege mismatch");
-
- // Retrieve and validate first identity for HostGA
- let identity = &hostga_rules
- .rules
- .as_ref()
- .unwrap()
- .identities
- .as_ref()
- .unwrap()[0];
- assert_eq!("test", identity.name, "identity name mismatch");
- assert_eq!(
- "test",
- identity.userName.as_ref().unwrap(),
- "identity userName mismatch"
- );
- assert_eq!(
- "test",
- identity.groupName.as_ref().unwrap(),
- "identity groupName mismatch"
- );
- assert_eq!(
- "test",
- identity.exePath.as_ref().unwrap(),
- "identity exePath mismatch"
- );
- assert_eq!(
- "test",
- identity.processName.as_ref().unwrap(),
- "identity processName mismatch"
- );
-
- // Retrieve and validate first role assignment for HostGA
- let role_assignment = &hostga_rules
- .rules
- .as_ref()
- .unwrap()
- .roleAssignments
- .as_ref()
- .unwrap()[0];
- assert_eq!(
- "test4", role_assignment.role,
- "roleAssignment role mismatch"
- );
assert_eq!(
- "test", role_assignment.identities[0],
- "roleAssignment identities mismatch"
- );
+ status.get_wire_server_mode(),
+ status.get_hostga_mode(),
+ "HostGA mode mismatch"
+ );
+
+ /***
+ *
+ * Validate HostGA rule details in future when HostGA has independent rules
+
+ // Validate HostGA rule details
+ // Retrieve and validate second privilege for HostGA
+ let privilege = &hostga_rules
+ .rules
+ .as_ref()
+ .unwrap()
+ .privileges
+ .as_ref()
+ .unwrap()[1];
+
+ assert_eq!("test2", privilege.name, "privilege name mismatch");
+ assert_eq!("/test2", privilege.path, "privilege path mismatch");
+
+ assert_eq!(
+ "value3",
+ privilege.queryParameters.as_ref().unwrap()["key1"],
+ "privilege queryParameters mismatch"
+ );
+ assert_eq!(
+ "value4",
+ privilege.queryParameters.as_ref().unwrap()["key2"],
+ "privilege queryParameters mismatch"
+ );
+
+ // Retrieve and validate second role for HostGA
+ let role = &hostga_rules.rules.as_ref().unwrap().roles.as_ref().unwrap()[1];
+ assert_eq!("test6", role.name, "role name mismatch");
+ assert_eq!("test4", role.privileges[0], "role privilege mismatch");
+ assert_eq!("test5", role.privileges[1], "role privilege mismatch");
+
+ // Retrieve and validate first identity for HostGA
+ let identity = &hostga_rules
+ .rules
+ .as_ref()
+ .unwrap()
+ .identities
+ .as_ref()
+ .unwrap()[0];
+ assert_eq!("test", identity.name, "identity name mismatch");
+ assert_eq!(
+ "test",
+ identity.userName.as_ref().unwrap(),
+ "identity userName mismatch"
+ );
+ assert_eq!(
+ "test",
+ identity.groupName.as_ref().unwrap(),
+ "identity groupName mismatch"
+ );
+ assert_eq!(
+ "test",
+ identity.exePath.as_ref().unwrap(),
+ "identity exePath mismatch"
+ );
+ assert_eq!(
+ "test",
+ identity.processName.as_ref().unwrap(),
+ "identity processName mismatch"
+ );
+
+ // Retrieve and validate first role assignment for HostGA
+ let role_assignment = &hostga_rules
+ .rules
+ .as_ref()
+ .unwrap()
+ .roleAssignments
+ .as_ref()
+ .unwrap()[0];
+ assert_eq!(
+ "test4", role_assignment.role,
+ "roleAssignment role mismatch"
+ );
+ assert_eq!(
+ "test", role_assignment.identities[0],
+ "roleAssignment identities mismatch"
+ );
+ *
+ */
}
#[test]
diff --git a/proxy_agent/src/main.rs b/proxy_agent/src/main.rs
index 460425ca..7a73720c 100644
--- a/proxy_agent/src/main.rs
+++ b/proxy_agent/src/main.rs
@@ -10,7 +10,6 @@ pub mod proxy_agent_status;
pub mod redirector;
pub mod service;
pub mod shared_state;
-pub mod telemetry;
use common::cli::{Commands, CLI};
use common::constants;
diff --git a/proxy_agent/src/provision.rs b/proxy_agent/src/provision.rs
index 4c4774c5..fd11fbea 100644
--- a/proxy_agent/src/provision.rs
+++ b/proxy_agent/src/provision.rs
@@ -5,94 +5,23 @@
//! It is used to track the provision state for each module and write the provision state to provisioned.tag and status.tag files.
//! It also provides the http handler to query the provision status for GPA service.
//! It is used to query the provision status from GPA service http listener.
-//! Example for GPA service:
-//! ```rust
-//! use proxy_agent::provision;
-//! use proxy_agent::shared_state::agent_status_wrapper::AgentStatusModule;
-//! use proxy_agent::shared_state::agent_status_wrapper::AgentStatusSharedState;
-//! use proxy_agent::shared_state::key_keeper_wrapper::KeyKeeperSharedState;
-//! use proxy_agent::shared_state::provision_wrapper::ProvisionSharedState;
-//! use proxy_agent::shared_state::SharedState;
-//! use proxy_agent::shared_state::telemetry_wrapper::TelemetrySharedState;
-//!
-//! use std::time::Duration;
-//!
-//! let shared_state = SharedState::start_all();
-//! let cancellation_token = shared_state.get_cancellation_token();
-//! let key_keeper_shared_state = shared_state.get_key_keeper_shared_state();
-//! let telemetry_shared_state = shared_state.get_telemetry_shared_state();
-//! let provision_shared_state = shared_state.get_provision_shared_state();
-//! let agent_status_shared_state = shared_state.get_agent_status_shared_state();
-//!
-//! let provision_state = provision::get_provision_state(
-//! provision_shared_state.clone(),
-//! agent_status_shared_state.clone(),
-//! ).await;
-//! assert_eq!(false, provision_state.finished);
-//! assert_eq!(0, provision_state.errorMessage.len());
-//!
-//! // update provision state when each provision finished
-//! provision::redirector_ready(
-//! cancellation_token.clone(),
-//! key_keeper_shared_state.clone(),
-//! telemetry_shared_state.clone(),
-//! provision_shared_state.clone(),
-//! agent_status_shared_state.clone(),
-//! ).await;
-//! provision::key_latched(
-//! cancellation_token.clone(),
-//! key_keeper_shared_state.clone(),
-//! telemetry_shared_state.clone(),
-//! provision_shared_state.clone(),
-//! agent_status_shared_state.clone(),
-//! ).await;
-//! provision::listener_started(
-//! cancellation_token.clone(),
-//! key_keeper_shared_state.clone(),
-//! telemetry_shared_state.clone(),
-//! provision_shared_state.clone(),
-//! agent_status_shared_state.clone(),
-//! ).await;
-//!
-//! let provision_state = provision::get_provision_state(
-//! provision_shared_state.clone(),
-//! agent_status_shared_state.clone(),
-//! ).await;
-//! assert_eq!(true, provision_state.finished);
-//! assert_eq!(0, provision_state.errorMessage.len());
-//! ```
-//!
-//! Example for GPA command line option --status [--wait seconds]:
-//! ```rust
-//! use proxy_agent::provision::ProvisionQuery;
-//! use std::time::Duration;
-//!
-//! let proxy_server_port = 8092;
-//! let provision_query = ProvisionQuery::new(proxy_server_port, None);
-//! let provision_not_finished_state = provision_query.get_provision_status_wait().await;
-//! assert_eq!(false, provision_state.0);
-//! assert_eq!(0, provision_state.1.len());
-//!
-//! let provision_query = ProvisionQuery::new(proxy_server_port, Some(Duration::from_millis(5)));
-//! let provision_finished_state = provision_query.get_provision_status_wait().await;
-//! assert_eq!(true, provision_state.0);
-//! assert_eq!(0, provision_state.1.len());
-//! ```
-
-use crate::common::{config, helpers, logger};
+
+use crate::common::{config, logger};
use crate::key_keeper::{DISABLE_STATE, UNKNOWN_STATE};
-use crate::proxy_agent_status;
+use crate::proxy::authorization_rules::AuthorizationMode;
+use crate::shared_state::access_control_wrapper::AccessControlSharedState;
use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState};
use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState;
use crate::shared_state::provision_wrapper::ProvisionSharedState;
-use crate::shared_state::telemetry_wrapper::TelemetrySharedState;
-use crate::telemetry::event_reader::EventReader;
+use crate::shared_state::redirector_wrapper::RedirectorSharedState;
+use crate::shared_state::EventThreadsSharedState;
+use crate::{proxy_agent_status, redirector};
use proxy_agent_shared::logger::LoggerLevel;
use proxy_agent_shared::telemetry::event_logger;
+use proxy_agent_shared::telemetry::event_reader::EventReader;
use proxy_agent_shared::{misc_helpers, proxy_agent_aggregate_status};
use std::path::PathBuf;
use std::time::Duration;
-use tokio_util::sync::CancellationToken;
const PROVISION_TAG_FILE_NAME: &str = "provisioned.tag";
const STATUS_TAG_TMP_FILE_NAME: &str = "status.tag.tmp";
@@ -153,63 +82,33 @@ impl ProvisionStateInternal {
/// Update provision state when redirector provision finished
/// It could be called by redirector module
-pub async fn redirector_ready(
- cancellation_token: CancellationToken,
- key_keeper_shared_state: KeyKeeperSharedState,
- telemetry_shared_state: TelemetrySharedState,
- provision_shared_state: ProvisionSharedState,
- agent_status_shared_state: AgentStatusSharedState,
-) {
+pub async fn redirector_ready(event_threads_shared_state: EventThreadsSharedState) {
update_provision_state(
ProvisionFlags::REDIRECTOR_READY,
None,
- cancellation_token,
- key_keeper_shared_state,
- telemetry_shared_state,
- provision_shared_state,
- agent_status_shared_state,
+ event_threads_shared_state,
)
.await;
}
/// Update provision state when key latch provision finished
/// It could be called by key latch module
-pub async fn key_latched(
- cancellation_token: CancellationToken,
- key_keeper_shared_state: KeyKeeperSharedState,
- telemetry_shared_state: TelemetrySharedState,
- provision_shared_state: ProvisionSharedState,
- agent_status_shared_state: AgentStatusSharedState,
-) {
+pub async fn key_latched(event_threads_shared_state: EventThreadsSharedState) {
update_provision_state(
ProvisionFlags::KEY_LATCH_READY,
None,
- cancellation_token,
- key_keeper_shared_state,
- telemetry_shared_state,
- provision_shared_state,
- agent_status_shared_state,
+ event_threads_shared_state,
)
.await;
}
/// Update provision state when listener provision finished
/// It could be called by listener module
-pub async fn listener_started(
- cancellation_token: CancellationToken,
- key_keeper_shared_state: KeyKeeperSharedState,
- telemetry_shared_state: TelemetrySharedState,
- provision_shared_state: ProvisionSharedState,
- agent_status_shared_state: AgentStatusSharedState,
-) {
+pub async fn listener_started(event_threads_shared_state: EventThreadsSharedState) {
update_provision_state(
ProvisionFlags::LISTENER_READY,
None,
- cancellation_token,
- key_keeper_shared_state,
- telemetry_shared_state,
- provision_shared_state,
- agent_status_shared_state,
+ event_threads_shared_state,
)
.await;
}
@@ -218,15 +117,28 @@ pub async fn listener_started(
async fn update_provision_state(
state: ProvisionFlags,
provision_dir: Option,
- cancellation_token: CancellationToken,
- key_keeper_shared_state: KeyKeeperSharedState,
- telemetry_shared_state: TelemetrySharedState,
- provision_shared_state: ProvisionSharedState,
- agent_status_shared_state: AgentStatusSharedState,
+ event_threads_shared_state: EventThreadsSharedState,
) {
- if let Ok(provision_state) = provision_shared_state.update_one_state(state).await {
+ if let Ok(provision_state) = event_threads_shared_state
+ .provision_shared_state
+ .update_one_state(state)
+ .await
+ {
if provision_state.contains(ProvisionFlags::ALL_READY) {
- if let Err(e) = provision_shared_state.set_provision_finished(true).await {
+ // update redirector/eBPF policy based on access control status
+ update_redirector_policy(
+ event_threads_shared_state.redirector_shared_state.clone(),
+ event_threads_shared_state
+ .access_control_shared_state
+ .clone(),
+ )
+ .await;
+
+ if let Err(e) = event_threads_shared_state
+ .provision_shared_state
+ .set_provision_finished(true)
+ .await
+ {
// log the error and continue
logger::write_error(format!(
"update_provision_state::Failed to set provision finished with error: {e}"
@@ -236,24 +148,62 @@ async fn update_provision_state(
// write provision success state here
write_provision_state(
provision_dir,
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.provision_shared_state.clone(),
+ event_threads_shared_state.agent_status_shared_state.clone(),
)
.await;
// start event threads right after provision successfully
- start_event_threads(
- cancellation_token,
- key_keeper_shared_state,
- telemetry_shared_state,
- provision_shared_state,
- agent_status_shared_state,
- )
- .await;
+ start_event_threads(event_threads_shared_state).await;
}
}
}
+/// update the redirector/eBPF policy based on access control status
+/// it should be called when provision finished
+async fn update_redirector_policy(
+ redirector_shared_state: RedirectorSharedState,
+ access_control_shared_state: AccessControlSharedState,
+) {
+ let wireserver_mode =
+ if let Ok(Some(rules)) = access_control_shared_state.get_wireserver_rules().await {
+ rules.mode
+ } else {
+ // default to disabled if the rules are not ready
+ AuthorizationMode::Disabled
+ };
+ redirector::update_wire_server_redirect_policy(
+ wireserver_mode != AuthorizationMode::Disabled,
+ redirector_shared_state.clone(),
+ )
+ .await;
+
+ let imds_mode = if let Ok(Some(rules)) = access_control_shared_state.get_imds_rules().await {
+ rules.mode
+ } else {
+ // default to disabled if the rules are not ready
+ AuthorizationMode::Disabled
+ };
+ redirector::update_imds_redirect_policy(
+ imds_mode != AuthorizationMode::Disabled,
+ redirector_shared_state.clone(),
+ )
+ .await;
+
+ let ga_plugin_mode =
+ if let Ok(Some(rules)) = access_control_shared_state.get_hostga_rules().await {
+ rules.mode
+ } else {
+ // default to disabled if the rules are not ready
+ AuthorizationMode::Disabled
+ };
+ redirector::update_hostga_redirect_policy(
+ ga_plugin_mode != AuthorizationMode::Disabled,
+ redirector_shared_state.clone(),
+ )
+ .await;
+}
+
pub async fn key_latch_ready_state_reset(provision_shared_state: ProvisionSharedState) {
reset_provision_state(ProvisionFlags::KEY_LATCH_READY, provision_shared_state).await;
}
@@ -287,9 +237,9 @@ async fn reset_provision_state(
/// use std::sync::{Arc, Mutex};
///
/// let shared_state = Arc::new(Mutex::new(SharedState::new()));
-/// provision::provision_timeup(None, shared_state.clone());
+/// provision::provision_timeout(None, shared_state.clone());
/// ```
-pub async fn provision_timeup(
+pub async fn provision_timeout(
provision_dir: Option,
provision_shared_state: ProvisionSharedState,
agent_status_shared_state: AgentStatusSharedState,
@@ -313,17 +263,78 @@ pub async fn provision_timeup(
}
}
+/// Set resource limits for Guest Proxy Agent service
+/// It will be called when provision finished or timedout,
+/// it is designed to delay set resource limits to give more cpu time to provision tasks
+/// For Linux GPA service, it sets CPUQuota for azure-proxy-agent.service to limit the CPU usage
+/// For Windows GPA service, it sets CPU and RAM limits for current process to limit the CPU and RAM usage
+fn set_resource_limits() {
+ #[cfg(not(windows))]
+ {
+ // Set CPUQuota for azure-proxy-agent.service to 15% to limit the CPU usage for Linux azure-proxy-agent service
+ // Linux GPA VM Extension is not required for Linux GPA service, it should not have the resource limits set in HandlerManifest.json file
+ const SERVICE_NAME: &str = "azure-proxy-agent.service";
+ const CPU_QUOTA: u16 = 15;
+ match proxy_agent_shared::linux::set_cpu_quota(SERVICE_NAME, CPU_QUOTA) {
+ Ok(_) => {
+ logger::write_warning(format!(
+ "Successfully set {SERVICE_NAME} CPU quota to {CPU_QUOTA}%"
+ ));
+ }
+ Err(e) => {
+ logger::write_error(format!(
+ "Failed to set {SERVICE_NAME} CPU quota with error: {e}"
+ ));
+ }
+ }
+
+ // Do not set MemoryMax or MemoryHigh for azure-proxy-agent.service to limit the RAM usage
+ // As Linux GPA service is designed to be lightweight and use minimal memory footprint (~20MB),
+ // but its provisioning process may need more memory temporarily (e.g., up to 100MB) and then shrinks to ~20MB.
+ // If we set MemoryMax to 20MB, it may cause the provisioning process OOM kill unexpectedly.
+ // If we set MemoryHigh to 20MB, it may cause the provisioning process being throttled/hung unexpectedly.
+ }
+
+ #[cfg(windows)]
+ {
+ // Set CPUQuota for GPA service process to limit the CPU usage for Windows GPA service
+ // As we need adjust the total CPU quota based on the number of CPU cores,
+ // Windows GPA VM Extension should not have the resource limits set in HandlerManifest.json file
+ let cpu_count = proxy_agent_shared::current_info::get_cpu_count();
+ let percent = if cpu_count <= 4 {
+ 50
+ } else if cpu_count <= 8 {
+ 30
+ } else if cpu_count <= 16 {
+ 20
+ } else {
+ 15
+ };
+
+ const RAM_LIMIT_IN_MB: usize = 20;
+ match proxy_agent_shared::windows::set_resource_limits(
+ std::process::id(),
+ percent,
+ RAM_LIMIT_IN_MB,
+ ) {
+ Ok(_) => {
+ logger::write_warning(format!(
+ "Successfully set current process CPU quota to {percent}% and RAM limit to {RAM_LIMIT_IN_MB}MB"
+ ));
+ }
+ Err(e) => {
+ logger::write_error(format!("Failed to set CPU and RAM quota with error: {e}"));
+ }
+ }
+ }
+}
+
/// Start event logger & reader tasks and status reporting task
/// It will be called when provision finished or timedout,
/// it is designed to delay start those tasks to give more cpu time to provision tasks
-pub async fn start_event_threads(
- cancellation_token: CancellationToken,
- key_keeper_shared_state: KeyKeeperSharedState,
- telemetry_shared_state: TelemetrySharedState,
- provision_shared_state: ProvisionSharedState,
- agent_status_shared_state: AgentStatusSharedState,
-) {
- if let Ok(logger_threads_initialized) = provision_shared_state
+pub async fn start_event_threads(event_threads_shared_state: EventThreadsSharedState) {
+ if let Ok(logger_threads_initialized) = event_threads_shared_state
+ .provision_shared_state
.get_event_log_threads_initialized()
.await
{
@@ -332,7 +343,12 @@ pub async fn start_event_threads(
}
}
- let cloned_agent_status_shared_state = agent_status_shared_state.clone();
+ // set resource limits before launching lower priority tasks,
+ // those tasks starts to run after provision finished or provision timedout
+ set_resource_limits();
+
+ let cloned_agent_status_shared_state =
+ event_threads_shared_state.agent_status_shared_state.clone();
tokio::spawn({
async {
event_logger::start(
@@ -355,10 +371,10 @@ pub async fn start_event_threads(
let event_reader = EventReader::new(
config::get_events_dir(),
true,
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.cancellation_token.clone(),
+ event_threads_shared_state.common_state.clone(),
+ "ProxyAgent".to_string(),
+ "MicrosoftAzureGuestProxyAgent".to_string(),
);
async move {
event_reader
@@ -366,7 +382,8 @@ pub async fn start_event_threads(
.await;
}
});
- if let Err(e) = provision_shared_state
+ if let Err(e) = event_threads_shared_state
+ .provision_shared_state
.set_event_log_threads_initialized()
.await
{
@@ -379,9 +396,12 @@ pub async fn start_event_threads(
let agent_status_task = proxy_agent_status::ProxyAgentStatusTask::new(
Duration::from_secs(60),
proxy_agent_aggregate_status::get_proxy_agent_aggregate_status_folder(),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.cancellation_token.clone(),
+ event_threads_shared_state.key_keeper_shared_state.clone(),
+ event_threads_shared_state.agent_status_shared_state.clone(),
+ event_threads_shared_state
+ .connection_summary_shared_state
+ .clone(),
);
async move {
agent_status_task.start().await;
@@ -429,7 +449,7 @@ async fn write_provision_state(
if !failed_state_message.is_empty() {
// escape xml characters to allow the message to able be composed into xml payload
- failed_state_message = helpers::xml_escape(failed_state_message);
+ failed_state_message = misc_helpers::xml_escape(failed_state_message);
// write provision failed error message to event
event_logger::write_event(
@@ -658,6 +678,7 @@ mod tests {
use crate::provision::provision_query::ProvisionQuery;
use crate::provision::ProvisionFlags;
use crate::proxy::proxy_server;
+ use crate::shared_state::EventThreadsSharedState;
use crate::shared_state::SharedState;
use std::env;
use std::fs;
@@ -677,8 +698,10 @@ mod tests {
let cancellation_token = shared_state.get_cancellation_token();
let provision_shared_state = shared_state.get_provision_shared_state();
let key_keeper_shared_state = shared_state.get_key_keeper_shared_state();
- let telemetry_shared_state = shared_state.get_telemetry_shared_state();
let agent_status_shared_state = shared_state.get_agent_status_shared_state();
+ let event_threads_shared_state = EventThreadsSharedState::new(&shared_state);
+
+ // initialize key keeper secure channel state to UNKNOWN
let port: u16 = 8092;
let proxy_server = proxy_server::ProxyServer::new(port, &shared_state);
@@ -709,11 +732,7 @@ mod tests {
_ = super::update_provision_state(
ProvisionFlags::KEY_LATCH_READY,
Some(temp_test_path.to_path_buf()),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
)
.await;
_ = key_keeper_shared_state
@@ -738,29 +757,17 @@ mod tests {
super::update_provision_state(
ProvisionFlags::REDIRECTOR_READY,
Some(dir1),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
),
super::update_provision_state(
ProvisionFlags::KEY_LATCH_READY,
Some(dir2),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
),
super::update_provision_state(
ProvisionFlags::LISTENER_READY,
Some(dir3),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
),
];
for handle in handles {
@@ -812,14 +819,7 @@ mod tests {
assert!(event_threads_initialized);
// update provision finish time_tick
- super::key_latched(
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
- )
- .await;
+ super::key_latched(event_threads_shared_state.clone()).await;
let provision_state_internal = super::get_provision_state_internal(
provision_shared_state.clone(),
agent_status_shared_state.clone(),
@@ -872,14 +872,7 @@ mod tests {
);
// test key_latched ready again
- super::key_latched(
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
- )
- .await;
+ super::key_latched(event_threads_shared_state.clone()).await;
let provision_state = provision_shared_state.get_state().await.unwrap();
assert!(
provision_state.contains(ProvisionFlags::ALL_READY),
@@ -910,39 +903,26 @@ mod tests {
let shared_state = SharedState::start_all();
let cancellation_token = shared_state.get_cancellation_token();
let provision_shared_state = shared_state.get_provision_shared_state();
- let key_keeper_shared_state = shared_state.get_key_keeper_shared_state();
- let telemetry_shared_state = shared_state.get_telemetry_shared_state();
let agent_status_shared_state = shared_state.get_agent_status_shared_state();
+ let event_threads_shared_state = EventThreadsSharedState::new(&shared_state);
// test all 3 provision states as ready
super::update_provision_state(
ProvisionFlags::LISTENER_READY,
Some(temp_test_path.clone()),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
)
.await;
super::update_provision_state(
ProvisionFlags::KEY_LATCH_READY,
Some(temp_test_path.clone()),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
)
.await;
super::update_provision_state(
ProvisionFlags::REDIRECTOR_READY,
Some(temp_test_path.clone()),
- cancellation_token.clone(),
- key_keeper_shared_state.clone(),
- telemetry_shared_state.clone(),
- provision_shared_state.clone(),
- agent_status_shared_state.clone(),
+ event_threads_shared_state.clone(),
)
.await;
@@ -964,7 +944,7 @@ mod tests {
super::AgentStatusModule::KeyKeeper,
)
.await;
- super::provision_timeup(
+ super::provision_timeout(
Some(temp_test_path.clone()),
provision_shared_state.clone(),
agent_status_shared_state.clone(),
diff --git a/proxy_agent/src/proxy.rs b/proxy_agent/src/proxy.rs
index a6a2dce5..dcc50b1a 100644
--- a/proxy_agent/src/proxy.rs
+++ b/proxy_agent/src/proxy.rs
@@ -165,10 +165,16 @@ impl Process {
let (process_full_path, cmd);
#[cfg(windows)]
{
- let handler = windows::get_process_handler(pid).unwrap_or_else(|e| {
- println!("Failed to get process handler: {e}");
- 0
- });
+ use windows_sys::Win32::System::Threading::{
+ PROCESS_QUERY_INFORMATION, PROCESS_VM_READ,
+ };
+
+ let options = PROCESS_QUERY_INFORMATION | PROCESS_VM_READ;
+ let handler = proxy_agent_shared::windows::get_process_handler(pid, options)
+ .unwrap_or_else(|e| {
+ println!("Failed to get process handler: {e}");
+ 0
+ });
let base_info = windows::query_basic_process_info(handler);
match base_info {
Ok(_) => {
@@ -182,7 +188,7 @@ impl Process {
}
}
// close the handle
- if let Err(e) = windows::close_process_handler(handler) {
+ if let Err(e) = proxy_agent_shared::windows::close_handler(handler) {
println!("Failed to close process handler: {e}");
}
}
diff --git a/proxy_agent/src/proxy/proxy_authorizer.rs b/proxy_agent/src/proxy/proxy_authorizer.rs
index 1707f55e..23d54dc5 100644
--- a/proxy_agent/src/proxy/proxy_authorizer.rs
+++ b/proxy_agent/src/proxy/proxy_authorizer.rs
@@ -8,12 +8,12 @@
//! ```rust
//! use proxy_agent::proxy_authorizer;
//! use proxy_agent::proxy::Claims;
-//! use proxy_agent::shared_state::key_keeper_wrapper::KeyKeeperSharedState;
+//! use crate::shared_state::access_control_wrapper::AccessControlSharedState;
//! use proxy_agent::common::constants;
//! use std::str::FromStr;
//!
-//! let key_keeper_shared_state = KeyKeeperSharedState::start_new();
-//! let vm_metadata = proxy_authorizer::get_access_control_rules(constants::WIRE_SERVER_IP.to_string(), constants::WIRE_SERVER_PORT, key_keeper_shared_state.clone()).await.unwrap();
+//! let access_control_shared_state = AccessControlSharedState::start_new();
+//! let vm_metadata = proxy_authorizer::get_access_control_rules(constants::WIRE_SERVER_IP.to_string(), constants::WIRE_SERVER_PORT, access_control_shared_state .clone()).await.unwrap();
//! let authorizer = proxy_authorizer::get_authorizer(constants::WIRE_SERVER_IP, constants::WIRE_SERVER_PORT, claims);
//! let url = hyper::Uri::from_str("http://localhost/test?").unwrap();
//! authorizer.authorize(logger, url, vm_metadata);
@@ -21,7 +21,7 @@
use super::authorization_rules::{AuthorizationMode, ComputedAuthorizationItem};
use super::proxy_connection::ConnectionLogger;
-use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState;
+use crate::shared_state::access_control_wrapper::AccessControlSharedState;
use crate::{common::constants, common::result::Result, proxy::Claims};
use proxy_agent_shared::logger::LoggerLevel;
@@ -211,17 +211,17 @@ pub fn get_authorizer(ip: String, port: u16, claims: Claims) -> Box Result