diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index 7809ed51..0b73b241 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -364,4 +364,22 @@ xsi xxxx xxxxxxxx xxxxxxxxxxx -zipsas \ No newline at end of file +zipsas +CALG +CNG +detials +HCRYPTPROV +hostgap +innerbandwidth +innerlatency +KSP +ncrypt +outwardbandwidth +outwardlatency +PCWSTR +PSTR +psz +rsm +SELFSIGN +vmsettings +hostgaplugin \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f07e20dc..14483e3d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -176,15 +176,11 @@ version = "9.9.9" dependencies = [ "aya", "bitflags", - "clap", "ctor", - "hex", - "hmac-sha256", "http", "http-body-util", "hyper", "hyper-util", - "itertools", "libc", "libloading", "nix", @@ -226,6 +222,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.6.0" @@ -464,9 +466,9 @@ checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" [[package]] name = "futures-sink" -version = "0.3.30" +version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" [[package]] name = "futures-task" @@ -553,9 +555,9 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hmac-sha256" -version = "1.1.7" +version = "1.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3688e69b38018fec1557254f64c8dc2cc8ec502890182f395dbb0aa997aa5735" +checksum = "ad6880c8d4a9ebf39c6e8b77007ce223f646a4d21ce29d99f70cb16420545425" [[package]] name = "http" @@ -711,9 +713,9 @@ checksum = "c19937216e9d3aa9956d9bb8dfc0b0c8beb6058fc4f7a4dc4d850edf86a237d6" [[package]] name = "libloading" -version = "0.8.5" +version = "0.8.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" dependencies = [ "cfg-if", "windows-targets", @@ -879,26 +881,50 @@ dependencies = [ name = "proxy_agent_shared" version = "9.9.9" dependencies = [ + "base64", "chrono", + "clap", "concurrent-queue", "ctor", + "hex", + "hmac-sha256", + "http", + "http-body-util", + "hyper", + "hyper-util", + "itertools", + "libloading", "log", "once_cell", "os_info", + "quick-xml", "regex", "serde", "serde-xml-rs", "serde_derive", "serde_json", + "sysinfo", "thiserror", "thread-id", "time", "tokio", + "uuid", + "windows 0.61.3", "windows-service", "windows-sys", "winreg", ] +[[package]] +name = "quick-xml" +version = "0.38.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42a232e7487fc2ef313d96dde7948e7a3c05101870d8985e4fd8d26aedd27b89" +dependencies = [ + "memchr", + "serde", +] + [[package]] name = "quote" version = "1.0.37" @@ -1136,7 +1162,7 @@ dependencies = [ "ntapi", "once_cell", "rayon", - "windows", + "windows 0.52.0", ] [[package]] @@ -1488,6 +1514,19 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link", + "windows-numerics", +] + [[package]] name = "windows-acl" version = "0.3.0" @@ -1500,6 +1539,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -1522,6 +1570,17 @@ dependencies = [ "windows-strings", ] +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -1550,6 +1609,16 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link", +] + [[package]] name = "windows-result" version = "0.3.4" @@ -1604,6 +1673,15 @@ dependencies = [ "windows_x86_64_msvc", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" diff --git a/proxy_agent/Cargo.toml b/proxy_agent/Cargo.toml index 355c1e55..5926ef7a 100644 --- a/proxy_agent/Cargo.toml +++ b/proxy_agent/Cargo.toml @@ -10,15 +10,12 @@ license = "MIT" [dependencies] proxy_agent_shared = { path ="../proxy_agent_shared"} -itertools = "0.10.5" # use to sort iterator elements into a new iterator in ascending order once_cell = "1.17.0" # use Lazy serde = "1.0.152" serde_derive = "1.0.152" serde_json = "1.0.91" # json Deserializer serde-xml-rs = "0.8.1" # xml Deserializer with xml attribute bitflags = "2.6.0" # support bitflag enum -hmac-sha256 = "1.1.6" # use HMAC using the SHA-256 hash function -hex = "0.4.3" # hex encode regex = "1.11" # match process name in cmdline tokio = { version = "1", features = ["rt", "rt-multi-thread", "time", "net", "macros", "sync"] } tokio-util = "0.7.11" @@ -28,7 +25,6 @@ hyper = { version = "1", features = ["server", "http1", "client"] } hyper-util = { version = "0.1", features = ["tokio"] } tower = { version = "0.5.2", features = ["full"] } tower-http = { version = "0.6.2", features = ["limit"] } -clap = { version = "4.5.17", features =["derive"] } # Command Line Argument Parser thiserror = "1.0.64" ctor = "0.3.6" # used for test setup and clean up diff --git a/proxy_agent/src/acl.rs b/proxy_agent/src/acl.rs index f569d803..f4f4e063 100644 --- a/proxy_agent/src/acl.rs +++ b/proxy_agent/src/acl.rs @@ -19,7 +19,7 @@ mod windows_acl; #[cfg(not(windows))] mod linux_acl; -use crate::common::result::Result; +use proxy_agent_shared::common::result::Result; use std::path::PathBuf; pub fn acl_directory(dir_to_acl: PathBuf) -> Result<()> { diff --git a/proxy_agent/src/acl/linux_acl.rs b/proxy_agent/src/acl/linux_acl.rs index d588c3d9..6d9df1cb 100644 --- a/proxy_agent/src/acl/linux_acl.rs +++ b/proxy_agent/src/acl/linux_acl.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use crate::common::{logger, result::Result}; use nix::unistd::{chown, Gid, Uid}; +use proxy_agent_shared::common::{logger, result::Result}; use proxy_agent_shared::misc_helpers; use std::fs; use std::os::unix::fs::PermissionsExt; diff --git a/proxy_agent/src/acl/windows_acl.rs b/proxy_agent/src/acl/windows_acl.rs index eb63a896..8aa0757b 100644 --- a/proxy_agent/src/acl/windows_acl.rs +++ b/proxy_agent/src/acl/windows_acl.rs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use crate::common::{ +use proxy_agent_shared::common::{ error::{AclErrorType, Error}, logger, result::Result, diff --git a/proxy_agent/src/common/cli.rs b/proxy_agent/src/common/cli.rs deleted file mode 100644 index ead9eb74..00000000 --- a/proxy_agent/src/common/cli.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright (c) Microsoft Corporation -// SPDX-License-Identifier: MIT - -use clap::{Parser, Subcommand}; -use once_cell::sync::Lazy; - -/// azure-proxy-agent console - launch a long run process of GPA in console mode. -/// azure-proxy-agent --version - print the version of the GPA. -/// azure-proxy-agent --status [--wait ] - get the provision status of the GPA service. -/// azure-proxy-agent - start the GPA as an OS service. -/// The GPA service will be started as an OS service in the background. -#[derive(Parser)] -#[command()] -pub struct Cli { - /// get the provision status of the GPA service - #[arg(short, long)] - pub status: bool, - - /// wait for the provision status to finish - #[arg(short, long, requires = "status")] - pub wait: Option, - - /// print the version of the GPA - #[arg(short, long)] - pub version: bool, - - #[cfg(test)] - #[arg(short, long)] - test_threads: Option, - - #[cfg(test)] - #[arg(short, long)] - nocapture: bool, - - #[command(subcommand)] - pub command: Option, -} - -#[derive(Subcommand)] -pub enum Commands { - /// launch a long run process of GPA in console mode - Console, -} - -impl Cli { - pub fn is_console_mode(&self) -> bool { - self.command.is_some() - } -} - -pub static CLI: Lazy = Lazy::new(Cli::parse); diff --git a/proxy_agent/src/host_clients/imds_client.rs b/proxy_agent/src/host_clients/imds_client.rs index 7c506ef1..ceeaa927 100644 --- a/proxy_agent/src/host_clients/imds_client.rs +++ b/proxy_agent/src/host_clients/imds_client.rs @@ -21,9 +21,9 @@ //! ``` use super::instance_info::InstanceInfo; -use crate::common::{error::Error, hyper_client, logger, result::Result}; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use hyper::Uri; +use proxy_agent_shared::common::{error::Error, hyper_client, logger, result::Result}; use std::collections::HashMap; pub struct ImdsClient { diff --git a/proxy_agent/src/host_clients/wire_server_client.rs b/proxy_agent/src/host_clients/wire_server_client.rs index ced98348..1cee0b11 100644 --- a/proxy_agent/src/host_clients/wire_server_client.rs +++ b/proxy_agent/src/host_clients/wire_server_client.rs @@ -20,16 +20,14 @@ //! ``` use crate::host_clients::goal_state::{GoalState, SharedConfig}; -use crate::{ - common::{ - error::{Error, WireServerErrorType}, - hyper_client, logger, - result::Result, - }, - shared_state::key_keeper_wrapper::KeyKeeperSharedState, -}; +use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use http::Method; use hyper::Uri; +use proxy_agent_shared::common::{ + error::{Error, WireServerErrorType}, + hyper_client, logger, + result::Result, +}; use std::collections::HashMap; pub struct WireServerClient { diff --git a/proxy_agent/src/key_keeper.rs b/proxy_agent/src/key_keeper.rs index 19170e15..0555185d 100644 --- a/proxy_agent/src/key_keeper.rs +++ b/proxy_agent/src/key_keeper.rs @@ -26,9 +26,6 @@ pub mod key; use self::key::Key; -use crate::common::error::{Error, KeyErrorType}; -use crate::common::result::Result; -use crate::common::{constants, helpers, logger}; use crate::provision; use crate::proxy::authorization_rules::{AuthorizationRulesForLogging, ComputedAuthorizationRules}; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; @@ -39,6 +36,9 @@ use crate::shared_state::telemetry_wrapper::TelemetrySharedState; use crate::shared_state::SharedState; use crate::{acl, redirector}; use hyper::Uri; +use proxy_agent_shared::common::error::{Error, KeyErrorType}; +use proxy_agent_shared::common::result::Result; +use proxy_agent_shared::common::{constants, helpers, logger}; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; @@ -607,7 +607,7 @@ impl KeyKeeper { key_file.set_extension("encrypted"); #[cfg(windows)] { - crate::common::store_key_data( + proxy_agent_shared::common::store_key_data( &key_file, serde_json::to_string(&key).map_err(|e| { Error::Key(KeyErrorType::StoreLocalKey(format!( @@ -657,7 +657,7 @@ impl KeyKeeper { if !key_file.exists() { // guid.key file does not exist locally return Err(Error::Key( - crate::common::error::KeyErrorType::FetchLocalKey(format!( + proxy_agent_shared::common::error::KeyErrorType::FetchLocalKey(format!( "Key file '{}' does not exist locally.", key_file.display() )), @@ -674,7 +674,7 @@ impl KeyKeeper { } #[cfg(windows)] { - crate::common::fetch_key_data(&key_file)? + proxy_agent_shared::common::fetch_key_data(&key_file)? } } else { fs::read_to_string(&key_file).map_err(|e| { @@ -683,9 +683,11 @@ impl KeyKeeper { }; serde_json::from_str::(&key_data).map_err(|e| { - Error::Key(crate::common::error::KeyErrorType::FetchLocalKey(format!( - "Parse key data with error: {e}" - ))) + Error::Key( + proxy_agent_shared::common::error::KeyErrorType::FetchLocalKey(format!( + "Parse key data with error: {e}" + )), + ) }) } @@ -726,7 +728,7 @@ impl KeyKeeper { } else { // guid.key file found but guid or key value is not matched Err(Error::Key( - crate::common::error::KeyErrorType::CheckLocalKey( + proxy_agent_shared::common::error::KeyErrorType::CheckLocalKey( "Local key guid or key value is not matched.".to_string(), ), )) diff --git a/proxy_agent/src/key_keeper/key.rs b/proxy_agent/src/key_keeper/key.rs index bf8ba5ac..756fc863 100644 --- a/proxy_agent/src/key_keeper/key.rs +++ b/proxy_agent/src/key_keeper/key.rs @@ -21,18 +21,15 @@ //! Key::attest_key(base_url.clone(), &key).await.unwrap(); //! //! ``` - -use crate::{ - common::{ - constants, - error::{Error, KeyErrorType}, - hyper_client, logger, - result::Result, - }, - proxy::{proxy_connection::ConnectionLogger, Claims}, -}; +use crate::proxy::{proxy_connection::ConnectionLogger, Claims}; use http::{Method, StatusCode}; use hyper::Uri; +use proxy_agent_shared::common::{ + constants, + error::{Error, KeyErrorType}, + hyper_client, logger, + result::Result, +}; use proxy_agent_shared::logger::LoggerLevel; use serde_derive::{Deserialize, Serialize}; use std::ffi::OsString; @@ -831,11 +828,11 @@ mod tests { use super::Key; use super::KeyStatus; - use crate::common::constants; use crate::key_keeper::key::Identity; use crate::key_keeper::key::Privilege; use crate::proxy::proxy_connection::ConnectionLogger; use hyper::Uri; + use proxy_agent_shared::common::constants; use serde_json::json; #[test] diff --git a/proxy_agent/src/main.rs b/proxy_agent/src/main.rs index 4dcc3ce1..a1bd5323 100644 --- a/proxy_agent/src/main.rs +++ b/proxy_agent/src/main.rs @@ -2,7 +2,6 @@ // SPDX-License-Identifier: MIT pub mod acl; -pub mod common; pub mod host_clients; pub mod key_keeper; pub mod provision; @@ -16,16 +15,16 @@ pub mod telemetry; #[cfg(test)] pub mod test_mock; -use common::cli::{Commands, CLI}; -use common::constants; -use common::helpers; use provision::provision_query::ProvisionQuery; +use proxy_agent_shared::common::cli::{Commands, CLI}; +use proxy_agent_shared::common::constants; +use proxy_agent_shared::common::helpers; use proxy_agent_shared::misc_helpers; use shared_state::SharedState; use std::{process, time::Duration}; #[cfg(windows)] -use common::logger; +use proxy_agent_shared::common::logger; #[cfg(windows)] use service::windows_main; #[cfg(windows)] diff --git a/proxy_agent/src/provision.rs b/proxy_agent/src/provision.rs index bec9c783..f184ad24 100644 --- a/proxy_agent/src/provision.rs +++ b/proxy_agent/src/provision.rs @@ -79,7 +79,6 @@ //! assert_eq!(0, provision_state.1.len()); //! ``` -use crate::common::{config, helpers, logger}; use crate::key_keeper::{DISABLE_STATE, UNKNOWN_STATE}; use crate::proxy_agent_status; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; @@ -87,6 +86,7 @@ use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::telemetry_wrapper::TelemetrySharedState; use crate::telemetry::event_reader::EventReader; +use proxy_agent_shared::common::{config, helpers, logger}; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::telemetry::event_logger; use proxy_agent_shared::{misc_helpers, proxy_agent_aggregate_status}; @@ -535,7 +535,9 @@ pub async fn get_provision_state_internal( /// provision query module designed for GPA command line, serves for --status [--wait seconds] option /// It is used to query the provision status from GPA service via http request pub mod provision_query { - use crate::common::{constants, error::Error, helpers, hyper_client, logger, result::Result}; + use proxy_agent_shared::common::{ + constants, error::Error, helpers, hyper_client, logger, result::Result, + }; use proxy_agent_shared::misc_helpers; use serde_derive::{Deserialize, Serialize}; use std::{collections::HashMap, net::Ipv4Addr, time::Duration}; diff --git a/proxy_agent/src/proxy.rs b/proxy_agent/src/proxy.rs index a6a2dce5..a5f5cbb0 100644 --- a/proxy_agent/src/proxy.rs +++ b/proxy_agent/src/proxy.rs @@ -40,9 +40,9 @@ pub mod proxy_summary; #[cfg(windows)] mod windows; -use crate::common::result::Result; use crate::redirector::AuditEntry; use crate::shared_state::proxy_server_wrapper::ProxyServerSharedState; +use proxy_agent_shared::common::result::Result; use serde_derive::{Deserialize, Serialize}; use std::{ffi::OsString, net::IpAddr, path::PathBuf}; diff --git a/proxy_agent/src/proxy/authorization_rules.rs b/proxy_agent/src/proxy/authorization_rules.rs index 27ef84fa..b5aed2ab 100644 --- a/proxy_agent/src/proxy/authorization_rules.rs +++ b/proxy_agent/src/proxy/authorization_rules.rs @@ -18,8 +18,8 @@ //! ``` use super::{proxy_connection::ConnectionLogger, Claims}; -use crate::common::logger; use crate::key_keeper::key::{AuthorizationItem, AuthorizationRules, Identity, Privilege, Role}; +use proxy_agent_shared::common::logger; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use serde_derive::{Deserialize, Serialize}; diff --git a/proxy_agent/src/proxy/proxy_authorizer.rs b/proxy_agent/src/proxy/proxy_authorizer.rs index 1707f55e..6ca85403 100644 --- a/proxy_agent/src/proxy/proxy_authorizer.rs +++ b/proxy_agent/src/proxy/proxy_authorizer.rs @@ -21,9 +21,10 @@ use super::authorization_rules::{AuthorizationMode, ComputedAuthorizationItem}; use super::proxy_connection::ConnectionLogger; +use crate::proxy::Claims; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; -use crate::{common::constants, common::result::Result, proxy::Claims}; use proxy_agent_shared::logger::LoggerLevel; +use proxy_agent_shared::{common::constants, common::result::Result}; #[derive(PartialEq)] pub enum AuthorizeResult { @@ -268,8 +269,8 @@ mod tests { }; let mut test_logger = ConnectionLogger::new(0, 0); let auth: Box = super::get_authorizer( - crate::common::constants::WIRE_SERVER_IP.to_string(), - crate::common::constants::WIRE_SERVER_PORT, + proxy_agent_shared::common::constants::WIRE_SERVER_IP.to_string(), + proxy_agent_shared::common::constants::WIRE_SERVER_PORT, claims.clone(), ); let test_uri = hyper::Uri::from_str("test").unwrap(); @@ -283,8 +284,8 @@ mod tests { ); let auth = super::get_authorizer( - crate::common::constants::GA_PLUGIN_IP.to_string(), - crate::common::constants::GA_PLUGIN_PORT, + proxy_agent_shared::common::constants::GA_PLUGIN_IP.to_string(), + proxy_agent_shared::common::constants::GA_PLUGIN_PORT, claims.clone(), ); assert_eq!( @@ -297,8 +298,8 @@ mod tests { ); let auth = super::get_authorizer( - crate::common::constants::IMDS_IP.to_string(), - crate::common::constants::IMDS_PORT, + proxy_agent_shared::common::constants::IMDS_IP.to_string(), + proxy_agent_shared::common::constants::IMDS_PORT, claims.clone(), ); assert_eq!(auth.to_string(), "IMDS"); @@ -308,8 +309,8 @@ mod tests { ); let auth = super::get_authorizer( - crate::common::constants::PROXY_AGENT_IP.to_string(), - crate::common::constants::PROXY_AGENT_PORT, + proxy_agent_shared::common::constants::PROXY_AGENT_IP.to_string(), + proxy_agent_shared::common::constants::PROXY_AGENT_PORT, claims.clone(), ); assert_eq!(auth.to_string(), "ProxyAgent"); @@ -319,8 +320,8 @@ mod tests { ); let auth = super::get_authorizer( - crate::common::constants::PROXY_AGENT_IP.to_string(), - crate::common::constants::PROXY_AGENT_PORT + 1, + proxy_agent_shared::common::constants::PROXY_AGENT_IP.to_string(), + proxy_agent_shared::common::constants::PROXY_AGENT_PORT + 1, claims.clone(), ); assert_eq!(auth.to_string(), "Default"); @@ -342,8 +343,8 @@ mod tests { }; let mut test_logger = ConnectionLogger::new(1, 1); let auth = super::get_authorizer( - crate::common::constants::WIRE_SERVER_IP.to_string(), - crate::common::constants::WIRE_SERVER_PORT, + proxy_agent_shared::common::constants::WIRE_SERVER_IP.to_string(), + proxy_agent_shared::common::constants::WIRE_SERVER_PORT, claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); @@ -467,8 +468,8 @@ mod tests { clientPort: 0, // doesn't matter for this test }; let auth = super::get_authorizer( - crate::common::constants::IMDS_IP.to_string(), - crate::common::constants::IMDS_PORT, + proxy_agent_shared::common::constants::IMDS_IP.to_string(), + proxy_agent_shared::common::constants::IMDS_PORT, claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); @@ -577,8 +578,8 @@ mod tests { }; let mut test_logger = ConnectionLogger::new(1, 1); let auth = super::get_authorizer( - crate::common::constants::GA_PLUGIN_IP.to_string(), - crate::common::constants::GA_PLUGIN_PORT, + proxy_agent_shared::common::constants::GA_PLUGIN_IP.to_string(), + proxy_agent_shared::common::constants::GA_PLUGIN_PORT, claims.clone(), ); let url = hyper::Uri::from_str("http://localhost/test?").unwrap(); diff --git a/proxy_agent/src/proxy/proxy_connection.rs b/proxy_agent/src/proxy/proxy_connection.rs index eeb3b340..1b7bd60a 100644 --- a/proxy_agent/src/proxy/proxy_connection.rs +++ b/proxy_agent/src/proxy/proxy_connection.rs @@ -3,9 +3,6 @@ //! This module contains the connection context struct for the proxy listener, and write proxy processing logs to local file. -use crate::common::error::{Error, HyperErrorType}; -use crate::common::result::Result; -use crate::common::{config, hyper_client}; use crate::proxy::Claims; use crate::redirector::{self, AuditEntry}; use crate::shared_state::proxy_server_wrapper::ProxyServerSharedState; @@ -14,6 +11,9 @@ use http_body_util::Full; use hyper::body::Bytes; use hyper::client::conn::http1; use hyper::Request; +use proxy_agent_shared::common::error::{Error, HyperErrorType}; +use proxy_agent_shared::common::result::Result; +use proxy_agent_shared::common::{config, hyper_client}; use proxy_agent_shared::logger::{self, logger_manager, LoggerLevel}; use proxy_agent_shared::misc_helpers; use std::net::{Ipv4Addr, SocketAddr}; @@ -339,7 +339,9 @@ impl ConnectionLogger { // connection logger file log does not write to system log implicitly. logger_manager::write_system_log(logger_level, message.to_string()); - if let Some(log_for_event) = crate::common::config::get_file_log_level_for_events() { + if let Some(log_for_event) = + proxy_agent_shared::common::config::get_file_log_level_for_events() + { if log_for_event >= logger_level { // write to event proxy_agent_shared::telemetry::event_logger::write_event_only( diff --git a/proxy_agent/src/proxy/proxy_server.rs b/proxy_agent/src/proxy/proxy_server.rs index dc8fd445..918bdc45 100644 --- a/proxy_agent/src/proxy/proxy_server.rs +++ b/proxy_agent/src/proxy/proxy_server.rs @@ -10,7 +10,7 @@ //! //! Example: //! ```rust -//! use crate::common::config; +//! use proxy_agent_shared::common::config; //! use crate::proxy::proxy_server; //! use crate::shared_state::SharedState; //! @@ -22,12 +22,6 @@ use super::proxy_authorizer::AuthorizeResult; use super::proxy_connection::{ConnectionLogger, HttpConnectionContext, TcpConnectionContext}; -use crate::common::{ - constants, - error::{Error, HyperErrorType}, - helpers, hyper_client, logger, - result::Result, -}; use crate::provision; use crate::proxy::{proxy_authorizer, proxy_summary::ProxySummary, Claims}; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; @@ -45,6 +39,12 @@ use hyper::service::service_fn; use hyper::StatusCode; use hyper::{Request, Response}; use hyper_util::rt::TokioIo; +use proxy_agent_shared::common::{ + constants, + error::{Error, HyperErrorType}, + helpers, hyper_client, logger, + result::Result, +}; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; @@ -279,7 +279,10 @@ impl ProxyServer { let large_limited_tower_service = tower::ServiceBuilder::new().layer(large_limit_layer); let tower_service_layer = - if crate::common::hyper_client::should_skip_sig(req.method(), req.uri()) { + if proxy_agent_shared::common::hyper_client::should_skip_sig( + req.method(), + req.uri(), + ) { // skip signature check for large request large_limited_tower_service.clone() } else { @@ -1012,11 +1015,11 @@ impl ProxyServer { #[cfg(test)] mod tests { - use crate::common::hyper_client; - use crate::common::logger; use crate::proxy::proxy_server; use crate::shared_state; use http::Method; + use proxy_agent_shared::common::hyper_client; + use proxy_agent_shared::common::logger; use std::collections::HashMap; use std::time::Duration; diff --git a/proxy_agent/src/proxy/windows.rs b/proxy_agent/src/proxy/windows.rs index 145b84d7..174000b6 100644 --- a/proxy_agent/src/proxy/windows.rs +++ b/proxy_agent/src/proxy/windows.rs @@ -1,13 +1,13 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use crate::common::{ +use libloading::{Library, Symbol}; +use once_cell::sync::Lazy; +use proxy_agent_shared::common::{ error::{Error, WindowsApiErrorType}, logger, result::Result, }; -use libloading::{Library, Symbol}; -use once_cell::sync::Lazy; use std::mem::MaybeUninit; use std::ptr::null_mut; use std::{collections::HashMap, ffi::OsString, os::windows::ffi::OsStringExt, path::PathBuf}; diff --git a/proxy_agent/src/proxy_agent_status.rs b/proxy_agent/src/proxy_agent_status.rs index 777c1620..3a6eb429 100644 --- a/proxy_agent/src/proxy_agent_status.rs +++ b/proxy_agent/src/proxy_agent_status.rs @@ -28,10 +28,10 @@ //! tokio::spawn(proxy_agent_status_task.start()); //! ``` -use crate::common::logger; use crate::key_keeper::UNKNOWN_STATE; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; +use proxy_agent_shared::common::logger; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::{ diff --git a/proxy_agent/src/redirector.rs b/proxy_agent/src/redirector.rs index 07a33433..bbb05640 100644 --- a/proxy_agent/src/redirector.rs +++ b/proxy_agent/src/redirector.rs @@ -45,12 +45,6 @@ mod windows; #[cfg(not(windows))] mod linux; -use crate::common::constants; -use crate::common::error::BpfErrorType; -use crate::common::error::Error; -use crate::common::helpers; -use crate::common::result::Result; -use crate::common::{config, logger}; use crate::provision; use crate::proxy::authorization_rules::AuthorizationMode; use crate::shared_state::agent_status_wrapper::{AgentStatusModule, AgentStatusSharedState}; @@ -59,6 +53,12 @@ use crate::shared_state::provision_wrapper::ProvisionSharedState; use crate::shared_state::redirector_wrapper::RedirectorSharedState; use crate::shared_state::telemetry_wrapper::TelemetrySharedState; use crate::shared_state::SharedState; +use proxy_agent_shared::common::constants; +use proxy_agent_shared::common::error::BpfErrorType; +use proxy_agent_shared::common::error::Error; +use proxy_agent_shared::common::helpers; +use proxy_agent_shared::common::result::Result; +use proxy_agent_shared::common::{config, logger}; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; diff --git a/proxy_agent/src/redirector/linux.rs b/proxy_agent/src/redirector/linux.rs index f973428e..58193076 100644 --- a/proxy_agent/src/redirector/linux.rs +++ b/proxy_agent/src/redirector/linux.rs @@ -2,12 +2,6 @@ // SPDX-License-Identifier: MIT mod ebpf_obj; -use crate::common::{ - config, constants, - error::{BpfErrorType, Error}, - logger, - result::Result, -}; use crate::redirector::{ip_to_string, AuditEntry}; use crate::shared_state::redirector_wrapper::RedirectorSharedState; use aya::programs::{CgroupSockAddr, KProbe}; @@ -19,6 +13,12 @@ use aya::{Btf, Ebpf, EbpfLoader}; use ebpf_obj::{ destination_entry, sock_addr_audit_entry, sock_addr_audit_key, sock_addr_skip_process_entry, }; +use proxy_agent_shared::common::{ + config, constants, + error::{BpfErrorType, Error}, + logger, + result::Result, +}; use proxy_agent_shared::telemetry::event_logger; use proxy_agent_shared::{logger::LoggerLevel, misc_helpers}; use std::convert::TryFrom; @@ -476,11 +476,11 @@ pub async fn update_hostga_redirect_policy( #[cfg(test)] #[cfg(feature = "test-with-root")] mod tests { - use crate::common::config; - use crate::common::constants; use crate::redirector::linux::ebpf_obj::sock_addr_audit_entry; use crate::redirector::linux::ebpf_obj::sock_addr_audit_key; use aya::maps::HashMap; + use proxy_agent_shared::common::config; + use proxy_agent_shared::common::constants; use proxy_agent_shared::misc_helpers; use std::env; diff --git a/proxy_agent/src/redirector/linux/ebpf_obj.rs b/proxy_agent/src/redirector/linux/ebpf_obj.rs index 706b2eaf..c00a9604 100644 --- a/proxy_agent/src/redirector/linux/ebpf_obj.rs +++ b/proxy_agent/src/redirector/linux/ebpf_obj.rs @@ -147,7 +147,7 @@ impl sock_addr_audit_entry { #[cfg(test)] mod tests { - use crate::common::constants; + use proxy_agent_shared::common::constants; #[test] fn destination_entry_test() { diff --git a/proxy_agent/src/redirector/windows.rs b/proxy_agent/src/redirector/windows.rs index cf23a3b8..66755940 100644 --- a/proxy_agent/src/redirector/windows.rs +++ b/proxy_agent/src/redirector/windows.rs @@ -5,11 +5,11 @@ mod bpf_api; mod bpf_obj; mod bpf_prog; -use crate::common::error::{BpfErrorType, Error, WindowsApiErrorType}; -use crate::common::{constants, logger, result::Result}; use crate::redirector::AuditEntry; use crate::shared_state::redirector_wrapper::RedirectorSharedState; use core::ffi::c_void; +use proxy_agent_shared::common::error::{BpfErrorType, Error, WindowsApiErrorType}; +use proxy_agent_shared::common::{constants, logger, result::Result}; use std::mem; use std::ptr; use windows_sys::Win32::Networking::WinSock; diff --git a/proxy_agent/src/redirector/windows/bpf_api.rs b/proxy_agent/src/redirector/windows/bpf_api.rs index 55b52bcb..0646b9d8 100644 --- a/proxy_agent/src/redirector/windows/bpf_api.rs +++ b/proxy_agent/src/redirector/windows/bpf_api.rs @@ -4,12 +4,12 @@ #![allow(non_snake_case)] use super::bpf_obj::*; -use crate::common::{ +use libloading::{Library, Symbol}; +use proxy_agent_shared::common::{ error::{BpfErrorType, Error}, logger, result::Result, }; -use libloading::{Library, Symbol}; use proxy_agent_shared::{ logger::LoggerLevel, misc_helpers, telemetry::event_logger, version::Version, }; diff --git a/proxy_agent/src/redirector/windows/bpf_prog.rs b/proxy_agent/src/redirector/windows/bpf_prog.rs index 71e51968..1372bcda 100644 --- a/proxy_agent/src/redirector/windows/bpf_prog.rs +++ b/proxy_agent/src/redirector/windows/bpf_prog.rs @@ -3,13 +3,13 @@ use super::bpf_api::*; use super::bpf_obj::*; use super::BpfObject; -use crate::common::constants; -use crate::common::logger; -use crate::common::{ +use crate::redirector::AuditEntry; +use proxy_agent_shared::common::constants; +use proxy_agent_shared::common::logger; +use proxy_agent_shared::common::{ error::{BpfErrorType, Error}, result::Result, }; -use crate::redirector::AuditEntry; use proxy_agent_shared::misc_helpers; use std::ffi::c_void; use std::mem::size_of_val; diff --git a/proxy_agent/src/service.rs b/proxy_agent/src/service.rs index 9de32681..00a5f172 100644 --- a/proxy_agent/src/service.rs +++ b/proxy_agent/src/service.rs @@ -3,12 +3,12 @@ #[cfg(windows)] pub mod windows_main; -use crate::common::{config, constants, helpers, logger}; use crate::key_keeper::KeyKeeper; use crate::proxy::proxy_connection::ConnectionLogger; use crate::proxy::proxy_server::ProxyServer; use crate::redirector::{self, Redirector}; use crate::shared_state::SharedState; +use proxy_agent_shared::common::{config, constants, helpers, logger}; use proxy_agent_shared::logger::rolling_logger::RollingLogger; use proxy_agent_shared::logger::{logger_manager, LoggerLevel}; use proxy_agent_shared::proxy_agent_aggregate_status; diff --git a/proxy_agent/src/service/windows_main.rs b/proxy_agent/src/service/windows_main.rs index d62170c3..2c0175c8 100644 --- a/proxy_agent/src/service/windows_main.rs +++ b/proxy_agent/src/service/windows_main.rs @@ -5,8 +5,8 @@ //! The GPA service is implemented as a Windows service using the windows_service crate. //! It is started, stopped, and controlled by the Windows service manager. -use crate::common::{constants, logger, result::Result}; use crate::{service, shared_state::SharedState}; +use proxy_agent_shared::common::{constants, logger, result::Result}; use std::time::Duration; use windows_service::service::{ ServiceControl, ServiceControlAccept, ServiceExitCode, ServiceState, ServiceStatus, ServiceType, diff --git a/proxy_agent/src/shared_state/agent_status_wrapper.rs b/proxy_agent/src/shared_state/agent_status_wrapper.rs index e7f55aa0..de00ad76 100644 --- a/proxy_agent/src/shared_state/agent_status_wrapper.rs +++ b/proxy_agent/src/shared_state/agent_status_wrapper.rs @@ -7,9 +7,10 @@ //! The proxy agent status contains the 'failed connection summary' of the proxy server. //! The proxy agent status contains the 'connection count' of the proxy server. -use crate::common::logger; -use crate::common::result::Result; -use crate::{common::error::Error, proxy::proxy_summary::ProxySummary}; +use crate::proxy::proxy_summary::ProxySummary; +use proxy_agent_shared::common::error::Error; +use proxy_agent_shared::common::logger; +use proxy_agent_shared::common::result::Result; use proxy_agent_shared::logger::LoggerLevel; use proxy_agent_shared::proxy_agent_aggregate_status::{ ModuleState, ProxyAgentDetailStatus, ProxyConnectionSummary, diff --git a/proxy_agent/src/shared_state/key_keeper_wrapper.rs b/proxy_agent/src/shared_state/key_keeper_wrapper.rs index de75b566..cccbbce3 100644 --- a/proxy_agent/src/shared_state/key_keeper_wrapper.rs +++ b/proxy_agent/src/shared_state/key_keeper_wrapper.rs @@ -1,12 +1,15 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +use crate::key_keeper::key::AuthorizationItem; +use crate::key_keeper::key::Key; +use crate::proxy::authorization_rules::ComputedAuthorizationItem; /// The KeyKeeperState struct is used to send actions to the KeyKeeper module related shared state fields /// Example: /// ``` /// use crate::shared_state::key_keeper_wrapper::KeyKeeperState; /// use crate::key_keeper::key::Key; -/// use crate::common::result::Result; +/// use proxy_agent_shared::common::result::Result; /// use std::sync::Arc; /// use tokio::sync::Notify; /// @@ -36,11 +39,9 @@ /// Ok(()) /// } /// ``` -use crate::common::error::Error; -use crate::common::result::Result; -use crate::key_keeper::key::AuthorizationItem; -use crate::proxy::authorization_rules::ComputedAuthorizationItem; -use crate::{common::logger, key_keeper::key::Key}; +use proxy_agent_shared::common::error::Error; +use proxy_agent_shared::common::logger; +use proxy_agent_shared::common::result::Result; use std::sync::Arc; use tokio::sync::{mpsc, oneshot, Notify}; diff --git a/proxy_agent/src/shared_state/provision_wrapper.rs b/proxy_agent/src/shared_state/provision_wrapper.rs index e9e92609..7317c1df 100644 --- a/proxy_agent/src/shared_state/provision_wrapper.rs +++ b/proxy_agent/src/shared_state/provision_wrapper.rs @@ -28,10 +28,10 @@ //! assert_eq!(get_provision_finished, false); //! ``` -use crate::common::error::Error; -use crate::common::logger; -use crate::common::result::Result; use crate::provision::ProvisionFlags; +use proxy_agent_shared::common::error::Error; +use proxy_agent_shared::common::logger; +use proxy_agent_shared::common::result::Result; use proxy_agent_shared::misc_helpers; use tokio::sync::{mpsc, oneshot}; diff --git a/proxy_agent/src/shared_state/proxy_server_wrapper.rs b/proxy_agent/src/shared_state/proxy_server_wrapper.rs index e9ef02a2..82802ad9 100644 --- a/proxy_agent/src/shared_state/proxy_server_wrapper.rs +++ b/proxy_agent/src/shared_state/proxy_server_wrapper.rs @@ -16,10 +16,10 @@ //! assert_eq!(user.user_name, "user1"); //! ``` -use crate::common::error::Error; -use crate::common::logger; -use crate::common::result::Result; use crate::proxy::User; +use proxy_agent_shared::common::error::Error; +use proxy_agent_shared::common::logger; +use proxy_agent_shared::common::result::Result; use std::collections::HashMap; use tokio::sync::{mpsc, oneshot}; diff --git a/proxy_agent/src/shared_state/redirector_wrapper.rs b/proxy_agent/src/shared_state/redirector_wrapper.rs index a1e9d2dd..94d178e8 100644 --- a/proxy_agent/src/shared_state/redirector_wrapper.rs +++ b/proxy_agent/src/shared_state/redirector_wrapper.rs @@ -19,10 +19,10 @@ //! let bpf_object = redirector_shared_state.get_bpf_object().await.unwrap().unwrap(); //! ``` -use crate::common::error::Error; -use crate::common::logger; -use crate::common::result::Result; use crate::redirector; +use proxy_agent_shared::common::error::Error; +use proxy_agent_shared::common::logger; +use proxy_agent_shared::common::result::Result; use std::sync::{Arc, Mutex}; use tokio::sync::{mpsc, oneshot}; diff --git a/proxy_agent/src/shared_state/telemetry_wrapper.rs b/proxy_agent/src/shared_state/telemetry_wrapper.rs index f386de37..d4ae7ee3 100644 --- a/proxy_agent/src/shared_state/telemetry_wrapper.rs +++ b/proxy_agent/src/shared_state/telemetry_wrapper.rs @@ -14,9 +14,9 @@ //! assert_eq!(meta_data, vm_meta_data); //! ``` -use crate::common::result::Result; -use crate::common::{error::Error, logger}; use crate::telemetry::event_reader::VmMetaData; +use proxy_agent_shared::common::result::Result; +use proxy_agent_shared::common::{error::Error, logger}; use tokio::sync::{mpsc, oneshot}; enum TelemetryAction { diff --git a/proxy_agent/src/telemetry/event_reader.rs b/proxy_agent/src/telemetry/event_reader.rs index 06baadc7..b57c992f 100644 --- a/proxy_agent/src/telemetry/event_reader.rs +++ b/proxy_agent/src/telemetry/event_reader.rs @@ -41,13 +41,13 @@ use super::telemetry_event::TelemetryData; use super::telemetry_event::TelemetryEvent; -use crate::common::{constants, logger, result::Result}; use crate::host_clients::imds_client::ImdsClient; use crate::host_clients::wire_server_client::WireServerClient; use crate::shared_state::agent_status_wrapper::AgentStatusModule; use crate::shared_state::agent_status_wrapper::AgentStatusSharedState; use crate::shared_state::key_keeper_wrapper::KeyKeeperSharedState; use crate::shared_state::telemetry_wrapper::TelemetrySharedState; +use proxy_agent_shared::common::{constants, logger, result::Result}; use proxy_agent_shared::misc_helpers; use proxy_agent_shared::proxy_agent_aggregate_status::ModuleState; use proxy_agent_shared::telemetry::Event; @@ -376,9 +376,9 @@ impl EventReader { #[cfg(test)] mod tests { use super::*; - use crate::common::logger; use crate::key_keeper::key::Key; use crate::test_mock::server_mock; + use proxy_agent_shared::common::logger; use proxy_agent_shared::misc_helpers; use std::{env, fs}; diff --git a/proxy_agent/src/telemetry/telemetry_event.rs b/proxy_agent/src/telemetry/telemetry_event.rs index 84adee44..573d616f 100644 --- a/proxy_agent/src/telemetry/telemetry_event.rs +++ b/proxy_agent/src/telemetry/telemetry_event.rs @@ -47,8 +47,8 @@ //! ``` use super::event_reader::VmMetaData; -use crate::common::helpers; use once_cell::sync::Lazy; +use proxy_agent_shared::common::helpers; use proxy_agent_shared::telemetry::Event; use serde_derive::{Deserialize, Serialize}; diff --git a/proxy_agent/src/test_mock/server_mock.rs b/proxy_agent/src/test_mock/server_mock.rs index 394c7449..2ce86c1b 100644 --- a/proxy_agent/src/test_mock/server_mock.rs +++ b/proxy_agent/src/test_mock/server_mock.rs @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use crate::common::{hyper_client, logger, result::Result}; use crate::key_keeper; use crate::key_keeper::key::{Key, KeyStatus}; use http_body_util::combinators::BoxBody; @@ -13,6 +12,7 @@ use hyper::Response; use hyper::StatusCode; use hyper_util::rt::TokioIo; use once_cell::sync::Lazy; +use proxy_agent_shared::common::{hyper_client, logger, result::Result}; use tokio::net::TcpListener; use tokio_util::sync::CancellationToken; use uuid::Uuid; diff --git a/proxy_agent_shared/Cargo.toml b/proxy_agent_shared/Cargo.toml index f3f4e353..63ce6a59 100644 --- a/proxy_agent_shared/Cargo.toml +++ b/proxy_agent_shared/Cargo.toml @@ -18,12 +18,24 @@ thiserror = "1.0.64" tokio = { version = "1", features = ["rt", "macros", "sync", "time"] } log = { version = "0.4.26", features = ["std"] } ctor = "0.3.6" # used for test setup and clean up +base64 = "0.22.1" +quick-xml = { version = "0.38.0", features = ["serialize", "serde-types"]} +uuid = { version = "1.8", features = ["v4"] } +clap = { version = "4.5.17", features =["derive"] } # Command Line Argument Parser +http = "1.1.0" +http-body-util = "0.1" +hex = "0.4.3" # hex encode +hyper = { version = "1", features = ["server", "http1", "client"] } +hyper-util = { version = "0.1", features = ["tokio"] } +hmac-sha256 = "1.1.6" # use HMAC using the SHA-256 hash function +itertools = "0.10.5" # use to sort iterator elements into a new iterator in ascending order +serde-xml-rs = "0.8.1" # xml Deserializer with xml attribute [target.'cfg(windows)'.dependencies] windows-service = "0.7.0" # windows NT service winreg = "0.11.0" # windows reg read/write -serde-xml-rs = "0.8.1" # xml Deserializer with xml attribute chrono = "0.4.41" # parse date time string +libloading = "0.8.0" # for dynamic load libraries [target.'cfg(windows)'.dependencies.windows-sys] version = "0.52.0" @@ -38,7 +50,18 @@ features = [ "Win32_System_Diagnostics_Debug", "Win32_System_SystemInformation", "Win32_Storage_FileSystem", + "Win32_Security_Cryptography", ] [target.'cfg(not(windows))'.dependencies] -os_info = "3.7.0" # read Linux OS version and arch \ No newline at end of file +os_info = "3.7.0" # read Linux OS version and arch +sysinfo = "0.30.13" # read process information for Linux + +[target.'cfg(windows)'.dependencies.windows] +version = "0.61.3" +features = [ + "Win32_Foundation", + "Win32_Security_Cryptography", + "Win32_System_SystemInformation", + "Win32_System_Memory", +] \ No newline at end of file diff --git a/proxy_agent_shared/src/certificate/certificate_helper.rs b/proxy_agent_shared/src/certificate/certificate_helper.rs new file mode 100644 index 00000000..ce25b1e0 --- /dev/null +++ b/proxy_agent_shared/src/certificate/certificate_helper.rs @@ -0,0 +1,103 @@ +#[cfg(windows)] +use crate::certificate::certificate_helper_windows::CertificateDetailsWindows; +use crate::common::formatted_error::FormattedError; + +#[cfg(windows)] +type CertDetailsType = CertificateDetailsWindows; + +#[cfg(not(windows))] +type CertDetailsType = (); + +pub struct CertificateDetailsWrapper { + pub cert_details: CertDetailsType, +} + +impl CertificateDetailsWrapper { + pub fn get_public_cert_der(&self) -> &[u8] { + #[cfg(windows)] + { + &self.cert_details.public_key_der + } + #[cfg(not(windows))] + { + todo!() + } + } +} + +pub fn generate_self_signed_certificate( + _subject_name: &str, +) -> Result { + #[cfg(windows)] + { + use crate::certificate::certificate_helper_windows::generate_self_signed_certificate_windows; + + generate_self_signed_certificate_windows(_subject_name) + } + #[cfg(not(windows))] + { + Err("Linux version is not implemented.".to_string().into()) + } +} + +pub fn decrypt_from_base64( + _base64_input: &str, + _cert_details: &CertificateDetailsWrapper, +) -> Result { + #[cfg(windows)] + { + use crate::certificate::certificate_helper_windows::decrypt_from_base64_windows; + + decrypt_from_base64_windows(_base64_input, _cert_details) + } + #[cfg(not(windows))] + { + Err("Linux version is not implemented.".to_string().into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn generate_self_signed_certificate_test() { + #[cfg(windows)] + { + let subject_name = "TestSubject"; + let cert_details_result = generate_self_signed_certificate(subject_name); + assert!(cert_details_result.is_ok()); + let cert_details = cert_details_result.unwrap(); + let public_cert_der = cert_details.get_public_cert_der(); + assert!(!public_cert_der.is_empty()); + } + #[cfg(not(windows))] + { + // On non-Windows platforms, the function is not implemented. + // This test will simply ensure that the function is called without panic. + let subject_name = "TestSubject"; + let cert_details_result = generate_self_signed_certificate(subject_name); + assert!(cert_details_result.is_err()); + } + } + + #[test] + fn decrypt_from_base64_test() { + #[cfg(windows)] + { + let subject_name = "TestSubject"; + let cert_details_result = generate_self_signed_certificate(subject_name); + assert!(cert_details_result.is_ok()); + let cert_details = cert_details_result.unwrap(); + let result = decrypt_from_base64("invalid input", &cert_details); + assert!(result.is_err()); + } + #[cfg(not(windows))] + { + let result = decrypt_from_base64( + "invalid input", + &CertificateDetailsWrapper { cert_details: () }, + ); + assert!(result.is_err()); + } + } +} diff --git a/proxy_agent_shared/src/certificate/certificate_helper_windows.rs b/proxy_agent_shared/src/certificate/certificate_helper_windows.rs new file mode 100644 index 00000000..a2d310af --- /dev/null +++ b/proxy_agent_shared/src/certificate/certificate_helper_windows.rs @@ -0,0 +1,451 @@ +use base64::Engine; +use uuid::Uuid; +use windows::{ + core::{BOOL, PCWSTR, PSTR}, + Win32::{ + Security::Cryptography::{ + szOID_KEY_USAGE, szOID_SUBJECT_KEY_IDENTIFIER, CertCreateSelfSignCertificate, + CertFreeCertificateContext, CertStrToNameW, CryptAcquireCertificatePrivateKey, + CryptEncodeObjectEx, CryptExportPublicKeyInfo, CryptHashPublicKeyInfo, CryptMsgClose, + CryptMsgControl, CryptMsgGetParam, CryptMsgOpenToDecode, CryptMsgUpdate, + NCryptCreatePersistedKey, NCryptFinalizeKey, NCryptFreeObject, + NCryptOpenStorageProvider, NCryptSetProperty, CALG_SHA1, CERT_CONTEXT, + CERT_DATA_ENCIPHERMENT_KEY_USAGE, CERT_DIGITAL_SIGNATURE_KEY_USAGE, CERT_EXTENSION, + CERT_EXTENSIONS, CERT_KEY_CERT_SIGN_KEY_USAGE, CERT_KEY_ENCIPHERMENT_KEY_USAGE, + CERT_KEY_SPEC, CERT_OFFLINE_CRL_SIGN_KEY_USAGE, CERT_PUBLIC_KEY_INFO, + CERT_X500_NAME_STR, CMSG_CTRL_DECRYPT, CMSG_CTRL_DECRYPT_PARA, + CMSG_CTRL_DECRYPT_PARA_0, CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG, + CRYPT_ACQUIRE_COMPARE_KEY_FLAG, CRYPT_ACQUIRE_FLAGS, + CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, CRYPT_BIT_BLOB, CRYPT_INTEGER_BLOB, + HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, MS_KEY_STORAGE_PROVIDER, NCRYPT_ALLOW_EXPORT_FLAG, + NCRYPT_ALLOW_PLAINTEXT_EXPORT_FLAG, NCRYPT_EXPORT_POLICY_PROPERTY, NCRYPT_HANDLE, + NCRYPT_KEY_HANDLE, NCRYPT_LENGTH_PROPERTY, NCRYPT_PROV_HANDLE, NCRYPT_RSA_ALGORITHM, + PKCS_7_ASN_ENCODING, X509_ASN_ENCODING, + }, + System::SystemInformation::GetSystemTime, + }, +}; + +use crate::{ + certificate::certificate_helper::CertificateDetailsWrapper, + common::formatted_error::FormattedError, +}; + +pub struct CertificateDetailsWindows { + pub public_key_der: Vec, + pub p_cert_ctx: *mut CERT_CONTEXT, +} + +impl Drop for CertificateDetailsWindows { + fn drop(&mut self) { + if !self.p_cert_ctx.is_null() + && !unsafe { CertFreeCertificateContext(Some(self.p_cert_ctx)) }.as_bool() + { + eprintln!("Failed to free certificate context."); + } + } +} + +pub fn generate_self_signed_certificate_windows( + subject_name: &str, +) -> Result { + // Open KSP + let mut h_prov = NCRYPT_PROV_HANDLE(0); + unsafe { + NCryptOpenStorageProvider(&mut h_prov, MS_KEY_STORAGE_PROVIDER, 0)?; + } + + // Create an RSA key + let key_name: Vec = Uuid::new_v4().to_string().encode_utf16().collect(); + let mut h_key = NCRYPT_KEY_HANDLE(0); + unsafe { + NCryptCreatePersistedKey( + h_prov, + &mut h_key, + NCRYPT_RSA_ALGORITHM, + PCWSTR(key_name.as_ptr()), // not NULL + windows::Win32::Security::Cryptography::CERT_KEY_SPEC(0), + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + } + + let key_length: u32 = 2048; + unsafe { + // Set key length property to 2048 bits + NCryptSetProperty( + h_key.into(), + NCRYPT_LENGTH_PROPERTY, + &key_length.to_ne_bytes(), + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + + // Set key export policy + NCryptSetProperty( + h_key.into(), + NCRYPT_EXPORT_POLICY_PROPERTY, + &(NCRYPT_ALLOW_EXPORT_FLAG | NCRYPT_ALLOW_PLAINTEXT_EXPORT_FLAG).to_ne_bytes(), + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + + // Finalize the key + NCryptFinalizeKey( + h_key, + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + } + + // Set up subject name for cert + let subject = format!("CN={subject_name}"); + let subject_w: Vec = subject.encode_utf16().chain(Some(0)).collect(); + let mut size = 0u32; + unsafe { + CertStrToNameW( + X509_ASN_ENCODING, + PCWSTR(subject_w.as_ptr()), + CERT_X500_NAME_STR, + Some(std::ptr::null_mut()), + Some(std::ptr::null_mut()), + &mut size, + Some(std::ptr::null_mut()), + )?; + } + let mut name_buf = vec![0u8; size as usize]; + unsafe { + CertStrToNameW( + X509_ASN_ENCODING, + PCWSTR(subject_w.as_ptr()), + CERT_X500_NAME_STR, + None, + Some(name_buf.as_mut_ptr()), + &mut size, + Some(std::ptr::null_mut()), + )?; + } + + let subject_blob = CRYPT_INTEGER_BLOB { + cbData: size, + pbData: name_buf.as_mut_ptr(), + }; + + // Validity period + let mut start = unsafe { GetSystemTime() }; + start.wYear -= 1; + let mut end = unsafe { GetSystemTime() }; + end.wYear += 3; + + let mut exts = build_cert_extensions(HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(h_key.0))?; + + let cert_exts = CERT_EXTENSIONS { + cExtension: exts.len() as u32, + rgExtension: exts.as_mut_ptr(), + }; + + // Create cert + let cert_ctx = unsafe { + CertCreateSelfSignCertificate( + Some(HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(h_key.0)), + &subject_blob, + windows::Win32::Security::Cryptography::CERT_CREATE_SELFSIGN_FLAGS(0), + None, + None, + Some(&start), + Some(&end), + Some(&cert_exts), + ) + }; + + // Get cert data + let cert_der = unsafe { + std::slice::from_raw_parts( + (*cert_ctx).pbCertEncoded, + (*cert_ctx).cbCertEncoded as usize, + ) + }; + + let cert_detials_windows = CertificateDetailsWindows { + public_key_der: cert_der.to_vec(), + p_cert_ctx: cert_ctx, + }; + let res = CertificateDetailsWrapper { + cert_details: cert_detials_windows, + }; + + // Cleanup + unsafe { + NCryptFreeObject(h_prov.into())?; + NCryptFreeObject(h_key.into())?; + }; + Ok(res) +} + +pub fn decrypt_from_base64_windows( + base64_input: &str, + cert_details_wrapper: &CertificateDetailsWrapper, +) -> Result { + let encrypted = base64_input.replace("\r", "").replace("\n", ""); + let encrypted_payload = &base64::engine::general_purpose::STANDARD.decode(encrypted)?; + let p_cert_ctx = cert_details_wrapper.cert_details.p_cert_ctx; + + // Acquire the private key handle using the CNG-compatible function. + let mut h_key = HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(0); + let mut key_spec = CERT_KEY_SPEC(0u32); + let mut must_free = BOOL(0); + unsafe { + CryptAcquireCertificatePrivateKey( + p_cert_ctx, + CRYPT_ACQUIRE_FLAGS( + CRYPT_ACQUIRE_COMPARE_KEY_FLAG.0 + | CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG.0 + | CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG.0, + ), + None, + &mut h_key, + Some(&mut key_spec), + Some(&mut must_free), + ) + }?; + + // Decode the encrypted message. + let msg_handle = unsafe { + CryptMsgOpenToDecode( + X509_ASN_ENCODING.0 | PKCS_7_ASN_ENCODING.0, + 0, + 0, + None, + None, + None, + ) + }; + + if msg_handle.is_null() { + return Err(FormattedError { + code: -1, + message: "Failed to open message handle to decrypt.".to_string(), + }); + } + unsafe { CryptMsgUpdate(msg_handle, Some(encrypted_payload), true) }?; + // Create an instance of the nested struct (the union) + let anonymous_union = CMSG_CTRL_DECRYPT_PARA_0 { + hCryptProv: h_key.0, + }; + + // Create the main struct instance + let mut decrypt_para = CMSG_CTRL_DECRYPT_PARA { + cbSize: std::mem::size_of::() as u32, + Anonymous: anonymous_union, + dwKeySpec: key_spec.0, + dwRecipientIndex: 0, + }; + + unsafe { + CryptMsgControl( + msg_handle, + 0, + CMSG_CTRL_DECRYPT, + Some(&mut decrypt_para as *mut _ as *mut _), + ) + }?; + + // Get the decrypted message size. + let mut content_size = 0; + unsafe { CryptMsgGetParam(msg_handle, 2, 0, None, &mut content_size) }?; + + // Get the decrypted message content. + let mut decrypted_data_buffer = vec![0u8; content_size as usize]; + unsafe { + CryptMsgGetParam( + msg_handle, + 2, + 0, + Some(decrypted_data_buffer.as_mut_ptr() as *mut _), + &mut content_size, + ) + }?; + + unsafe { CryptMsgClose(Some(msg_handle)) }?; + if must_free.as_bool() { + unsafe { NCryptFreeObject(NCRYPT_HANDLE(h_key.0)) }?; + } + + let res = String::from_utf8(decrypted_data_buffer)?; + Ok(res) +} + +fn build_cert_extensions( + h_key: HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, +) -> Result, FormattedError> { + let mut extensions: Vec = Vec::new(); + + // Key Usage + let key_usage: u8 = CERT_DIGITAL_SIGNATURE_KEY_USAGE as u8 + | CERT_KEY_CERT_SIGN_KEY_USAGE as u8 + | CERT_OFFLINE_CRL_SIGN_KEY_USAGE as u8 + | CERT_KEY_ENCIPHERMENT_KEY_USAGE as u8 + | CERT_DATA_ENCIPHERMENT_KEY_USAGE as u8; + + let key_usage_blob = CRYPT_BIT_BLOB { + cbData: 1, + pbData: &key_usage as *const u8 as *mut u8, + cUnusedBits: 0, + }; + + let mut encoded_key_usage_len: u32 = 0; + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_KEY_USAGE, + &key_usage_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + None, + &mut encoded_key_usage_len, + )?; + } + + let mut encoded_key_usage = vec![0u8; encoded_key_usage_len as usize]; + + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_KEY_USAGE, + &key_usage_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + Some(encoded_key_usage.as_mut_ptr() as *mut _), + &mut encoded_key_usage_len, + )?; + } + + extensions.push(CERT_EXTENSION { + pszObjId: PSTR(szOID_KEY_USAGE.as_ptr() as *mut _), + fCritical: BOOL(0), // FALSE + Value: CRYPT_INTEGER_BLOB { + cbData: encoded_key_usage_len, + pbData: encoded_key_usage.as_mut_ptr(), + }, + }); + + let mut size = 0; + unsafe { + CryptExportPublicKeyInfo(h_key, Some(0), X509_ASN_ENCODING, None, &mut size)?; + } + let mut buffer = vec![0u8; size as usize]; + + let p_info = buffer.as_mut_ptr() as *mut CERT_PUBLIC_KEY_INFO; + + // Subject Key Identifier (let Windows generate it) + unsafe { + CryptExportPublicKeyInfo(h_key, Some(0), X509_ASN_ENCODING, Some(p_info), &mut size) + }?; + + let mut ski_hash = [0u8; 20]; + let mut ski_size = ski_hash.len() as u32; + unsafe { + CryptHashPublicKeyInfo( + None, + CALG_SHA1, + 0, + X509_ASN_ENCODING, + p_info, + Some(ski_hash.as_mut_ptr()), + &mut ski_size, + ) + }?; + + let ski_blob = CRYPT_INTEGER_BLOB { + cbData: ski_size, + pbData: ski_hash.as_mut_ptr(), + }; + + let mut encoded_ski_size = 0; + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_SUBJECT_KEY_IDENTIFIER, + &ski_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + None, + &mut encoded_ski_size, + ) + }?; + + let mut encoded_ski = vec![0u8; encoded_ski_size as usize]; + + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_SUBJECT_KEY_IDENTIFIER, + &ski_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + Some(encoded_ski.as_mut_ptr() as *mut _), + &mut encoded_ski_size, + ) + }?; + + extensions.push(CERT_EXTENSION { + pszObjId: PSTR(szOID_SUBJECT_KEY_IDENTIFIER.as_ptr() as *mut _), + fCritical: BOOL(0), // FALSE + Value: CRYPT_INTEGER_BLOB { + cbData: encoded_ski_size, + pbData: encoded_ski.as_mut_ptr(), + }, + }); + Ok(extensions) +} + +#[cfg(test)] +mod tests { + use super::*; + use windows::{ + core::PSTR, + Win32::Security::Cryptography::{ + szOID_NIST_AES256_CBC, CryptEncryptMessage, CRYPT_ENCRYPT_MESSAGE_PARA, + }, + }; + + #[test] + fn test_certificate_decryption() { + let cert = generate_self_signed_certificate_windows(&Uuid::new_v4().to_string()).unwrap(); + + let org_str = "Hello, World!"; + let encrypted = encrypt(&cert.cert_details, org_str); + + let decrypted = decrypt_from_base64_windows(&encrypted, &cert).unwrap(); + + assert!(decrypted.eq(org_str)) + } + + pub fn encrypt(cert: &CertificateDetailsWindows, org_str: &str) -> String { + let mut info = CRYPT_ENCRYPT_MESSAGE_PARA::default(); + info.cbSize = std::mem::size_of::() as u32; + info.dwMsgEncodingType = X509_ASN_ENCODING.0 | PKCS_7_ASN_ENCODING.0; + info.ContentEncryptionAlgorithm.pszObjId = PSTR(szOID_NIST_AES256_CBC.as_ptr() as *mut _); + info.dwFlags = 0; + let cert_ctx_ptrs = [cert.p_cert_ctx as *const _]; + let mut encrypted_size: u32 = 0; + unsafe { + CryptEncryptMessage( + &info, + &cert_ctx_ptrs, + Some(org_str.as_bytes()), + None, + &mut encrypted_size as *mut u32, + ) + .unwrap(); + } + let mut encrypted_data = vec![0u8; encrypted_size as usize]; + unsafe { + CryptEncryptMessage( + &info, + &cert_ctx_ptrs, + Some(org_str.as_bytes()), + Some(encrypted_data.as_mut_ptr()), + &mut encrypted_size as *mut u32, + ) + .unwrap(); + } + return base64::engine::general_purpose::STANDARD.encode(&encrypted_data); + } +} diff --git a/proxy_agent_shared/src/certificate/mod.rs b/proxy_agent_shared/src/certificate/mod.rs new file mode 100644 index 00000000..9ac1744b --- /dev/null +++ b/proxy_agent_shared/src/certificate/mod.rs @@ -0,0 +1,3 @@ +pub mod certificate_helper; +#[cfg(windows)] +pub mod certificate_helper_windows; diff --git a/proxy_agent/src/common.rs b/proxy_agent_shared/src/common.rs similarity index 93% rename from proxy_agent/src/common.rs rename to proxy_agent_shared/src/common.rs index 3afa2dac..37e2ebdd 100644 --- a/proxy_agent/src/common.rs +++ b/proxy_agent_shared/src/common.rs @@ -4,6 +4,7 @@ pub mod cli; pub mod config; pub mod constants; pub mod error; +pub mod formatted_error; pub mod helpers; pub mod hyper_client; pub mod logger; diff --git a/proxy_agent_shared/src/common/cli.rs b/proxy_agent_shared/src/common/cli.rs new file mode 100644 index 00000000..aba9239f --- /dev/null +++ b/proxy_agent_shared/src/common/cli.rs @@ -0,0 +1,119 @@ +// Copyright (c) Microsoft Corporation +// SPDX-License-Identifier: MIT + +use clap::{Parser, Subcommand}; +use once_cell::sync::Lazy; + +/// azure-proxy-agent console - launch a long run process of GPA in console mode. +/// azure-proxy-agent --version - print the version of the GPA. +/// azure-proxy-agent --status [--wait ] - get the provision status of the GPA service. +/// azure-proxy-agent - start the GPA as an OS service. +/// The GPA service will be started as an OS service in the background. +#[derive(Parser)] +#[command()] +pub struct Cli { + /// get the provision status of the GPA service + #[arg(short, long)] + pub status: bool, + + /// wait for the provision status to finish + #[arg(short, long, requires = "status")] + pub wait: Option, + + /// print the version of the GPA + #[arg(short, long)] + pub version: bool, + + #[cfg(test)] + #[arg(short, long)] + test_threads: Option, + + #[cfg(test)] + #[arg(short, long)] + nocapture: bool, + + #[command(subcommand)] + pub command: Option, +} + +#[derive(Subcommand)] +pub enum Commands { + /// launch a long run process of GPA in console mode + Console, +} + +impl Cli { + pub fn is_console_mode(&self) -> bool { + self.command.is_some() + } +} + +pub static CLI: Lazy = Lazy::new(Cli::parse); + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_version_flag() { + let cli = Cli::try_parse_from(["command", "--version"]).unwrap(); + assert!(cli.version); + assert!(!cli.status); + assert!(cli.wait.is_none()); + assert!(cli.command.is_none()); + } + + #[test] + fn test_parse_status_with_wait() { + let cli = Cli::try_parse_from(["command", "--status", "--wait", "10"]).unwrap(); + assert!(cli.status); + assert_eq!(cli.wait, Some(10)); + assert!(!cli.version); + assert!(cli.command.is_none()); + } + + #[test] + fn test_parse_console_command() { + let cli = Cli::try_parse_from(["command", "console"]).unwrap(); + assert!(cli.is_console_mode()); + match cli.command { + Some(Commands::Console) => {} + _ => panic!("Expected Commands::Console"), + } + } + + #[test] + fn test_parse_status_without_wait() { + let cli = Cli::try_parse_from(["command", "--status"]).unwrap(); + assert!(cli.status); + assert!(cli.wait.is_none()); + assert!(!cli.version); + } + + #[test] + fn test_conflicting_wait_without_status_fails() { + let result = Cli::try_parse_from(["command", "--wait", "5"]); + assert!(result.is_err(), "wait without --status should fail"); + } + + #[test] + fn test_no_args_defaults() { + let cli = Cli::try_parse_from(["command"]).unwrap(); + assert!(!cli.status); + assert!(!cli.version); + assert!(cli.wait.is_none()); + assert!(cli.command.is_none()); + assert!(!cli.is_console_mode()); + } + + #[cfg(test)] + #[test] + fn test_test_only_flags() { + let cli = + Cli::try_parse_from(["command", "--status", "--test-threads", "4", "--nocapture"]) + .unwrap(); + assert!(cli.status); + assert_eq!(cli.test_threads, Some(4)); + assert!(cli.nocapture); + } +} diff --git a/proxy_agent/src/common/config.rs b/proxy_agent_shared/src/common/config.rs similarity index 70% rename from proxy_agent/src/common/config.rs rename to proxy_agent_shared/src/common/config.rs index d9d1fdab..ff8dd16d 100644 --- a/proxy_agent/src/common/config.rs +++ b/proxy_agent_shared/src/common/config.rs @@ -6,7 +6,7 @@ //! //! Example //! ```rust -//! use proxy_agent::config; +//! use crate::common::config; //! //! // Get the logs directory //! let logs_dir = config::get_logs_dir(); @@ -17,8 +17,8 @@ //! ``` use crate::common::constants; +use crate::{logger::LoggerLevel, misc_helpers}; use once_cell::sync::Lazy; -use proxy_agent_shared::{logger::LoggerLevel, misc_helpers}; use serde_derive::{Deserialize, Serialize}; use std::str::FromStr; use std::{path::PathBuf, time::Duration}; @@ -221,7 +221,8 @@ impl Config { mod tests { use crate::common::config::Config; use crate::common::constants; - use proxy_agent_shared::misc_helpers; + use crate::logger::LoggerLevel; + use crate::misc_helpers; use std::fs::File; use std::io::Write; use std::path::PathBuf; @@ -297,13 +298,13 @@ mod tests { } assert_eq!( - proxy_agent_shared::logger::LoggerLevel::Info, + crate::logger::LoggerLevel::Info, config.get_file_log_level_for_events().unwrap(), "get_file_log_level_for_events mismatch" ); assert_eq!( - proxy_agent_shared::logger::LoggerLevel::Info, + crate::logger::LoggerLevel::Info, config.get_file_log_level_for_system_events().unwrap(), "get_file_log_level_for_system_events mismatch" ); @@ -354,4 +355,121 @@ mod tests { .unwrap(); Config::from_json_file(file_path) } + + #[test] + fn test_optional_fields_some_values() { + let mut temp_path = env::temp_dir(); + temp_path.push("config_optional_some"); + _ = fs::remove_dir_all(&temp_path); + fs::create_dir_all(&temp_path).unwrap(); + let config_path = temp_path.join("config.json"); + + let data = r#"{ + "logFolder": "C:\\logFolderName", + "eventFolder": "C:\\eventFolderName", + "latchKeyFolder": "C:\\latchKeyFolderName", + "monitorIntervalInSeconds": 120, + "pollKeyStatusIntervalInSeconds": 30, + "hostGAPluginSupport": 2, + "maxEventFileCount": 42, + "ebpfFileFullPath": "C:\\ebpf.o", + "ebpfProgramName": "ebpfCustom", + "fileLogLevel": "Error", + "fileLogLevelForEvents": "Warn", + "fileLogLevelForSystemEvents": "Debug", + "enableHttpProxyTrace": true + }"#; + + let config = create_custom_config_file(config_path, data); + + assert_eq!(config.get_max_event_file_count(), 42); + assert_eq!( + config.get_ebpf_file_full_path(), + Some(PathBuf::from("C:\\ebpf.o")) + ); + assert_eq!(config.get_ebpf_program_name(), "ebpfCustom"); + assert_eq!(config.get_file_log_level(), LoggerLevel::Error); + assert_eq!( + config.get_file_log_level_for_events().unwrap(), + LoggerLevel::Warn + ); + assert_eq!( + config.get_file_log_level_for_system_events().unwrap(), + LoggerLevel::Debug + ); + assert_eq!(config.enableHttpProxyTrace.unwrap(), true); + + _ = fs::remove_dir_all(&temp_path); + } + + #[test] + fn test_optional_fields_none_values() { + let mut temp_path = env::temp_dir(); + temp_path.push("config_optional_none"); + _ = fs::remove_dir_all(&temp_path); + fs::create_dir_all(&temp_path).unwrap(); + let config_path = temp_path.join("config.json"); + + let data = r#"{ + "logFolder": "C:\\logFolderName", + "eventFolder": "C:\\eventFolderName", + "latchKeyFolder": "C:\\latchKeyFolderName", + "monitorIntervalInSeconds": 60, + "pollKeyStatusIntervalInSeconds": 15, + "hostGAPluginSupport": 1, + "ebpfProgramName": "ebpfProgramName" + }"#; + + let config = create_custom_config_file(config_path, data); + + // optional fields fallback + assert_eq!( + config.get_max_event_file_count(), + constants::DEFAULT_MAX_EVENT_FILE_COUNT + ); + assert_eq!(config.get_ebpf_file_full_path(), None); + assert_eq!(config.get_file_log_level(), LoggerLevel::Info); + assert_eq!(config.get_file_log_level_for_events(), None); + assert_eq!(config.get_file_log_level_for_system_events(), None); + + _ = fs::remove_dir_all(&temp_path); + } + + #[test] + fn test_invalid_log_level_fallback() { + let config = Config { + logFolder: "log".to_string(), + eventFolder: "event".to_string(), + latchKeyFolder: "latch".to_string(), + monitorIntervalInSeconds: 0, + pollKeyStatusIntervalInSeconds: 0, + hostGAPluginSupport: 0, + maxEventFileCount: None, + ebpfFileFullPath: None, + ebpfProgramName: "ebpf".to_string(), + fileLogLevel: Some("InvalidLevel".to_string()), + fileLogLevelForEvents: Some("InvalidLevel".to_string()), + fileLogLevelForSystemEvents: Some("InvalidLevel".to_string()), + enableHttpProxyTrace: None, + #[cfg(not(windows))] + cgroupRoot: None, + }; + + assert_eq!(config.get_file_log_level(), LoggerLevel::Info); + assert_eq!( + config.get_file_log_level_for_events().unwrap(), + LoggerLevel::Info + ); + assert_eq!( + config.get_file_log_level_for_system_events().unwrap(), + LoggerLevel::Info + ); + } + + fn create_custom_config_file(file_path: PathBuf, content: &str) -> Config { + fs::create_dir_all(file_path.parent().unwrap()).unwrap(); + let mut file = fs::File::create(&file_path).unwrap(); + file.write_all(content.as_bytes()).unwrap(); + Config::from_json_file(file_path) + } } diff --git a/proxy_agent/src/common/constants.rs b/proxy_agent_shared/src/common/constants.rs similarity index 100% rename from proxy_agent/src/common/constants.rs rename to proxy_agent_shared/src/common/constants.rs diff --git a/proxy_agent/src/common/error.rs b/proxy_agent_shared/src/common/error.rs similarity index 96% rename from proxy_agent/src/common/error.rs rename to proxy_agent_shared/src/common/error.rs index 4dee7fb0..75c87c7c 100644 --- a/proxy_agent/src/common/error.rs +++ b/proxy_agent_shared/src/common/error.rs @@ -3,6 +3,8 @@ use http::{uri::InvalidUri, StatusCode}; +use crate::common::formatted_error::FormattedError; + #[derive(Debug, thiserror::Error)] pub enum Error { #[error("IO error: {0}: {1}")] @@ -48,6 +50,15 @@ pub enum Error { #[error("{0}")] FindAuditEntryError(String), + + #[error("{0}")] + OtherError(FormattedError), +} + +impl From for Error { + fn from(value: FormattedError) -> Self { + Error::OtherError(value) + } } #[derive(Debug, thiserror::Error)] diff --git a/proxy_agent_shared/src/common/formatted_error.rs b/proxy_agent_shared/src/common/formatted_error.rs new file mode 100644 index 00000000..ec09ed47 --- /dev/null +++ b/proxy_agent_shared/src/common/formatted_error.rs @@ -0,0 +1,126 @@ +use std::{fmt, string::FromUtf8Error}; + +use base64::DecodeError; +use tokio::time::error::Elapsed; + +#[derive(Debug, Clone)] +pub struct FormattedError { + pub message: String, + pub code: i32, +} + +impl fmt::Display for FormattedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "code: {}, message: {}", self.code, self.message) + } +} + +impl std::error::Error for FormattedError {} + +impl From for FormattedError { + fn from(value: DecodeError) -> Self { + FormattedError { + message: format!("Decode Error: {value:?}"), + code: -1, + } + } +} + +impl From for FormattedError { + fn from(value: FromUtf8Error) -> Self { + FormattedError { + message: format!("Utf-8 Convert Error: {value:?}"), + code: -1, + } + } +} + +impl From for FormattedError { + fn from(value: serde_json::Error) -> Self { + FormattedError { + message: format!("Json Error: {value:?}"), + code: -1, + } + } +} + +impl From for FormattedError { + fn from(value: String) -> Self { + FormattedError { + message: format!("GeneralError: {value}"), + code: -1, + } + } +} + +impl From for FormattedError { + fn from(value: Elapsed) -> Self { + FormattedError { + message: format!("Operation timeout: {value}"), + code: -1, + } + } +} + +#[cfg(windows)] +impl From for FormattedError { + fn from(value: windows::core::Error) -> Self { + FormattedError { + message: format!("Windows API Error: {value:?}"), + code: -1, + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::time::timeout; + + use super::*; + + #[test] + fn formatted_error_display_test() { + let error = FormattedError { + message: "An error occurred".to_string(), + code: 404, + }; + assert_eq!( + format!("{}", error), + "code: 404, message: An error occurred" + ); + } + + #[tokio::test] + async fn formatted_error_from_test() { + let decode_error = DecodeError::InvalidLength(0); + let formatted_error: FormattedError = decode_error.into(); + assert_eq!(formatted_error.message, "Decode Error: InvalidLength(0)"); + + let utf8_bytes = vec![0, 159, 146, 150]; + let utf8_error = String::from_utf8(utf8_bytes).unwrap_err(); + let formatted_error: FormattedError = utf8_error.into(); + assert!(formatted_error.message.starts_with("Utf-8 Convert Error:")); + + let json_error = serde_json::from_str::("invalid json").unwrap_err(); + let formatted_error: FormattedError = json_error.into(); + assert!(formatted_error.message.starts_with("Json Error:")); + + let elapsed_error = timeout(Duration::from_millis(10), async { + tokio::time::sleep(Duration::from_secs(1)).await; + }) + .await; + + let elapsed_error = elapsed_error.unwrap_err(); + let formatted_error: FormattedError = elapsed_error.into(); + assert!(formatted_error.message.starts_with("Operation timeout")); + + #[cfg(windows)] + { + let windows_error = windows::core::Error::from_win32(); + let formatted_error: FormattedError = windows_error.into(); + assert!(formatted_error.message.starts_with("Windows API Error:")); + } + } +} diff --git a/proxy_agent/src/common/helpers.rs b/proxy_agent_shared/src/common/helpers.rs similarity index 97% rename from proxy_agent/src/common/helpers.rs rename to proxy_agent_shared/src/common/helpers.rs index 112c957a..225505e8 100644 --- a/proxy_agent/src/common/helpers.rs +++ b/proxy_agent_shared/src/common/helpers.rs @@ -2,9 +2,9 @@ // SPDX-License-Identifier: MIT use super::result::Result; use super::{error::Error, logger}; +use crate::misc_helpers; +use crate::telemetry::span::SimpleSpan; use once_cell::sync::Lazy; -use proxy_agent_shared::misc_helpers; -use proxy_agent_shared::telemetry::span::SimpleSpan; #[cfg(not(windows))] use sysinfo::{CpuRefreshKind, MemoryRefreshKind, RefreshKind, System}; diff --git a/proxy_agent/src/common/hyper_client.rs b/proxy_agent_shared/src/common/hyper_client.rs similarity index 86% rename from proxy_agent/src/common/hyper_client.rs rename to proxy_agent_shared/src/common/hyper_client.rs index bc52c375..a89116ad 100644 --- a/proxy_agent/src/common/hyper_client.rs +++ b/proxy_agent_shared/src/common/hyper_client.rs @@ -5,11 +5,12 @@ //! //! Example //! ```rust -//! use proxy_agent::hyper_client; -//! use host_clients::goal_state::GoalState; +//! use crate::common::hyper_client; +//! use crate::host_clients::data_model::wire_server_model::GoalState; //! use std::collections::HashMap; //! use hyper::Uri; //! use std::str::FromStr; +//! use http::Method; //! //! let mut headers = HashMap::new(); //! headers.insert("x-ms-version".to_string(), "2012-11-30".to_string()); @@ -37,6 +38,7 @@ use super::error::{Error, HyperErrorType}; use super::result::Result; use super::{constants, helpers}; +use crate::misc_helpers; use http::request::Builder; use http::request::Parts; use http::Method; @@ -49,7 +51,6 @@ use hyper::Request; use hyper::Uri; use hyper_util::rt::TokioIo; use itertools::Itertools; -use proxy_agent_shared::misc_helpers; use serde::de::DeserializeOwned; use std::collections::HashMap; use tokio::net::TcpStream; @@ -82,9 +83,7 @@ where read_response_body(response).await } -pub async fn read_response_body( - mut response: hyper::Response, -) -> Result +pub async fn read_response_body(response: hyper::Response) -> Result where T: DeserializeOwned, { @@ -123,6 +122,37 @@ where ("unknown", "unknown") }; + let body_string = read_response_body_as_string(response, charset_type).await?; + + match content_type { + "xml" => match serde_xml_rs::from_str(&body_string) { + Ok(t) => Ok(t), + Err(e) => Err(Error::Hyper( + HyperErrorType::Deserialize( + format!( + "Failed to xml deserialize response body with content_type {content_type} from: {body_string} with error {e}" + ) + ), + )), + }, + // default to json + _ => match serde_json::from_str(&body_string) { + Ok(t) => Ok(t), + Err(e) => Err(Error::Hyper( + HyperErrorType::Deserialize( + format!( + "Failed to json deserialize response body with {content_type} from: {body_string} with error {e}" + ) + ), + )), + }, + } +} + +pub async fn read_response_body_as_string( + mut response: hyper::Response, + charset_type: &str, +) -> Result { let mut body_string = String::new(); while let Some(next) = response.frame().await { let frame = match next { @@ -159,30 +189,7 @@ where }; } } - - match content_type { - "xml" => match serde_xml_rs::from_str(&body_string) { - Ok(t) => Ok(t), - Err(e) => Err(Error::Hyper( - HyperErrorType::Deserialize( - format!( - "Failed to xml deserialize response body with content_type {content_type} from: {body_string} with error {e}" - ) - ), - )), - }, - // default to json - _ => match serde_json::from_str(&body_string) { - Ok(t) => Ok(t), - Err(e) => Err(Error::Hyper( - HyperErrorType::Deserialize( - format!( - "Failed to json deserialize response body with {content_type} from: {body_string} with error {e}" - ) - ), - )), - }, - } + Ok(body_string) } pub fn build_request( @@ -501,6 +508,10 @@ pub fn should_skip_sig(method: &hyper::Method, relative_uri: &Uri) -> bool { #[cfg(test)] mod tests { + use http::{HeaderMap, HeaderValue, Method}; + + use crate::common::constants; + #[test] fn get_path_and_canonicalized_parameters_test() { let url_str = "/machine/a8016240-7286-49ef-8981-63520cb8f6d0/49c242ba%2Dc18a%2D4f6c%2D8cf8%2D85ff790b6431.%5Fzpeng%2Debpf%2Dvm2?comp=config&keyOnly&comp=again&type=hostingEnvironmentConfig&incarnation=1&resource=https%3a%2f%2fstorage.azure.com%2f"; @@ -513,4 +524,68 @@ mod tests { "query parameters mismatch" ); } + + #[test] + fn query_pairs_basic() { + let uri = "/test?name=test&key=value&empty=".parse().unwrap(); + let pairs = super::query_pairs(&uri); + assert_eq!( + pairs, + vec![ + ("name".to_string(), "test".to_string()), + ("key".to_string(), "value".to_string()), + ("empty".to_string(), "".to_string()), + ] + ); + } + + #[test] + fn query_pairs_ignore_empty_key() { + let uri = "/test?=value&valid=1".parse().unwrap(); + let pairs = super::query_pairs(&uri); + assert_eq!(pairs, vec![("valid".to_string(), "1".to_string())]); + } + + #[test] + fn headers_to_canonicalized_string_sorts_and_skips_auth() { + let mut headers = HeaderMap::new(); + headers.insert("Test-Header", HeaderValue::from_static("test")); + headers.insert( + constants::AUTHORIZATION_HEADER, + HeaderValue::from_static("should-skip"), + ); + + let result = super::headers_to_canonicalized_string(&headers); + // "a-header" should come first, auth header skipped + assert!(result.starts_with("test-header:test\n")); + assert!(!result.contains("should-skip")); + } + + #[test] + fn host_port_from_uri_defaults_port() { + let uri = "http://example.com/test".parse().unwrap(); + let (host, port) = super::host_port_from_uri(&uri).unwrap(); + assert_eq!(host, "example.com"); + assert_eq!(port, 80); + } + + #[test] + fn host_port_from_uri_with_port() { + let uri = "http://example.com:8080/test".parse().unwrap(); + let (host, port) = super::host_port_from_uri(&uri).unwrap(); + assert_eq!(host, "example.com"); + assert_eq!(port, 8080); + } + + #[test] + fn should_skip_sig_matches() { + let put_uri = "/vmAgentLog".parse().unwrap(); + assert!(super::should_skip_sig(&Method::PUT, &put_uri)); + + let post_uri = "/machine/?comp=telemetrydata".parse().unwrap(); + assert!(super::should_skip_sig(&Method::POST, &post_uri)); + + let other_uri = "/machine/?comp=goalstate".parse().unwrap(); + assert!(!super::should_skip_sig(&Method::GET, &other_uri)); + } } diff --git a/proxy_agent/src/common/logger.rs b/proxy_agent_shared/src/common/logger.rs similarity index 96% rename from proxy_agent/src/common/logger.rs rename to proxy_agent_shared/src/common/logger.rs index 437a3f69..27fe6f80 100644 --- a/proxy_agent/src/common/logger.rs +++ b/proxy_agent_shared/src/common/logger.rs @@ -1,6 +1,6 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT -use proxy_agent_shared::{ +use crate::{ logger::{logger_manager, LoggerLevel}, telemetry::event_logger, }; @@ -43,7 +43,7 @@ fn log(log_level: LoggerLevel, message: String) { #[cfg(not(windows))] pub fn write_serial_console_log(message: String) { - use proxy_agent_shared::misc_helpers; + use crate::misc_helpers; use std::io::Write; let message = format!( diff --git a/proxy_agent/src/common/result.rs b/proxy_agent_shared/src/common/result.rs similarity index 100% rename from proxy_agent/src/common/result.rs rename to proxy_agent_shared/src/common/result.rs diff --git a/proxy_agent/src/common/windows.rs b/proxy_agent_shared/src/common/windows.rs similarity index 99% rename from proxy_agent/src/common/windows.rs rename to proxy_agent_shared/src/common/windows.rs index 5513695e..aa433ce1 100644 --- a/proxy_agent/src/common/windows.rs +++ b/proxy_agent_shared/src/common/windows.rs @@ -139,7 +139,7 @@ pub fn fetch_key_data(encrypted_file_path: &Path) -> Result { mod tests { use std::{env, fs}; - use proxy_agent_shared::misc_helpers; + use crate::misc_helpers; #[test] fn get_processor_count_test() { diff --git a/proxy_agent_shared/src/host_clients/data_model/hostga_plugin_model.rs b/proxy_agent_shared/src/host_clients/data_model/hostga_plugin_model.rs new file mode 100644 index 00000000..7cb38cd9 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/data_model/hostga_plugin_model.rs @@ -0,0 +1,251 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct VMSettings { + #[serde(rename = "hostGAPluginVersion")] + pub host_ga_plugin_version: Option, + #[serde(rename = "activityId")] + pub activity_id: Option, + #[serde(rename = "correlationId")] + pub correlation_id: Option, + #[serde(rename = "inSvdSeqNo")] + pub in_svd_seq_no: Option, + #[serde(rename = "certificatesRevision")] + pub certificates_revision: Option, + #[serde(rename = "extensionsLastModifiedTickCount")] + pub extensions_last_modified_tick_count: Option, + #[serde(rename = "extensionGoalStatesSource")] + pub extension_goal_states_source: Option, + #[serde(rename = "statusUploadBlob")] + pub status_upload_blob: Option, + #[serde(rename = "gaFamilies")] + pub ga_families: Option>, + #[serde(rename = "extensionGoalStates")] + pub extension_goal_states: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusUploadBlob { + #[serde(rename = "statusBlobType")] + pub status_blob_type: Option, + pub value: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GaFamily { + pub name: Option, + pub version: Option, + #[serde(rename = "isVersionFromRSM")] + pub is_version_from_rsm: Option, + #[serde(rename = "isVMEnabledForRSMUpgrades")] + pub is_vm_enabled_for_rsm_upgrades: Option, + pub uris: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtensionGoalState { + pub name: Option, + pub version: Option, + pub location: Option, + #[serde(rename = "failoverLocation")] + pub failover_location: Option, + #[serde(rename = "additionalLocations")] + pub additional_locations: Option>, + pub state: Option, + #[serde(rename = "autoUpgrade")] + pub auto_upgrade: Option, + #[serde(rename = "runAsStartupTask")] + pub run_as_startup_task: Option, + #[serde(rename = "isJson")] + pub is_json: Option, + #[serde(rename = "useExactVersion")] + pub use_exact_version: Option, + #[serde(rename = "settingsSeqNo")] + pub settings_seq_no: Option, + #[serde(rename = "isMultiConfig")] + pub is_multi_config: Option, + pub settings: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Settings { + #[serde(rename = "protectedSettingsCertThumbprint")] + pub protected_settings_cert_thumbprint: Option, + #[serde(rename = "protectedSettings")] + pub protected_settings: Option, + #[serde(rename = "publicSettings")] + pub public_settings: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RawCertificatesPayload { + #[serde(rename = "Pkcs7BlobWithPfxContents")] + pub pkcs7_blob_with_pfx_contents: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Certificates { + #[serde(rename = "activityId")] + pub activity_id: Option, + + #[serde(rename = "correlationId")] + pub correlation_id: Option, + #[serde(rename = "certificates")] + pub certificates: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Certificate { + #[serde(rename = "name")] + pub name: Option, + + #[serde(rename = "storeName")] + pub store_name: Option, + + #[serde(rename = "configurationLevel")] + pub configuration_level: Option, + + #[serde(rename = "certificateInBase64")] + pub certificate_in_base64: Option, + + #[serde(rename = "includePrivateKey")] + pub include_private_key: Option, + + #[serde(rename = "thumbprint")] + pub thumbprint: Option, + + #[serde(rename = "certificateBlobFormatType")] + pub certificate_blob_format_type: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn certificates_deserialization_test() { + let certificates_json = r#" + { + "activityId": "11111111-1111111-11111111111-111111", + "correlationId": "80e22e3b-3f9a-424e-b300-6cda2dd7e718", + "certificates": [ + { + "name": "TenantEncryptionCert", + "storeName": "My", + "configurationLevel": "System", + "certificateInBase64": "certificateInBase64_test", + "includePrivateKey": false, + "thumbprint": "thumbprint_test", + "certificateBlobFormatType": "PfxInClear" + } + ] + } + "#; + let certificates: Certificates = serde_json::from_str(certificates_json).unwrap(); + assert_eq!( + certificates.activity_id.unwrap(), + "11111111-1111111-11111111111-111111" + ); + assert_eq!( + certificates.correlation_id.unwrap(), + "80e22e3b-3f9a-424e-b300-6cda2dd7e718" + ); + assert_eq!(certificates.certificates.as_ref().unwrap().len(), 1); + assert_eq!( + certificates.certificates.as_ref().unwrap()[0] + .certificate_in_base64 + .as_ref() + .unwrap(), + "certificateInBase64_test" + ); + assert_eq!( + certificates.certificates.as_ref().unwrap()[0] + .thumbprint + .as_ref() + .unwrap(), + "thumbprint_test" + ); + } + + #[test] + fn vmsettings_deserialization_test() { + let vmsettings_json = r#" + { + "hostGAPluginVersion": "1.0.8.179", + "activityId": "1111-11111111-1111-11-1111", + "correlationId": "000000-00000000-000000-0000", + "inSvdSeqNo": 1, + "certificatesRevision": 0, + "extensionsLastModifiedTickCount": 638931417044754873, + "extensionGoalStatesSource": "FastTrack", + "statusUploadBlob": { + "statusBlobType": "PageBlob", + "value": "string" + }, + "gaFamilies": [ + { + "name": "Win7", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + }, + { + "name": "Win8", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + } + ], + "extensionGoalStates": [ + { + "name": "test", + "version": "1.0.1", + "location": "location", + "failoverLocation": "location", + "additionalLocations": [ + "location" + ], + "state": "enabled", + "autoUpgrade": true, + "runAsStartupTask": false, + "isJson": true, + "useExactVersion": true, + "settingsSeqNo": 0, + "isMultiConfig": false, + "settings": [ + { + "protectedSettingsCertThumbprint": null, + "protectedSettings": null, + "publicSettings": "{}" + } + ] + } + ] + }"#; + + let vmsettings: VMSettings = serde_json::from_str(vmsettings_json).unwrap(); + assert_eq!(vmsettings.host_ga_plugin_version.unwrap(), "1.0.8.179"); + assert_eq!( + vmsettings.activity_id.unwrap(), + "1111-11111111-1111-11-1111" + ); + assert_eq!( + vmsettings.correlation_id.unwrap(), + "000000-00000000-000000-0000" + ); + assert_eq!(vmsettings.in_svd_seq_no.unwrap(), 1); + assert_eq!(vmsettings.certificates_revision.unwrap(), 0); + assert_eq!( + vmsettings.extensions_last_modified_tick_count.unwrap(), + 638931417044754873 + ); + assert_eq!(vmsettings.ga_families.as_ref().unwrap().len(), 2); + assert_eq!(vmsettings.extension_goal_states.as_ref().unwrap().len(), 1); + } +} diff --git a/proxy_agent_shared/src/host_clients/data_model/mod.rs b/proxy_agent_shared/src/host_clients/data_model/mod.rs new file mode 100644 index 00000000..9f46f218 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/data_model/mod.rs @@ -0,0 +1,2 @@ +pub mod hostga_plugin_model; +pub mod wire_server_model; diff --git a/proxy_agent_shared/src/host_clients/data_model/wire_server_model.rs b/proxy_agent_shared/src/host_clients/data_model/wire_server_model.rs new file mode 100644 index 00000000..e798e943 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/data_model/wire_server_model.rs @@ -0,0 +1,509 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename = "Versions")] +pub struct Versions { + #[serde(rename = "Preferred")] + pub preferred: Preferred, + + #[serde(rename = "Supported")] + pub supported: Supported, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Preferred { + #[serde(rename = "Version")] + pub version: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Supported { + #[serde(rename = "Version")] + pub versions: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename = "GoalState")] +pub struct GoalState { + #[serde(rename = "Version")] + pub version: Option, + + #[serde(rename = "Incarnation")] + pub incarnation: Option, + + #[serde(rename = "Machine")] + pub machine: Option, + + #[serde(rename = "Container")] + pub container: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Machine { + #[serde(rename = "ExpectedState")] + pub expected_state: Option, + + #[serde(rename = "StopRolesDeadlineHint")] + pub stop_roles_deadline_hint: Option, + + #[serde(rename = "LBProbePorts")] + pub lb_probe_ports: Option, + + #[serde(rename = "ExpectHealthReport")] + pub expect_health_report: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct LBProbePorts { + #[serde(rename = "Port")] + pub port: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Container { + #[serde(rename = "ContainerId")] + pub container_id: Option, + + #[serde(rename = "RoleInstanceList")] + pub role_instance_list: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RoleInstanceList { + #[serde(rename = "RoleInstance")] + pub role_instance: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RoleInstance { + #[serde(rename = "InstanceId")] + pub instance_id: Option, + + #[serde(rename = "State")] + pub state: Option, + + #[serde(rename = "Configuration")] + pub configuration: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Configuration { + #[serde(rename = "HostingEnvironmentConfig")] + pub hosting_environment_config: Option, + + #[serde(rename = "SharedConfig")] + pub shared_config: Option, + + #[serde(rename = "ExtensionsConfig")] + pub extensions_config: Option, + + #[serde(rename = "FullConfig")] + pub full_config: Option, + + #[serde(rename = "Certificates")] + pub certificates: Option, + + #[serde(rename = "ConfigName")] + pub config_name: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +#[serde(rename = "RDConfig")] +pub struct RDConfig { + #[serde(rename = "@version")] + pub version: Option, + + #[serde(rename = "StoredCertificates")] + pub stored_certificates: Option, + + #[serde(rename = "Deployment")] + pub deployment: Option, + + #[serde(rename = "Incarnation")] + pub incarnation: Option, + + #[serde(rename = "Role")] + pub role: Option, + + #[serde(rename = "HostingEnvironmentSettings")] + pub hosting_environment_settings: Option, + + #[serde(rename = "ApplicationSettings")] + pub application_settings: Option, + + #[serde(rename = "OutputEndpoints")] + pub output_endpoints: Option, + + #[serde(rename = "Instances")] + pub instances: Option, + + #[serde(rename = "Neighborhoods")] + pub neighborhoods: Option, +} + +// ----- StoredCertificates ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct StoredCertificates { + #[serde(rename = "StoredCertificate", default)] + pub stored_certificate: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct StoredCertificate { + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@certificateId")] + pub certificate_id: Option, + #[serde(rename = "@storeName")] + pub store_name: Option, + #[serde(rename = "@configurationLevel")] + pub configuration_level: Option, +} + +// ----- Deployment ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Deployment { + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@incarnation")] + pub incarnation: Option, + #[serde(rename = "@guid")] + pub guid: Option, + + #[serde(rename = "Service", default)] + pub services: Vec, + #[serde(rename = "ServiceInstance", default)] + pub service_instances: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Service { + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@guid")] + pub guid: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct ServiceInstance { + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@guid")] + pub guid: Option, +} + +// ----- Incarnation ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Incarnation { + #[serde(rename = "@number")] + pub number: Option, + #[serde(rename = "@instance")] + pub instance: Option, + #[serde(rename = "@guid")] + pub guid: Option, +} + +// ----- Role ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Role { + #[serde(rename = "@guid")] + pub guid: Option, + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@hostingEnvironment")] + pub hosting_environment: Option, + #[serde(rename = "@hostingEnvironmentVersion")] + pub hosting_environment_version: Option, + #[serde(rename = "@software")] + pub software: Option, + #[serde(rename = "@softwareType")] + pub software_type: Option, + #[serde(rename = "@entryPoint")] + pub entry_point: Option, + #[serde(rename = "@parameters")] + pub parameters: Option, + #[serde(rename = "@cpu")] + pub cpu: Option, + #[serde(rename = "@memory")] + pub memory: Option, + #[serde(rename = "@bandwidth")] + pub bandwidth: Option, + #[serde(rename = "@isManagementRole")] + pub is_management_role: Option, +} + +// ----- HostingEnvironmentSettings ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct HostingEnvironmentSettings { + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@Runtime")] + pub runtime: Option, + #[serde(rename = "CAS")] + pub cas: Option, + #[serde(rename = "PrivilegeLevel")] + pub privilege_level: Option, + #[serde(rename = "AdditionalProperties")] + pub additional_properties: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct CAS { + #[serde(rename = "@mode")] + pub mode: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct PrivilegeLevel { + #[serde(rename = "@mode")] + pub mode: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct AdditionalProperties { + #[serde(rename = "Extensions")] + pub extensions: Option, // CDATA content +} + +// ----- ApplicationSettings ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct ApplicationSettings { + #[serde(rename = "Setting", default)] + pub settings: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Setting { + #[serde(rename = "@name")] + pub name: Option, + #[serde(rename = "@value")] + pub value: Option, +} + +// ----- Instances ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Instances { + #[serde(rename = "Instance", default)] + pub instances: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Instance { + #[serde(rename = "@id")] + pub id: Option, + #[serde(rename = "@neighborhoodID")] + pub neighborhood_id: Option, + #[serde(rename = "@address")] + pub address: Option, + #[serde(rename = "FaultDomains")] + pub fault_domains: Option, + #[serde(rename = "InputEndpoints")] + pub input_endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct FaultDomains { + #[serde(rename = "@randomID")] + pub random_id: Option, + #[serde(rename = "@updateID")] + pub update_id: Option, + #[serde(rename = "@updateCount")] + pub update_count: Option, +} + +// ----- Neighborhoods ----- +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Neighborhoods { + #[serde(rename = "Neighborhood", default)] + pub neighborhoods: Vec, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +pub struct Neighborhood { + #[serde(rename = "@id")] + pub id: Option, + #[serde(rename = "@innerbandwidth")] + pub innerbandwidth: Option, + #[serde(rename = "@innerlatency")] + pub innerlatency: Option, + #[serde(rename = "@outwardbandwidth")] + pub outwardbandwidth: Option, + #[serde(rename = "@outwardlatency")] + pub outwardlatency: Option, + #[serde(rename = "@parentNeighborhoodID")] + pub parent_neighborhood_id: Option, +} + +#[derive(Debug, Serialize, Deserialize, Default)] +#[serde(rename = "HostingEnvironmentConfig")] +pub struct HostingEnvironmentConfig { + #[serde(rename = "@version")] + pub version: Option, + + #[serde(rename = "@goalStateIncarnation")] + pub goal_state_incarnation: Option, + + #[serde(rename = "StoredCertificates")] + pub stored_certificates: Option, + + #[serde(rename = "Deployment")] + pub deployment: Option, + + #[serde(rename = "Incarnation")] + pub incarnation: Option, + + #[serde(rename = "Role")] + pub role: Option, + + #[serde(rename = "HostingEnvironmentSettings")] + pub hosting_environment_settings: Option, + + #[serde(rename = "ApplicationSettings")] + pub application_settings: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn versions_deserialization_test() { + let xml_data = r#" + + + 2015-04-05 + + + 2015-04-05 + 2012-11-30 + 2012-09-15 + 2012-05-15 + 2011-12-31 + 2011-10-15 + 2011-08-31 + 2011-04-07 + 2010-12-15 + 2010-28-10 + + + "#; + + let versions: Versions = quick_xml::de::from_str(xml_data).unwrap(); + assert_eq!(versions.preferred.version, "2015-04-05"); + assert_eq!(versions.supported.versions.len(), 10); + } + + #[test] + fn goal_state_deserialization_test() { + let xml_data = r#" + + 2015-04-05 + 1 + + Started + 300000 + + 16001 + + FALSE + + + c9514be2-ff0a-4dee-a059-45a0452268e7 + + + 896a1f5d-459b-4e58-a337-d113f9e97d25.instance + Started + + HostingEnvironmentConfig_uri + SharedConfig_uri + ExtensionsConfig_uri + FullConfig_uri + Certificates_uri + ConfigName.xml + + + + + + "#; + let goal_state: GoalState = quick_xml::de::from_str(xml_data).unwrap(); + assert_eq!(goal_state.version.unwrap(), "2015-04-05"); + assert_eq!(goal_state.incarnation.unwrap(), 1); + + let role_instances = goal_state + .container + .unwrap() + .role_instance_list + .unwrap() + .role_instance + .unwrap(); + assert_eq!(role_instances.len(), 1); + let role_instance = &role_instances[0]; + assert_eq!( + role_instance.instance_id.as_ref().unwrap(), + "896a1f5d-459b-4e58-a337-d113f9e97d25.instance" + ); + let configuration = role_instance.configuration.as_ref().unwrap(); + assert_eq!( + configuration.full_config.as_ref().unwrap(), + "FullConfig_uri" + ); + assert_eq!( + configuration.certificates.as_ref().unwrap(), + "Certificates_uri" + ); + } + + #[test] + fn full_config_deserialization_test() { + let xml_data = r#" + + + + + + + + + + + + + + + + + + + + + + + + + + "#; + let rd_config: RDConfig = quick_xml::de::from_str(xml_data).unwrap(); + assert_eq!(rd_config.version.unwrap(), "1.0.0.0"); + assert_eq!( + rd_config + .stored_certificates + .as_ref() + .unwrap() + .stored_certificate + .len(), + 1 + ); + + let certificate = &rd_config + .stored_certificates + .as_ref() + .unwrap() + .stored_certificate[0]; + assert_eq!(certificate.name.as_ref().unwrap(), "TenantEncryptionCert"); + assert_eq!( + certificate.certificate_id.as_ref().unwrap(), + "sha1:45750FFF384A47DEC65C9C7BB829B27E0562726F" + ); + } +} diff --git a/proxy_agent_shared/src/host_clients/hostga_plugin_client.rs b/proxy_agent_shared/src/host_clients/hostga_plugin_client.rs new file mode 100644 index 00000000..8b37d059 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/hostga_plugin_client.rs @@ -0,0 +1,452 @@ +use std::collections::HashMap; +use std::time::Duration; + +use crate::certificate::certificate_helper::{ + decrypt_from_base64, generate_self_signed_certificate, +}; +use crate::common::error::Error; +use crate::common::formatted_error::FormattedError; +use crate::common::hyper_client; +use crate::common::hyper_client::read_response_body_as_string; +use crate::common::result::Result; +use crate::host_clients::data_model::hostga_plugin_model::{ + Certificates, RawCertificatesPayload, VMSettings, +}; +use crate::logger::LoggerLevel; +use base64::Engine; +use http::{Method, StatusCode, Uri}; +use serde::{Deserialize, Serialize}; +use tokio::time::timeout; +use uuid::Uuid; + +pub struct HostGAPluginClient { + base_url: String, + logger: fn(LoggerLevel, String) -> (), + timeout_in_seconds: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HostGAPluginResponse { + pub body: Option, + pub etag: Option, + pub certificates_revision: Option, + pub version: Option, +} + +impl HostGAPluginClient { + const CERTIFICATES_URL: &'static str = "certificates"; + const VMSETTINGS_URL: &'static str = "vmSettings"; + const HOSTGAP_CIPHER: &'static str = "AES256_CBC"; + + const ETAG_HEADER: &'static str = "etag"; + const X_MS_SERVER_VERSION_HEADER: &'static str = "x-ms-server-version"; + const X_MS_CERTIFICATES_REVISION_HEADER: &'static str = "x-ms-certificates-revision"; + const TRANSPORT_CERTIFICATE_HEADER: &'static str = "x-ms-guest-agent-public-x509-cert"; + const TRANSPORT_CERTIFICATE_ENCRYPT_CIPHER_HEADER: &'static str = "x-ms-cipher-name"; + + pub fn new( + base_url: &str, + logger: fn(LoggerLevel, String) -> (), + timeout_in_seconds: Option, + ) -> HostGAPluginClient { + HostGAPluginClient { + base_url: base_url.to_string(), + logger, + timeout_in_seconds, + } + } + + pub async fn get_vmsettings( + &self, + etag: Option, + ) -> Result> { + let logger = self.logger; + + logger( + LoggerLevel::Info, + format!("Requesting VMSettings with etag: {etag:?}"), + ); + + let headers = self.vmsettings_request_headers(etag); + + self.get::( + &format!("{}/{}", self.base_url, Self::VMSETTINGS_URL), + &headers, + ) + .await + } + + pub async fn get_certificates( + &self, + cert_revision: u32, + ) -> Result> { + let logger = self.logger; + logger( + LoggerLevel::Info, + format!("Requesting certificates with revision: {cert_revision}"), + ); + + let cert = generate_self_signed_certificate(&Uuid::new_v4().to_string())?; + let cert_der = cert.get_public_cert_der(); + let cert_base64 = base64::engine::general_purpose::STANDARD.encode(cert_der); + + let headers = self.certificate_request_headers(&cert_base64); + + let raw_certs_resp = self + .get::( + &format!( + "{}/{}/{}", + self.base_url, + Self::CERTIFICATES_URL, + cert_revision + ), + &headers, + ) + .await?; + + if let Some(cert_base64) = raw_certs_resp + .body + .as_ref() + .and_then(|body| body.pkcs7_blob_with_pfx_contents.as_ref()) + { + let certs = decrypt_from_base64(cert_base64, &cert)?; + + return Ok(HostGAPluginResponse { + body: Some( + serde_json::from_str::(&certs).map_err(FormattedError::from)?, + ), + etag: raw_certs_resp.etag.clone(), + certificates_revision: raw_certs_resp.certificates_revision, + version: raw_certs_resp.version.clone(), + }); + } + Err(FormattedError { + message: "certificate payload is empty.".to_string(), + code: -1, + } + .into()) + } + + pub async fn get( + &self, + url: &str, + headers: &HashMap, + ) -> Result> + where + for<'a> T: Deserialize<'a>, + { + let logger = self.logger; + let url: Uri = url + .parse::() + .map_err(|e| Error::ParseUrl(url.to_string(), e.to_string()))?; + + let request = hyper_client::build_request(Method::GET, &url, headers, None, None, None)?; + + let (host, port) = hyper_client::host_port_from_uri(&url)?; + + let response = if let Some(timeout_in_seconds) = self.timeout_in_seconds { + timeout( + Duration::from_secs(timeout_in_seconds as u64), + hyper_client::send_request(&host, port, request, move |m| { + logger(LoggerLevel::Warn, m) + }), + ) + .await + .map_err(Into::::into)?? + } else { + hyper_client::send_request(&host, port, request, move |m| logger(LoggerLevel::Warn, m)) + .await? + }; + + let etag = response + .headers() + .get(Self::ETAG_HEADER) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + let version = response + .headers() + .get(Self::X_MS_SERVER_VERSION_HEADER) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + let certificates_revision = response + .headers() + .get(Self::X_MS_CERTIFICATES_REVISION_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + let status = response.status(); + if status == StatusCode::NOT_MODIFIED { + let _hostgap_response: HostGAPluginResponse = HostGAPluginResponse { + body: None, + etag, + version, + certificates_revision, + }; + return Ok(_hostgap_response); + } else if status.is_success() { + let body_obj = hyper_client::read_response_body::(response).await?; + return Ok(HostGAPluginResponse { + body: Some(body_obj), + etag, + certificates_revision, + version, + }); + } + let body_string = read_response_body_as_string(response, "utf-8").await?; + Err(FormattedError { + code: status.as_u16() as i32, + message: format!("Http Error Status: {}, Body: {}", status, &body_string), + } + .into()) + } + + fn certificate_request_headers(&self, cert: &str) -> HashMap { + let mut headers = HashMap::new(); + headers.insert( + Self::TRANSPORT_CERTIFICATE_HEADER.to_string(), + cert.to_string(), + ); + headers.insert( + Self::TRANSPORT_CERTIFICATE_ENCRYPT_CIPHER_HEADER.to_string(), + Self::HOSTGAP_CIPHER.to_string(), + ); + headers + } + + fn vmsettings_request_headers(&self, etag: Option) -> HashMap { + let mut headers = HashMap::new(); + if let Some(etag) = etag { + headers.insert(Self::ETAG_HEADER.to_string(), etag); + } + headers + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hostgaplugin_client_creation_test() { + let client = HostGAPluginClient::new( + "http://localhost:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + None, + ); + assert_eq!(client.base_url, "http://localhost:8080"); + assert_eq!(client.timeout_in_seconds, None); + } + + #[test] + fn certificate_request_headers_test() { + let client = HostGAPluginClient::new( + "http://localhost:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + None, + ); + let cert = "test_cert"; + let headers = client.certificate_request_headers(cert); + assert_eq!( + headers + .get(HostGAPluginClient::TRANSPORT_CERTIFICATE_HEADER) + .unwrap(), + cert + ); + assert_eq!( + headers + .get(HostGAPluginClient::TRANSPORT_CERTIFICATE_ENCRYPT_CIPHER_HEADER) + .unwrap(), + HostGAPluginClient::HOSTGAP_CIPHER + ); + } + + #[test] + fn vmsettings_request_headers_test() { + let client = HostGAPluginClient::new( + "http://localhost:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + None, + ); + let etag = Some("test_etag".to_string()); + let headers = client.vmsettings_request_headers(etag.clone()); + assert_eq!( + headers.get(HostGAPluginClient::ETAG_HEADER).unwrap(), + etag.as_ref().unwrap() + ); + + let headers_no_etag = client.vmsettings_request_headers(None); + assert!(headers_no_etag + .get(HostGAPluginClient::ETAG_HEADER) + .is_none()); + } + + #[tokio::test] + async fn get_vmsettings_negative_test() { + let client = HostGAPluginClient::new( + "http://invalid:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + Some(2), + ); + let response = client.get_vmsettings(None).await; + assert!(response.is_err()); + } + + #[tokio::test] + async fn get_certificates_negative_test() { + let client = HostGAPluginClient::new( + "http://invalid:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + Some(2), + ); + let response = client.get_certificates(0).await; + assert!(response.is_err()); + } + + #[test] + fn get_hostgaplugin_certificates_response_test() { + let response = r#" + { + "body": { + "activityId": "11111111-1111111-11111111111-111111", + "correlationId": "80e22e3b-3f9a-424e-b300-6cda2dd7e718", + "certificates": [ + { + "name": "TenantEncryptionCert", + "storeName": "My", + "configurationLevel": "System", + "certificateInBase64": "certificateInBase64_test", + "includePrivateKey": false, + "thumbprint": "thumbprint_test", + "certificateBlobFormatType": "PfxInClear" + } + ] + }, + "etag": null, + "certificates_revision": null, + "version": "1.0.8.179" + } + "#; + let resp: HostGAPluginResponse = + serde_json::from_str(response).expect("Deserialize HostGAPluginResponse failed"); + assert!(resp.body.is_some()); + let certs = resp.body.unwrap(); + assert_eq!( + certs.activity_id.unwrap(), + "11111111-1111111-11111111111-111111" + ); + assert_eq!( + certs.correlation_id.unwrap(), + "80e22e3b-3f9a-424e-b300-6cda2dd7e718" + ); + assert!(certs.certificates.is_some()); + let cert_list = certs.certificates.unwrap(); + assert_eq!(cert_list.len(), 1); + let cert = &cert_list[0]; + assert_eq!(cert.name.as_ref().unwrap(), "TenantEncryptionCert"); + assert_eq!( + cert.certificate_in_base64.as_ref().unwrap(), + "certificateInBase64_test" + ); + assert_eq!(cert.thumbprint.as_ref().unwrap(), "thumbprint_test"); + } + + #[test] + fn get_hostgaplugin_vmsettings_response_test() { + let response = r#" + { + "body": { + "hostGAPluginVersion": "1.0.8.179", + "activityId": "1111-11111111-1111-11-1111", + "correlationId": "000000-00000000-000000-0000", + "inSvdSeqNo": 1, + "certificatesRevision": 0, + "extensionsLastModifiedTickCount": 638931417044754873, + "extensionGoalStatesSource": "FastTrack", + "statusUploadBlob": { + "statusBlobType": "PageBlob", + "value": "string" + }, + "gaFamilies": [ + { + "name": "Win7", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + }, + { + "name": "Win8", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + } + ], + "extensionGoalStates": [ + { + "name": "extension.test", + "version": "1.0.1", + "location": "location", + "failoverLocation": "location", + "additionalLocations": [ + "location" + ], + "state": "enabled", + "autoUpgrade": true, + "runAsStartupTask": false, + "isJson": true, + "useExactVersion": true, + "settingsSeqNo": 0, + "isMultiConfig": false, + "settings": [ + { + "protectedSettingsCertThumbprint": null, + "protectedSettings": null, + "publicSettings": "{}" + } + ] + } + ] + }, + "etag": "5048704324908356042", + "certificates_revision": 1, + "version": "1.0.8.179" + } + "#; + let resp: HostGAPluginResponse = + serde_json::from_str(response).expect("Deserialize HostGAPluginResponse failed"); + assert!(resp.body.is_some()); + assert_eq!(resp.etag.unwrap(), "5048704324908356042"); + assert_eq!(resp.certificates_revision.unwrap(), 1); + + let vmsettings = resp.body.unwrap(); + assert_eq!( + vmsettings.activity_id.unwrap(), + "1111-11111111-1111-11-1111" + ); + assert_eq!( + vmsettings.extensions_last_modified_tick_count.unwrap(), + 638931417044754873 + ); + assert_eq!(vmsettings.extension_goal_states.as_ref().unwrap().len(), 1); + assert_eq!(vmsettings.ga_families.as_ref().unwrap().len(), 2); + + let extension = &vmsettings.extension_goal_states.as_ref().unwrap()[0]; + assert_eq!(extension.name.as_ref().unwrap(), "extension.test"); + assert_eq!(extension.version.as_ref().unwrap(), "1.0.1"); + } +} diff --git a/proxy_agent_shared/src/host_clients/mod.rs b/proxy_agent_shared/src/host_clients/mod.rs new file mode 100644 index 00000000..25e8f2b5 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/mod.rs @@ -0,0 +1,3 @@ +pub mod data_model; +pub mod hostga_plugin_client; +pub mod wire_server_client; diff --git a/proxy_agent_shared/src/host_clients/wire_server_client.rs b/proxy_agent_shared/src/host_clients/wire_server_client.rs new file mode 100644 index 00000000..0980239f --- /dev/null +++ b/proxy_agent_shared/src/host_clients/wire_server_client.rs @@ -0,0 +1,165 @@ +use std::{collections::HashMap, time::Duration}; + +use http::Uri; +use serde::Deserialize; +use tokio::time::timeout; + +use crate::{ + common::{error::Error, formatted_error::FormattedError, hyper_client, result::Result}, + host_clients::data_model::wire_server_model::{GoalState, Versions}, + logger::LoggerLevel, +}; + +pub struct WireServerClient { + base_url: String, + version: String, + logger: fn(LoggerLevel, String) -> (), + timeout_in_seconds: Option, +} + +impl WireServerClient { + const X_MS_VERSION_HEADER: &'static str = "x-ms-version"; + const VERSIONS_URL: &'static str = "?comp=Versions"; + const GOAL_STATE_URL: &'static str = "machine?comp=goalstate"; + + const DEFAULT_WIRE_VERSION: &'static str = "2012-11-30"; + + pub fn new( + base_url: &str, + logger: fn(LoggerLevel, String) -> (), + timeout_in_seconds: Option, + ) -> WireServerClient { + WireServerClient { + base_url: base_url.to_string(), + version: Self::DEFAULT_WIRE_VERSION.to_string(), + logger, + timeout_in_seconds, + } + } + + // http://168.63.129.16?comp=Versions + pub async fn get_versions(&self) -> Result { + self.get_sub_url::(Self::VERSIONS_URL).await + } + + // http://168.63.129.16/machine?comp=goalstate + pub async fn get_goal_state(&self) -> Result { + self.get_sub_url::(Self::GOAL_STATE_URL).await + } + + pub async fn refresh_wire_server_version(&mut self) -> Result<()> { + let versions = self.get_versions().await?; + self.update_version(versions); + Ok(()) + } + + pub async fn get_sub_url(&self, sub_url: &str) -> Result + where + T: for<'a> Deserialize<'a>, + { + self.get_url(&format!("{}/{}", &self.base_url, sub_url)) + .await + } + + pub async fn get_url(&self, url: &str) -> Result + where + T: for<'a> Deserialize<'a>, + { + let logger = self.logger; + let url: Uri = url + .parse::() + .map_err(|e| Error::ParseUrl(url.to_string(), e.to_string()))?; + + let headers = self.common_headers(); + + let res = if let Some(timeout_in_seconds) = self.timeout_in_seconds { + timeout( + Duration::from_secs(timeout_in_seconds as u64), + hyper_client::get(&url, &headers, None, None, move |message| { + logger(LoggerLevel::Warn, message) + }), + ) + .await + .map_err(Into::::into)?? + } else { + hyper_client::get(&url, &headers, None, None, move |message| { + logger(LoggerLevel::Warn, message) + }) + .await? + }; + + Ok(res) + } + + fn common_headers(&self) -> HashMap { + let mut headers = HashMap::new(); + headers.insert(Self::X_MS_VERSION_HEADER.to_string(), self.version.clone()); + headers + } + + fn update_version(&mut self, versions: Versions) { + self.version = versions.preferred.version; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn wire_server_client_creation_test() { + let client = WireServerClient::new("http://localhost:8080", test_logger, None); + assert_eq!(client.base_url, "http://localhost:8080"); + assert_eq!(client.version, WireServerClient::DEFAULT_WIRE_VERSION); + } + + #[test] + fn wire_server_client_common_headers_test() { + let client = WireServerClient::new("http://localhost:8080", test_logger, None); + let headers = client.common_headers(); + assert_eq!( + headers.get("x-ms-version").unwrap(), + WireServerClient::DEFAULT_WIRE_VERSION + ); + } + + #[test] + fn wire_server_client_update_version_test() { + let mut client = WireServerClient::new("http://localhost:8080", test_logger, None); + let versions = Versions { + preferred: crate::host_clients::data_model::wire_server_model::Preferred { + version: "2021-01-01".to_string(), + }, + supported: crate::host_clients::data_model::wire_server_model::Supported { + versions: vec![], + }, + }; + client.update_version(versions); + assert_eq!(client.version, "2021-01-01"); + } + + #[tokio::test] + async fn wire_server_client_negative_test() { + let mut client = WireServerClient::new("http://invalid:8080", test_logger, Some(1)); + assert!(client.get_goal_state().await.is_err()); + assert!(client.get_versions().await.is_err()); + assert!(client.refresh_wire_server_version().await.is_err()); + } + + #[tokio::test] + async fn wire_server_client_get_url_invalid_uri() { + let client = WireServerClient::new("http://localhost:8080", test_logger, None); + + let res: Result = client.get_url("http://invalid uri").await; + assert!(res.is_err()); + + match res { + Err(Error::ParseUrl(_, _)) => {} // expected + _ => panic!("Expected Parse Url Error"), + } + } + + fn test_logger(level: LoggerLevel, message: String) { + println!("{:?}: {}", level, message); + } +} diff --git a/proxy_agent_shared/src/lib.rs b/proxy_agent_shared/src/lib.rs index 41d2c220..ed809172 100644 --- a/proxy_agent_shared/src/lib.rs +++ b/proxy_agent_shared/src/lib.rs @@ -1,9 +1,12 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +pub mod certificate; +pub mod common; pub mod error; #[cfg(windows)] pub mod etw; +pub mod host_clients; pub mod logger; pub mod misc_helpers; pub mod proxy_agent_aggregate_status; diff --git a/proxy_agent_shared/src/misc_helpers.rs b/proxy_agent_shared/src/misc_helpers.rs index 238b04cd..0d4f9d98 100644 --- a/proxy_agent_shared/src/misc_helpers.rs +++ b/proxy_agent_shared/src/misc_helpers.rs @@ -183,7 +183,7 @@ pub fn get_files(dir: &Path) -> Result> { /// # Example /// ```rust /// use std::path::PathBuf; -/// use proxy_agent_shared::misc_helpers; +/// use crate::misc_helpers; /// let dir = PathBuf::from("."); /// let search_regex_pattern = r"^(.*\.log)$"; // search for files with .log extension /// let files = misc_helpers::search_files(&dir, search_regex_pattern).unwrap(); diff --git a/proxy_agent_shared/src/proxy_agent_aggregate_status.rs b/proxy_agent_shared/src/proxy_agent_aggregate_status.rs index 5a3f2e6a..8c2feb0b 100644 --- a/proxy_agent_shared/src/proxy_agent_aggregate_status.rs +++ b/proxy_agent_shared/src/proxy_agent_aggregate_status.rs @@ -88,3 +88,121 @@ pub struct GuestProxyAgentAggregateStatus { pub proxyConnectionSummary: Vec, pub failedAuthenticateSummary: Vec, } + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn get_proxy_agent_aggregate_status_folder_test() { + let path = misc_helpers::resolve_env_variables(PROXY_AGENT_AGGREGATE_STATUS_FOLDER); + assert!(path.is_ok()); + let path_buf = get_proxy_agent_aggregate_status_folder(); + assert_eq!(path_buf.to_string_lossy().into_owned(), path.unwrap()); + } + + #[test] + fn guest_proxy_agent_aggregate_status_deserialize_test() { + let json_str = r#" + { + "timestamp": "2025-10-10T12:00:00Z", + "proxyAgentStatus": { + "version": "1.0.0", + "status": "SUCCESS", + "monitorStatus": { + "status": "RUNNING", + "message": "Monitor is running", + "states": { + "monitorState": "TestState" + } + }, + "keyLatchStatus": { + "status": "STOPPED", + "message": "Key latch is stopped" + }, + "ebpfProgramStatus": { + "status": "UNKNOWN", + "message": "eBPF program status unknown" + }, + "proxyListenerStatus": { + "status": "RUNNING", + "message": "Proxy listener is active" + }, + "telemetryLoggerStatus": { + "status": "RUNNING", + "message": "Telemetry logger is operational" + }, + "proxyConnectionsCount": 42 + }, + "proxyConnectionSummary": [ + { + "userName": "user1", + "ip": "192.168.1.1", + "port": 8080, + "processCmdLine": "cmd.exe /c whoami", + "responseStatus": "Success", + "count": 1, + "userGroups": ["Administrators"], + "processFullPath": "C:\\Windows\\System32\\cmd.exe" + } + ], + "failedAuthenticateSummary": [] + }"#; + let status: Result = + serde_json::from_str(json_str); + assert!(status.is_ok()); + let status = status.unwrap(); + assert_eq!(status.timestamp, "2025-10-10T12:00:00Z"); + assert_eq!(status.proxyAgentStatus.version, "1.0.0"); + assert_eq!(status.proxyAgentStatus.status, OverallState::SUCCESS); + assert_eq!( + status.proxyAgentStatus.monitorStatus.status, + ModuleState::RUNNING + ); + assert_eq!( + status.proxyAgentStatus.monitorStatus.message, + "Monitor is running" + ); + assert_eq!( + status + .proxyAgentStatus + .monitorStatus + .states + .unwrap() + .get("monitorState") + .unwrap(), + "TestState" + ); + assert_eq!( + status.proxyAgentStatus.keyLatchStatus.status, + ModuleState::STOPPED + ); + assert_eq!( + status.proxyAgentStatus.keyLatchStatus.message, + "Key latch is stopped" + ); + assert_eq!( + status.proxyAgentStatus.ebpfProgramStatus.status, + ModuleState::UNKNOWN + ); + assert_eq!( + status.proxyAgentStatus.ebpfProgramStatus.message, + "eBPF program status unknown" + ); + assert_eq!( + status.proxyAgentStatus.proxyListenerStatus.status, + ModuleState::RUNNING + ); + assert_eq!( + status.proxyAgentStatus.proxyListenerStatus.message, + "Proxy listener is active" + ); + assert_eq!( + status.proxyAgentStatus.telemetryLoggerStatus.status, + ModuleState::RUNNING + ); + assert_eq!( + status.proxyAgentStatus.telemetryLoggerStatus.message, + "Telemetry logger is operational" + ); + } +}