diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index a798d29b..82b9b1c0 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -365,4 +365,22 @@ xsi xxxx xxxxxxxx xxxxxxxxxxx -zipsas \ No newline at end of file +zipsas +CALG +CNG +detials +HCRYPTPROV +hostgap +innerbandwidth +innerlatency +KSP +ncrypt +outwardbandwidth +outwardlatency +PCWSTR +PSTR +psz +rsm +SELFSIGN +vmsettings +hostgaplugin \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index f1dba32b..ee57a771 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -223,6 +223,12 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "base64" +version = "0.22.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" + [[package]] name = "bitflags" version = "2.6.0" @@ -876,6 +882,7 @@ dependencies = [ name = "proxy_agent_shared" version = "9.9.9" dependencies = [ + "base64", "chrono", "concurrent-queue", "ctor", @@ -900,6 +907,7 @@ dependencies = [ "tokio", "tokio-util", "uuid", + "windows 0.61.3", "windows-service", "windows-sys", "winreg", @@ -1142,7 +1150,7 @@ dependencies = [ "ntapi", "once_cell", "rayon", - "windows", + "windows 0.52.0", ] [[package]] @@ -1482,6 +1490,19 @@ dependencies = [ "windows-targets", ] +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link", + "windows-numerics", +] + [[package]] name = "windows-acl" version = "0.3.0" @@ -1494,6 +1515,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -1516,6 +1546,17 @@ dependencies = [ "windows-strings", ] +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.0" @@ -1544,6 +1585,16 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link", +] + [[package]] name = "windows-result" version = "0.3.4" @@ -1598,6 +1649,15 @@ dependencies = [ "windows_x86_64_msvc", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" diff --git a/proxy_agent_shared/Cargo.toml b/proxy_agent_shared/Cargo.toml index c4756061..ddd14c15 100644 --- a/proxy_agent_shared/Cargo.toml +++ b/proxy_agent_shared/Cargo.toml @@ -27,6 +27,7 @@ http = "1.1.0" http-body-util = "0.1" hyper = { version = "1", features = ["server", "http1", "client"] } hyper-util = { version = "0.1", features = ["tokio"] } +base64 = "0.22.1" [dependencies.uuid] version = "1.3.0" @@ -56,5 +57,14 @@ features = [ "Win32_Storage_FileSystem", ] +[target.'cfg(windows)'.dependencies.windows] +version = "0.61.3" +features = [ + "Win32_Foundation", + "Win32_Security_Cryptography", + "Win32_System_SystemInformation", + "Win32_System_Memory", +] + [target.'cfg(not(windows))'.dependencies] os_info = "3.7.0" # read Linux OS version and arch \ No newline at end of file diff --git a/proxy_agent_shared/src/certificate.rs b/proxy_agent_shared/src/certificate.rs new file mode 100644 index 00000000..9ac1744b --- /dev/null +++ b/proxy_agent_shared/src/certificate.rs @@ -0,0 +1,3 @@ +pub mod certificate_helper; +#[cfg(windows)] +pub mod certificate_helper_windows; diff --git a/proxy_agent_shared/src/certificate/certificate_helper.rs b/proxy_agent_shared/src/certificate/certificate_helper.rs new file mode 100644 index 00000000..acec684e --- /dev/null +++ b/proxy_agent_shared/src/certificate/certificate_helper.rs @@ -0,0 +1,103 @@ +#[cfg(windows)] +use crate::certificate::certificate_helper_windows::CertificateDetailsWindows; +use crate::formatted_error_message::FormattedErrorMessage; + +#[cfg(windows)] +type CertDetailsType = CertificateDetailsWindows; + +#[cfg(not(windows))] +type CertDetailsType = (); + +pub struct CertificateDetailsWrapper { + pub cert_details: CertDetailsType, +} + +impl CertificateDetailsWrapper { + pub fn get_public_cert_der(&self) -> &[u8] { + #[cfg(windows)] + { + &self.cert_details.public_key_der + } + #[cfg(not(windows))] + { + todo!() + } + } +} + +pub fn generate_self_signed_certificate( + _subject_name: &str, +) -> Result { + #[cfg(windows)] + { + use crate::certificate::certificate_helper_windows::generate_self_signed_certificate_windows; + + generate_self_signed_certificate_windows(_subject_name) + } + #[cfg(not(windows))] + { + Err("Linux version is not implemented.".to_string().into()) + } +} + +pub fn decrypt_from_base64( + _base64_input: &str, + _cert_details: &CertificateDetailsWrapper, +) -> Result { + #[cfg(windows)] + { + use crate::certificate::certificate_helper_windows::decrypt_from_base64_windows; + + decrypt_from_base64_windows(_base64_input, _cert_details) + } + #[cfg(not(windows))] + { + Err("Linux version is not implemented.".to_string().into()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn generate_self_signed_certificate_test() { + #[cfg(windows)] + { + let subject_name = "TestSubject"; + let cert_details_result = generate_self_signed_certificate(subject_name); + assert!(cert_details_result.is_ok()); + let cert_details = cert_details_result.unwrap(); + let public_cert_der = cert_details.get_public_cert_der(); + assert!(!public_cert_der.is_empty()); + } + #[cfg(not(windows))] + { + // On non-Windows platforms, the function is not implemented. + // This test will simply ensure that the function is called without panic. + let subject_name = "TestSubject"; + let cert_details_result = generate_self_signed_certificate(subject_name); + assert!(cert_details_result.is_err()); + } + } + + #[test] + fn decrypt_from_base64_test() { + #[cfg(windows)] + { + let subject_name = "TestSubject"; + let cert_details_result = generate_self_signed_certificate(subject_name); + assert!(cert_details_result.is_ok()); + let cert_details = cert_details_result.unwrap(); + let result = decrypt_from_base64("invalid input", &cert_details); + assert!(result.is_err()); + } + #[cfg(not(windows))] + { + let result = decrypt_from_base64( + "invalid input", + &CertificateDetailsWrapper { cert_details: () }, + ); + assert!(result.is_err()); + } + } +} diff --git a/proxy_agent_shared/src/certificate/certificate_helper_windows.rs b/proxy_agent_shared/src/certificate/certificate_helper_windows.rs new file mode 100644 index 00000000..4179c091 --- /dev/null +++ b/proxy_agent_shared/src/certificate/certificate_helper_windows.rs @@ -0,0 +1,451 @@ +use base64::Engine; +use uuid::Uuid; +use windows::{ + core::{BOOL, PCWSTR, PSTR}, + Win32::{ + Security::Cryptography::{ + szOID_KEY_USAGE, szOID_SUBJECT_KEY_IDENTIFIER, CertCreateSelfSignCertificate, + CertFreeCertificateContext, CertStrToNameW, CryptAcquireCertificatePrivateKey, + CryptEncodeObjectEx, CryptExportPublicKeyInfo, CryptHashPublicKeyInfo, CryptMsgClose, + CryptMsgControl, CryptMsgGetParam, CryptMsgOpenToDecode, CryptMsgUpdate, + NCryptCreatePersistedKey, NCryptFinalizeKey, NCryptFreeObject, + NCryptOpenStorageProvider, NCryptSetProperty, CALG_SHA1, CERT_CONTEXT, + CERT_DATA_ENCIPHERMENT_KEY_USAGE, CERT_DIGITAL_SIGNATURE_KEY_USAGE, CERT_EXTENSION, + CERT_EXTENSIONS, CERT_KEY_CERT_SIGN_KEY_USAGE, CERT_KEY_ENCIPHERMENT_KEY_USAGE, + CERT_KEY_SPEC, CERT_OFFLINE_CRL_SIGN_KEY_USAGE, CERT_PUBLIC_KEY_INFO, + CERT_X500_NAME_STR, CMSG_CTRL_DECRYPT, CMSG_CTRL_DECRYPT_PARA, + CMSG_CTRL_DECRYPT_PARA_0, CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG, + CRYPT_ACQUIRE_COMPARE_KEY_FLAG, CRYPT_ACQUIRE_FLAGS, + CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG, CRYPT_BIT_BLOB, CRYPT_INTEGER_BLOB, + HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, MS_KEY_STORAGE_PROVIDER, NCRYPT_ALLOW_EXPORT_FLAG, + NCRYPT_ALLOW_PLAINTEXT_EXPORT_FLAG, NCRYPT_EXPORT_POLICY_PROPERTY, NCRYPT_HANDLE, + NCRYPT_KEY_HANDLE, NCRYPT_LENGTH_PROPERTY, NCRYPT_PROV_HANDLE, NCRYPT_RSA_ALGORITHM, + PKCS_7_ASN_ENCODING, X509_ASN_ENCODING, + }, + System::SystemInformation::GetSystemTime, + }, +}; + +use crate::{ + certificate::certificate_helper::CertificateDetailsWrapper, + formatted_error_message::FormattedErrorMessage, +}; + +pub struct CertificateDetailsWindows { + pub public_key_der: Vec, + pub p_cert_ctx: *mut CERT_CONTEXT, +} + +impl Drop for CertificateDetailsWindows { + fn drop(&mut self) { + if !self.p_cert_ctx.is_null() + && !unsafe { CertFreeCertificateContext(Some(self.p_cert_ctx)) }.as_bool() + { + eprintln!("Failed to free certificate context."); + } + } +} + +pub fn generate_self_signed_certificate_windows( + subject_name: &str, +) -> Result { + // Open KSP + let mut h_prov = NCRYPT_PROV_HANDLE(0); + unsafe { + NCryptOpenStorageProvider(&mut h_prov, MS_KEY_STORAGE_PROVIDER, 0)?; + } + + // Create an RSA key + let key_name: Vec = Uuid::new_v4().to_string().encode_utf16().collect(); + let mut h_key = NCRYPT_KEY_HANDLE(0); + unsafe { + NCryptCreatePersistedKey( + h_prov, + &mut h_key, + NCRYPT_RSA_ALGORITHM, + PCWSTR(key_name.as_ptr()), // not NULL + windows::Win32::Security::Cryptography::CERT_KEY_SPEC(0), + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + } + + let key_length: u32 = 2048; + unsafe { + // Set key length property to 2048 bits + NCryptSetProperty( + h_key.into(), + NCRYPT_LENGTH_PROPERTY, + &key_length.to_ne_bytes(), + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + + // Set key export policy + NCryptSetProperty( + h_key.into(), + NCRYPT_EXPORT_POLICY_PROPERTY, + &(NCRYPT_ALLOW_EXPORT_FLAG | NCRYPT_ALLOW_PLAINTEXT_EXPORT_FLAG).to_ne_bytes(), + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + + // Finalize the key + NCryptFinalizeKey( + h_key, + windows::Win32::Security::Cryptography::NCRYPT_FLAGS(0), + )?; + } + + // Set up subject name for cert + let subject = format!("CN={subject_name}"); + let subject_w: Vec = subject.encode_utf16().chain(Some(0)).collect(); + let mut size = 0u32; + unsafe { + CertStrToNameW( + X509_ASN_ENCODING, + PCWSTR(subject_w.as_ptr()), + CERT_X500_NAME_STR, + Some(std::ptr::null_mut()), + Some(std::ptr::null_mut()), + &mut size, + Some(std::ptr::null_mut()), + )?; + } + let mut name_buf = vec![0u8; size as usize]; + unsafe { + CertStrToNameW( + X509_ASN_ENCODING, + PCWSTR(subject_w.as_ptr()), + CERT_X500_NAME_STR, + None, + Some(name_buf.as_mut_ptr()), + &mut size, + Some(std::ptr::null_mut()), + )?; + } + + let subject_blob = CRYPT_INTEGER_BLOB { + cbData: size, + pbData: name_buf.as_mut_ptr(), + }; + + // Validity period + let mut start = unsafe { GetSystemTime() }; + start.wYear -= 1; + let mut end = unsafe { GetSystemTime() }; + end.wYear += 3; + + let mut exts = build_cert_extensions(HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(h_key.0))?; + + let cert_exts = CERT_EXTENSIONS { + cExtension: exts.len() as u32, + rgExtension: exts.as_mut_ptr(), + }; + + // Create cert + let cert_ctx = unsafe { + CertCreateSelfSignCertificate( + Some(HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(h_key.0)), + &subject_blob, + windows::Win32::Security::Cryptography::CERT_CREATE_SELFSIGN_FLAGS(0), + None, + None, + Some(&start), + Some(&end), + Some(&cert_exts), + ) + }; + + // Get cert data + let cert_der = unsafe { + std::slice::from_raw_parts( + (*cert_ctx).pbCertEncoded, + (*cert_ctx).cbCertEncoded as usize, + ) + }; + + let cert_detials_windows = CertificateDetailsWindows { + public_key_der: cert_der.to_vec(), + p_cert_ctx: cert_ctx, + }; + let res = CertificateDetailsWrapper { + cert_details: cert_detials_windows, + }; + + // Cleanup + unsafe { + NCryptFreeObject(h_prov.into())?; + NCryptFreeObject(h_key.into())?; + }; + Ok(res) +} + +pub fn decrypt_from_base64_windows( + base64_input: &str, + cert_details_wrapper: &CertificateDetailsWrapper, +) -> Result { + let encrypted = base64_input.replace("\r", "").replace("\n", ""); + let encrypted_payload = &base64::engine::general_purpose::STANDARD.decode(encrypted)?; + let p_cert_ctx = cert_details_wrapper.cert_details.p_cert_ctx; + + // Acquire the private key handle using the CNG-compatible function. + let mut h_key = HCRYPTPROV_OR_NCRYPT_KEY_HANDLE(0); + let mut key_spec = CERT_KEY_SPEC(0u32); + let mut must_free = BOOL(0); + unsafe { + CryptAcquireCertificatePrivateKey( + p_cert_ctx, + CRYPT_ACQUIRE_FLAGS( + CRYPT_ACQUIRE_COMPARE_KEY_FLAG.0 + | CRYPT_ACQUIRE_ALLOW_NCRYPT_KEY_FLAG.0 + | CRYPT_ACQUIRE_ONLY_NCRYPT_KEY_FLAG.0, + ), + None, + &mut h_key, + Some(&mut key_spec), + Some(&mut must_free), + ) + }?; + + // Decode the encrypted message. + let msg_handle = unsafe { + CryptMsgOpenToDecode( + X509_ASN_ENCODING.0 | PKCS_7_ASN_ENCODING.0, + 0, + 0, + None, + None, + None, + ) + }; + + if msg_handle.is_null() { + return Err(FormattedErrorMessage { + code: -1, + message: "Failed to open message handle to decrypt.".to_string(), + }); + } + unsafe { CryptMsgUpdate(msg_handle, Some(encrypted_payload), true) }?; + // Create an instance of the nested struct (the union) + let anonymous_union = CMSG_CTRL_DECRYPT_PARA_0 { + hCryptProv: h_key.0, + }; + + // Create the main struct instance + let mut decrypt_para = CMSG_CTRL_DECRYPT_PARA { + cbSize: std::mem::size_of::() as u32, + Anonymous: anonymous_union, + dwKeySpec: key_spec.0, + dwRecipientIndex: 0, + }; + + unsafe { + CryptMsgControl( + msg_handle, + 0, + CMSG_CTRL_DECRYPT, + Some(&mut decrypt_para as *mut _ as *mut _), + ) + }?; + + // Get the decrypted message size. + let mut content_size = 0; + unsafe { CryptMsgGetParam(msg_handle, 2, 0, None, &mut content_size) }?; + + // Get the decrypted message content. + let mut decrypted_data_buffer = vec![0u8; content_size as usize]; + unsafe { + CryptMsgGetParam( + msg_handle, + 2, + 0, + Some(decrypted_data_buffer.as_mut_ptr() as *mut _), + &mut content_size, + ) + }?; + + unsafe { CryptMsgClose(Some(msg_handle)) }?; + if must_free.as_bool() { + unsafe { NCryptFreeObject(NCRYPT_HANDLE(h_key.0)) }?; + } + + let res = String::from_utf8(decrypted_data_buffer)?; + Ok(res) +} + +fn build_cert_extensions( + h_key: HCRYPTPROV_OR_NCRYPT_KEY_HANDLE, +) -> Result, FormattedErrorMessage> { + let mut extensions: Vec = Vec::new(); + + // Key Usage + let key_usage: u8 = CERT_DIGITAL_SIGNATURE_KEY_USAGE as u8 + | CERT_KEY_CERT_SIGN_KEY_USAGE as u8 + | CERT_OFFLINE_CRL_SIGN_KEY_USAGE as u8 + | CERT_KEY_ENCIPHERMENT_KEY_USAGE as u8 + | CERT_DATA_ENCIPHERMENT_KEY_USAGE as u8; + + let key_usage_blob = CRYPT_BIT_BLOB { + cbData: 1, + pbData: &key_usage as *const u8 as *mut u8, + cUnusedBits: 0, + }; + + let mut encoded_key_usage_len: u32 = 0; + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_KEY_USAGE, + &key_usage_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + None, + &mut encoded_key_usage_len, + )?; + } + + let mut encoded_key_usage = vec![0u8; encoded_key_usage_len as usize]; + + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_KEY_USAGE, + &key_usage_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + Some(encoded_key_usage.as_mut_ptr() as *mut _), + &mut encoded_key_usage_len, + )?; + } + + extensions.push(CERT_EXTENSION { + pszObjId: PSTR(szOID_KEY_USAGE.as_ptr() as *mut _), + fCritical: BOOL(0), // FALSE + Value: CRYPT_INTEGER_BLOB { + cbData: encoded_key_usage_len, + pbData: encoded_key_usage.as_mut_ptr(), + }, + }); + + let mut size = 0; + unsafe { + CryptExportPublicKeyInfo(h_key, Some(0), X509_ASN_ENCODING, None, &mut size)?; + } + let mut buffer = vec![0u8; size as usize]; + + let p_info = buffer.as_mut_ptr() as *mut CERT_PUBLIC_KEY_INFO; + + // Subject Key Identifier (let Windows generate it) + unsafe { + CryptExportPublicKeyInfo(h_key, Some(0), X509_ASN_ENCODING, Some(p_info), &mut size) + }?; + + let mut ski_hash = [0u8; 20]; + let mut ski_size = ski_hash.len() as u32; + unsafe { + CryptHashPublicKeyInfo( + None, + CALG_SHA1, + 0, + X509_ASN_ENCODING, + p_info, + Some(ski_hash.as_mut_ptr()), + &mut ski_size, + ) + }?; + + let ski_blob = CRYPT_INTEGER_BLOB { + cbData: ski_size, + pbData: ski_hash.as_mut_ptr(), + }; + + let mut encoded_ski_size = 0; + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_SUBJECT_KEY_IDENTIFIER, + &ski_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + None, + &mut encoded_ski_size, + ) + }?; + + let mut encoded_ski = vec![0u8; encoded_ski_size as usize]; + + unsafe { + CryptEncodeObjectEx( + X509_ASN_ENCODING, + szOID_SUBJECT_KEY_IDENTIFIER, + &ski_blob as *const _ as *const _, + windows::Win32::Security::Cryptography::CRYPT_ENCODE_OBJECT_FLAGS(0), + None, + Some(encoded_ski.as_mut_ptr() as *mut _), + &mut encoded_ski_size, + ) + }?; + + extensions.push(CERT_EXTENSION { + pszObjId: PSTR(szOID_SUBJECT_KEY_IDENTIFIER.as_ptr() as *mut _), + fCritical: BOOL(0), // FALSE + Value: CRYPT_INTEGER_BLOB { + cbData: encoded_ski_size, + pbData: encoded_ski.as_mut_ptr(), + }, + }); + Ok(extensions) +} + +#[cfg(test)] +mod tests { + use super::*; + use windows::{ + core::PSTR, + Win32::Security::Cryptography::{ + szOID_NIST_AES256_CBC, CryptEncryptMessage, CRYPT_ENCRYPT_MESSAGE_PARA, + }, + }; + + #[test] + fn test_certificate_decryption() { + let cert = generate_self_signed_certificate_windows(&Uuid::new_v4().to_string()).unwrap(); + + let org_str = "Hello, World!"; + let encrypted = encrypt(&cert.cert_details, org_str); + + let decrypted = decrypt_from_base64_windows(&encrypted, &cert).unwrap(); + + assert!(decrypted.eq(org_str)) + } + + pub fn encrypt(cert: &CertificateDetailsWindows, org_str: &str) -> String { + let mut info = CRYPT_ENCRYPT_MESSAGE_PARA::default(); + info.cbSize = std::mem::size_of::() as u32; + info.dwMsgEncodingType = X509_ASN_ENCODING.0 | PKCS_7_ASN_ENCODING.0; + info.ContentEncryptionAlgorithm.pszObjId = PSTR(szOID_NIST_AES256_CBC.as_ptr() as *mut _); + info.dwFlags = 0; + let cert_ctx_ptrs = [cert.p_cert_ctx as *const _]; + let mut encrypted_size: u32 = 0; + unsafe { + CryptEncryptMessage( + &info, + &cert_ctx_ptrs, + Some(org_str.as_bytes()), + None, + &mut encrypted_size as *mut u32, + ) + .unwrap(); + } + let mut encrypted_data = vec![0u8; encrypted_size as usize]; + unsafe { + CryptEncryptMessage( + &info, + &cert_ctx_ptrs, + Some(org_str.as_bytes()), + Some(encrypted_data.as_mut_ptr()), + &mut encrypted_size as *mut u32, + ) + .unwrap(); + } + return base64::engine::general_purpose::STANDARD.encode(&encrypted_data); + } +} diff --git a/proxy_agent_shared/src/error.rs b/proxy_agent_shared/src/error.rs index 28f22b3b..0f001d18 100644 --- a/proxy_agent_shared/src/error.rs +++ b/proxy_agent_shared/src/error.rs @@ -2,6 +2,8 @@ // SPDX-License-Identifier: MIT use http::StatusCode; +use crate::formatted_error_message::FormattedErrorMessage; + #[derive(Debug, thiserror::Error)] pub enum Error { // windows_service::Error is a custom error type from the windows-service crate @@ -40,6 +42,9 @@ pub enum Error { #[error("{0} command: {1}")] Command(CommandErrorType, String), + + #[error("{0}")] + OtherError(FormattedErrorMessage), } #[derive(Debug, thiserror::Error)] diff --git a/proxy_agent_shared/src/formatted_error_message.rs b/proxy_agent_shared/src/formatted_error_message.rs new file mode 100644 index 00000000..82a5a8c5 --- /dev/null +++ b/proxy_agent_shared/src/formatted_error_message.rs @@ -0,0 +1,134 @@ +use std::{fmt, string::FromUtf8Error}; + +use base64::DecodeError; +use tokio::time::error::Elapsed; + +use crate::error::Error; + +#[derive(Debug, Clone)] +pub struct FormattedErrorMessage { + pub message: String, + pub code: i32, +} + +impl fmt::Display for FormattedErrorMessage { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "code: {}, message: {}", self.code, self.message) + } +} + +impl std::error::Error for FormattedErrorMessage {} + +impl From for FormattedErrorMessage { + fn from(value: DecodeError) -> Self { + FormattedErrorMessage { + message: format!("Decode Error: {value:?}"), + code: -1, + } + } +} + +impl From for FormattedErrorMessage { + fn from(value: FromUtf8Error) -> Self { + FormattedErrorMessage { + message: format!("Utf-8 Convert Error: {value:?}"), + code: -1, + } + } +} + +impl From for FormattedErrorMessage { + fn from(value: serde_json::Error) -> Self { + FormattedErrorMessage { + message: format!("Json Error: {value:?}"), + code: -1, + } + } +} + +impl From for FormattedErrorMessage { + fn from(value: String) -> Self { + FormattedErrorMessage { + message: format!("GeneralError: {value}"), + code: -1, + } + } +} + +impl From for FormattedErrorMessage { + fn from(value: Elapsed) -> Self { + FormattedErrorMessage { + message: format!("Operation timeout: {value}"), + code: -1, + } + } +} + +impl From for Error { + fn from(value: FormattedErrorMessage) -> Self { + Error::OtherError(value) + } +} + +#[cfg(windows)] +impl From for FormattedErrorMessage { + fn from(value: windows::core::Error) -> Self { + FormattedErrorMessage { + message: format!("Windows API Error: {value:?}"), + code: -1, + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use tokio::time::timeout; + + use super::*; + + #[test] + fn formatted_error_display_test() { + let error = FormattedErrorMessage { + message: "An error occurred".to_string(), + code: 404, + }; + assert_eq!( + format!("{}", error), + "code: 404, message: An error occurred" + ); + } + + #[tokio::test] + async fn formatted_error_from_test() { + let decode_error = DecodeError::InvalidLength(0); + let formatted_error: FormattedErrorMessage = decode_error.into(); + assert_eq!(formatted_error.message, "Decode Error: InvalidLength(0)"); + + let utf8_bytes = vec![0, 159, 146, 150]; + let utf8_error = String::from_utf8(utf8_bytes).unwrap_err(); + let formatted_error: FormattedErrorMessage = utf8_error.into(); + assert!(formatted_error.message.starts_with("Utf-8 Convert Error:")); + + let json_error = serde_json::from_str::("invalid json").unwrap_err(); + let formatted_error: FormattedErrorMessage = json_error.into(); + assert!(formatted_error.message.starts_with("Json Error:")); + + let elapsed_error = timeout(Duration::from_millis(10), async { + tokio::time::sleep(Duration::from_secs(1)).await; + }) + .await; + + let elapsed_error = elapsed_error.unwrap_err(); + let formatted_error: FormattedErrorMessage = elapsed_error.into(); + assert!(formatted_error.message.starts_with("Operation timeout")); + + #[cfg(windows)] + { + let windows_error = windows::core::Error::from_win32(); + let formatted_error: FormattedErrorMessage = windows_error.into(); + assert!(formatted_error.message.starts_with("Windows API Error:")); + } + } +} diff --git a/proxy_agent_shared/src/host_clients.rs b/proxy_agent_shared/src/host_clients.rs index 4134ca50..4b261e1c 100644 --- a/proxy_agent_shared/src/host_clients.rs +++ b/proxy_agent_shared/src/host_clients.rs @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT pub mod goal_state; +pub mod hostga_plugin_client; +pub mod hostga_plugin_model; pub mod imds_client; pub mod instance_info; pub mod wire_server_client; diff --git a/proxy_agent_shared/src/host_clients/hostga_plugin_client.rs b/proxy_agent_shared/src/host_clients/hostga_plugin_client.rs new file mode 100644 index 00000000..ce2e1808 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/hostga_plugin_client.rs @@ -0,0 +1,450 @@ +use std::collections::HashMap; +use std::time::Duration; + +use crate::certificate::certificate_helper::{ + decrypt_from_base64, generate_self_signed_certificate, +}; +use crate::error::Error; +use crate::formatted_error_message::FormattedErrorMessage; +use crate::host_clients::hostga_plugin_model::{Certificates, RawCertificatesPayload, VMSettings}; +use crate::hyper_client; +use crate::logger::LoggerLevel; +use crate::result::Result; +use base64::Engine; +use http::{Method, StatusCode, Uri}; +use serde_derive::{Deserialize, Serialize}; +use tokio::time::timeout; +use uuid::Uuid; + +pub struct HostGAPluginClient { + base_url: String, + logger: fn(LoggerLevel, String) -> (), + timeout_in_seconds: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct HostGAPluginResponse { + pub body: Option, + pub etag: Option, + pub certificates_revision: Option, + pub version: Option, +} + +impl HostGAPluginClient { + const CERTIFICATES_URL: &'static str = "certificates"; + const VMSETTINGS_URL: &'static str = "vmSettings"; + const HOSTGAP_CIPHER: &'static str = "AES256_CBC"; + + const ETAG_HEADER: &'static str = "etag"; + const X_MS_SERVER_VERSION_HEADER: &'static str = "x-ms-server-version"; + const X_MS_CERTIFICATES_REVISION_HEADER: &'static str = "x-ms-certificates-revision"; + const TRANSPORT_CERTIFICATE_HEADER: &'static str = "x-ms-guest-agent-public-x509-cert"; + const TRANSPORT_CERTIFICATE_ENCRYPT_CIPHER_HEADER: &'static str = "x-ms-cipher-name"; + + pub fn new( + base_url: &str, + logger: fn(LoggerLevel, String) -> (), + timeout_in_seconds: Option, + ) -> HostGAPluginClient { + HostGAPluginClient { + base_url: base_url.to_string(), + logger, + timeout_in_seconds, + } + } + + pub async fn get_vmsettings( + &self, + etag: Option, + ) -> Result> { + let logger = self.logger; + + logger( + LoggerLevel::Info, + format!("Requesting VMSettings with etag: {etag:?}"), + ); + + let headers = self.vmsettings_request_headers(etag); + + self.get::( + &format!("{}/{}", self.base_url, Self::VMSETTINGS_URL), + &headers, + ) + .await + } + + pub async fn get_certificates( + &self, + cert_revision: u32, + ) -> Result> { + let logger = self.logger; + logger( + LoggerLevel::Info, + format!("Requesting certificates with revision: {cert_revision}"), + ); + + let cert = generate_self_signed_certificate(&Uuid::new_v4().to_string())?; + let cert_der = cert.get_public_cert_der(); + let cert_base64 = base64::engine::general_purpose::STANDARD.encode(cert_der); + + let headers = self.certificate_request_headers(&cert_base64); + + let raw_certs_resp = self + .get::( + &format!( + "{}/{}/{}", + self.base_url, + Self::CERTIFICATES_URL, + cert_revision + ), + &headers, + ) + .await?; + + if let Some(cert_base64) = raw_certs_resp + .body + .as_ref() + .and_then(|body| body.pkcs7_blob_with_pfx_contents.as_ref()) + { + let certs = decrypt_from_base64(cert_base64, &cert)?; + + return Ok(HostGAPluginResponse { + body: Some( + serde_json::from_str::(&certs) + .map_err(FormattedErrorMessage::from)?, + ), + etag: raw_certs_resp.etag.clone(), + certificates_revision: raw_certs_resp.certificates_revision, + version: raw_certs_resp.version.clone(), + }); + } + Err(FormattedErrorMessage { + message: "certificate payload is empty.".to_string(), + code: -1, + } + .into()) + } + + pub async fn get( + &self, + url: &str, + headers: &HashMap, + ) -> Result> + where + for<'a> T: serde::Deserialize<'a>, + { + let logger = self.logger; + let url: Uri = url + .parse::() + .map_err(|e| Error::ParseUrl(url.to_string(), e.to_string()))?; + + let request = hyper_client::build_request(Method::GET, &url, headers, None, None, None)?; + + let (host, port) = hyper_client::host_port_from_uri(&url)?; + + let response = if let Some(timeout_in_seconds) = self.timeout_in_seconds { + timeout( + Duration::from_secs(timeout_in_seconds as u64), + hyper_client::send_request(&host, port, request, move |m| { + logger(LoggerLevel::Warn, m) + }), + ) + .await + .map_err(Into::::into)?? + } else { + hyper_client::send_request(&host, port, request, move |m| logger(LoggerLevel::Warn, m)) + .await? + }; + + let etag = response + .headers() + .get(Self::ETAG_HEADER) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + let version = response + .headers() + .get(Self::X_MS_SERVER_VERSION_HEADER) + .and_then(|v| v.to_str().ok()) + .map(|v| v.to_string()); + let certificates_revision = response + .headers() + .get(Self::X_MS_CERTIFICATES_REVISION_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| s.parse::().ok()); + + let status = response.status(); + if status == StatusCode::NOT_MODIFIED { + let _hostgap_response: HostGAPluginResponse = HostGAPluginResponse { + body: None, + etag, + version, + certificates_revision, + }; + return Ok(_hostgap_response); + } else if status.is_success() { + let body_obj = hyper_client::read_response_body::(response).await?; + return Ok(HostGAPluginResponse { + body: Some(body_obj), + etag, + certificates_revision, + version, + }); + } + let body_string = hyper_client::read_response_body_as_string(response, "utf-8").await?; + Err(FormattedErrorMessage { + code: status.as_u16() as i32, + message: format!("Http Error Status: {}, Body: {}", status, &body_string), + } + .into()) + } + + fn certificate_request_headers(&self, cert: &str) -> HashMap { + let mut headers = HashMap::new(); + headers.insert( + Self::TRANSPORT_CERTIFICATE_HEADER.to_string(), + cert.to_string(), + ); + headers.insert( + Self::TRANSPORT_CERTIFICATE_ENCRYPT_CIPHER_HEADER.to_string(), + Self::HOSTGAP_CIPHER.to_string(), + ); + headers + } + + fn vmsettings_request_headers(&self, etag: Option) -> HashMap { + let mut headers = HashMap::new(); + if let Some(etag) = etag { + headers.insert(Self::ETAG_HEADER.to_string(), etag); + } + headers + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn hostgaplugin_client_creation_test() { + let client = HostGAPluginClient::new( + "http://localhost:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + None, + ); + assert_eq!(client.base_url, "http://localhost:8080"); + assert_eq!(client.timeout_in_seconds, None); + } + + #[test] + fn certificate_request_headers_test() { + let client = HostGAPluginClient::new( + "http://localhost:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + None, + ); + let cert = "test_cert"; + let headers = client.certificate_request_headers(cert); + assert_eq!( + headers + .get(HostGAPluginClient::TRANSPORT_CERTIFICATE_HEADER) + .unwrap(), + cert + ); + assert_eq!( + headers + .get(HostGAPluginClient::TRANSPORT_CERTIFICATE_ENCRYPT_CIPHER_HEADER) + .unwrap(), + HostGAPluginClient::HOSTGAP_CIPHER + ); + } + + #[test] + fn vmsettings_request_headers_test() { + let client = HostGAPluginClient::new( + "http://localhost:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + None, + ); + let etag = Some("test_etag".to_string()); + let headers = client.vmsettings_request_headers(etag.clone()); + assert_eq!( + headers.get(HostGAPluginClient::ETAG_HEADER).unwrap(), + etag.as_ref().unwrap() + ); + + let headers_no_etag = client.vmsettings_request_headers(None); + assert!(headers_no_etag + .get(HostGAPluginClient::ETAG_HEADER) + .is_none()); + } + + #[tokio::test] + async fn get_vmsettings_negative_test() { + let client = HostGAPluginClient::new( + "http://invalid:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + Some(2), + ); + let response = client.get_vmsettings(None).await; + assert!(response.is_err()); + } + + #[tokio::test] + async fn get_certificates_negative_test() { + let client = HostGAPluginClient::new( + "http://invalid:8080", + |level, message| { + println!("{:?}: {}", level, message); + }, + Some(2), + ); + let response = client.get_certificates(0).await; + assert!(response.is_err()); + } + + #[test] + fn get_hostgaplugin_certificates_response_test() { + let response = r#" + { + "body": { + "activityId": "11111111-1111111-11111111111-111111", + "correlationId": "80e22e3b-3f9a-424e-b300-6cda2dd7e718", + "certificates": [ + { + "name": "TenantEncryptionCert", + "storeName": "My", + "configurationLevel": "System", + "certificateInBase64": "certificateInBase64_test", + "includePrivateKey": false, + "thumbprint": "thumbprint_test", + "certificateBlobFormatType": "PfxInClear" + } + ] + }, + "etag": null, + "certificates_revision": null, + "version": "1.0.8.179" + } + "#; + let resp: HostGAPluginResponse = + serde_json::from_str(response).expect("Deserialize HostGAPluginResponse failed"); + assert!(resp.body.is_some()); + let certs = resp.body.unwrap(); + assert_eq!( + certs.activity_id.unwrap(), + "11111111-1111111-11111111111-111111" + ); + assert_eq!( + certs.correlation_id.unwrap(), + "80e22e3b-3f9a-424e-b300-6cda2dd7e718" + ); + assert!(certs.certificates.is_some()); + let cert_list = certs.certificates.unwrap(); + assert_eq!(cert_list.len(), 1); + let cert = &cert_list[0]; + assert_eq!(cert.name.as_ref().unwrap(), "TenantEncryptionCert"); + assert_eq!( + cert.certificate_in_base64.as_ref().unwrap(), + "certificateInBase64_test" + ); + assert_eq!(cert.thumbprint.as_ref().unwrap(), "thumbprint_test"); + } + + #[test] + fn get_hostgaplugin_vmsettings_response_test() { + let response = r#" + { + "body": { + "hostGAPluginVersion": "1.0.8.179", + "activityId": "1111-11111111-1111-11-1111", + "correlationId": "000000-00000000-000000-0000", + "inSvdSeqNo": 1, + "certificatesRevision": 0, + "extensionsLastModifiedTickCount": 638931417044754873, + "extensionGoalStatesSource": "FastTrack", + "statusUploadBlob": { + "statusBlobType": "PageBlob", + "value": "string" + }, + "gaFamilies": [ + { + "name": "Win7", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + }, + { + "name": "Win8", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + } + ], + "extensionGoalStates": [ + { + "name": "extension.test", + "version": "1.0.1", + "location": "location", + "failoverLocation": "location", + "additionalLocations": [ + "location" + ], + "state": "enabled", + "autoUpgrade": true, + "runAsStartupTask": false, + "isJson": true, + "useExactVersion": true, + "settingsSeqNo": 0, + "isMultiConfig": false, + "settings": [ + { + "protectedSettingsCertThumbprint": null, + "protectedSettings": null, + "publicSettings": "{}" + } + ] + } + ] + }, + "etag": "5048704324908356042", + "certificates_revision": 1, + "version": "1.0.8.179" + } + "#; + let resp: HostGAPluginResponse = + serde_json::from_str(response).expect("Deserialize HostGAPluginResponse failed"); + assert!(resp.body.is_some()); + assert_eq!(resp.etag.unwrap(), "5048704324908356042"); + assert_eq!(resp.certificates_revision.unwrap(), 1); + + let vmsettings = resp.body.unwrap(); + assert_eq!( + vmsettings.activity_id.unwrap(), + "1111-11111111-1111-11-1111" + ); + assert_eq!( + vmsettings.extensions_last_modified_tick_count.unwrap(), + 638931417044754873 + ); + assert_eq!(vmsettings.extension_goal_states.as_ref().unwrap().len(), 1); + assert_eq!(vmsettings.ga_families.as_ref().unwrap().len(), 2); + + let extension = &vmsettings.extension_goal_states.as_ref().unwrap()[0]; + assert_eq!(extension.name.as_ref().unwrap(), "extension.test"); + assert_eq!(extension.version.as_ref().unwrap(), "1.0.1"); + } +} diff --git a/proxy_agent_shared/src/host_clients/hostga_plugin_model.rs b/proxy_agent_shared/src/host_clients/hostga_plugin_model.rs new file mode 100644 index 00000000..5e8cf747 --- /dev/null +++ b/proxy_agent_shared/src/host_clients/hostga_plugin_model.rs @@ -0,0 +1,251 @@ +use serde_derive::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +pub struct VMSettings { + #[serde(rename = "hostGAPluginVersion")] + pub host_ga_plugin_version: Option, + #[serde(rename = "activityId")] + pub activity_id: Option, + #[serde(rename = "correlationId")] + pub correlation_id: Option, + #[serde(rename = "inSvdSeqNo")] + pub in_svd_seq_no: Option, + #[serde(rename = "certificatesRevision")] + pub certificates_revision: Option, + #[serde(rename = "extensionsLastModifiedTickCount")] + pub extensions_last_modified_tick_count: Option, + #[serde(rename = "extensionGoalStatesSource")] + pub extension_goal_states_source: Option, + #[serde(rename = "statusUploadBlob")] + pub status_upload_blob: Option, + #[serde(rename = "gaFamilies")] + pub ga_families: Option>, + #[serde(rename = "extensionGoalStates")] + pub extension_goal_states: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct StatusUploadBlob { + #[serde(rename = "statusBlobType")] + pub status_blob_type: Option, + pub value: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct GaFamily { + pub name: Option, + pub version: Option, + #[serde(rename = "isVersionFromRSM")] + pub is_version_from_rsm: Option, + #[serde(rename = "isVMEnabledForRSMUpgrades")] + pub is_vm_enabled_for_rsm_upgrades: Option, + pub uris: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ExtensionGoalState { + pub name: Option, + pub version: Option, + pub location: Option, + #[serde(rename = "failoverLocation")] + pub failover_location: Option, + #[serde(rename = "additionalLocations")] + pub additional_locations: Option>, + pub state: Option, + #[serde(rename = "autoUpgrade")] + pub auto_upgrade: Option, + #[serde(rename = "runAsStartupTask")] + pub run_as_startup_task: Option, + #[serde(rename = "isJson")] + pub is_json: Option, + #[serde(rename = "useExactVersion")] + pub use_exact_version: Option, + #[serde(rename = "settingsSeqNo")] + pub settings_seq_no: Option, + #[serde(rename = "isMultiConfig")] + pub is_multi_config: Option, + pub settings: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Settings { + #[serde(rename = "protectedSettingsCertThumbprint")] + pub protected_settings_cert_thumbprint: Option, + #[serde(rename = "protectedSettings")] + pub protected_settings: Option, + #[serde(rename = "publicSettings")] + pub public_settings: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct RawCertificatesPayload { + #[serde(rename = "Pkcs7BlobWithPfxContents")] + pub pkcs7_blob_with_pfx_contents: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Certificates { + #[serde(rename = "activityId")] + pub activity_id: Option, + + #[serde(rename = "correlationId")] + pub correlation_id: Option, + #[serde(rename = "certificates")] + pub certificates: Option>, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct Certificate { + #[serde(rename = "name")] + pub name: Option, + + #[serde(rename = "storeName")] + pub store_name: Option, + + #[serde(rename = "configurationLevel")] + pub configuration_level: Option, + + #[serde(rename = "certificateInBase64")] + pub certificate_in_base64: Option, + + #[serde(rename = "includePrivateKey")] + pub include_private_key: Option, + + #[serde(rename = "thumbprint")] + pub thumbprint: Option, + + #[serde(rename = "certificateBlobFormatType")] + pub certificate_blob_format_type: Option, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn certificates_deserialization_test() { + let certificates_json = r#" + { + "activityId": "11111111-1111111-11111111111-111111", + "correlationId": "80e22e3b-3f9a-424e-b300-6cda2dd7e718", + "certificates": [ + { + "name": "TenantEncryptionCert", + "storeName": "My", + "configurationLevel": "System", + "certificateInBase64": "certificateInBase64_test", + "includePrivateKey": false, + "thumbprint": "thumbprint_test", + "certificateBlobFormatType": "PfxInClear" + } + ] + } + "#; + let certificates: Certificates = serde_json::from_str(certificates_json).unwrap(); + assert_eq!( + certificates.activity_id.unwrap(), + "11111111-1111111-11111111111-111111" + ); + assert_eq!( + certificates.correlation_id.unwrap(), + "80e22e3b-3f9a-424e-b300-6cda2dd7e718" + ); + assert_eq!(certificates.certificates.as_ref().unwrap().len(), 1); + assert_eq!( + certificates.certificates.as_ref().unwrap()[0] + .certificate_in_base64 + .as_ref() + .unwrap(), + "certificateInBase64_test" + ); + assert_eq!( + certificates.certificates.as_ref().unwrap()[0] + .thumbprint + .as_ref() + .unwrap(), + "thumbprint_test" + ); + } + + #[test] + fn vmsettings_deserialization_test() { + let vmsettings_json = r#" + { + "hostGAPluginVersion": "1.0.8.179", + "activityId": "1111-11111111-1111-11-1111", + "correlationId": "000000-00000000-000000-0000", + "inSvdSeqNo": 1, + "certificatesRevision": 0, + "extensionsLastModifiedTickCount": 638931417044754873, + "extensionGoalStatesSource": "FastTrack", + "statusUploadBlob": { + "statusBlobType": "PageBlob", + "value": "string" + }, + "gaFamilies": [ + { + "name": "Win7", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + }, + { + "name": "Win8", + "version": "2.7.41491.1176", + "isVersionFromRSM": false, + "isVMEnabledForRSMUpgrades": true, + "uris": [ + "uri" + ] + } + ], + "extensionGoalStates": [ + { + "name": "test", + "version": "1.0.1", + "location": "location", + "failoverLocation": "location", + "additionalLocations": [ + "location" + ], + "state": "enabled", + "autoUpgrade": true, + "runAsStartupTask": false, + "isJson": true, + "useExactVersion": true, + "settingsSeqNo": 0, + "isMultiConfig": false, + "settings": [ + { + "protectedSettingsCertThumbprint": null, + "protectedSettings": null, + "publicSettings": "{}" + } + ] + } + ] + }"#; + + let vmsettings: VMSettings = serde_json::from_str(vmsettings_json).unwrap(); + assert_eq!(vmsettings.host_ga_plugin_version.unwrap(), "1.0.8.179"); + assert_eq!( + vmsettings.activity_id.unwrap(), + "1111-11111111-1111-11-1111" + ); + assert_eq!( + vmsettings.correlation_id.unwrap(), + "000000-00000000-000000-0000" + ); + assert_eq!(vmsettings.in_svd_seq_no.unwrap(), 1); + assert_eq!(vmsettings.certificates_revision.unwrap(), 0); + assert_eq!( + vmsettings.extensions_last_modified_tick_count.unwrap(), + 638931417044754873 + ); + assert_eq!(vmsettings.ga_families.as_ref().unwrap().len(), 2); + assert_eq!(vmsettings.extension_goal_states.as_ref().unwrap().len(), 1); + } +} diff --git a/proxy_agent_shared/src/hyper_client.rs b/proxy_agent_shared/src/hyper_client.rs index 6d9224ce..ed368b96 100644 --- a/proxy_agent_shared/src/hyper_client.rs +++ b/proxy_agent_shared/src/hyper_client.rs @@ -88,9 +88,7 @@ where read_response_body(response).await } -pub async fn read_response_body( - mut response: hyper::Response, -) -> Result +pub async fn read_response_body(response: hyper::Response) -> Result where T: DeserializeOwned, { @@ -128,7 +126,37 @@ where } else { ("unknown", "unknown") }; + let body_string = read_response_body_as_string(response, charset_type).await?; + match content_type { + "xml" => match serde_xml_rs::from_str(&body_string) { + Ok(t) => Ok(t), + Err(e) => Err(Error::Hyper( + HyperErrorType::Deserialize( + format!( + "Failed to xml deserialize response body with content_type {content_type} from: {body_string} with error {e}" + ) + ), + )), + }, + // default to json + _ => match serde_json::from_str(&body_string) { + Ok(t) => Ok(t), + Err(e) => Err(Error::Hyper( + HyperErrorType::Deserialize( + format!( + "Failed to json deserialize response body with {content_type} from: {body_string} with error {e}" + ) + ), + )), + }, + } +} + +pub async fn read_response_body_as_string( + mut response: hyper::Response, + charset_type: &str, +) -> Result { let mut body_string = String::new(); while let Some(next) = response.frame().await { let frame = match next { @@ -166,29 +194,7 @@ where } } - match content_type { - "xml" => match serde_xml_rs::from_str(&body_string) { - Ok(t) => Ok(t), - Err(e) => Err(Error::Hyper( - HyperErrorType::Deserialize( - format!( - "Failed to xml deserialize response body with content_type {content_type} from: {body_string} with error {e}" - ) - ), - )), - }, - // default to json - _ => match serde_json::from_str(&body_string) { - Ok(t) => Ok(t), - Err(e) => Err(Error::Hyper( - HyperErrorType::Deserialize( - format!( - "Failed to json deserialize response body with {content_type} from: {body_string} with error {e}" - ) - ), - )), - }, - } + Ok(body_string) } pub fn build_request( diff --git a/proxy_agent_shared/src/lib.rs b/proxy_agent_shared/src/lib.rs index 3d75e50a..5685798b 100644 --- a/proxy_agent_shared/src/lib.rs +++ b/proxy_agent_shared/src/lib.rs @@ -1,9 +1,11 @@ // Copyright (c) Microsoft Corporation // SPDX-License-Identifier: MIT +pub mod certificate; pub mod error; #[cfg(windows)] pub mod etw; +pub mod formatted_error_message; pub mod host_clients; pub mod hyper_client; pub mod logger;