|
| 1 | +use aws_sdk_s3::primitives::ByteStream; |
| 2 | +use aws_config::{BehaviorVersion, Region}; |
| 3 | +use rand::RngCore; |
| 4 | +use rand_chacha::rand_core::OsRng; |
| 5 | +use crate::tokio_block_on; |
| 6 | + |
| 7 | +pub fn get_test_region() -> String { |
| 8 | + std::env::var("S3_REGION").expect("Set S3_REGION to run integration tests") |
| 9 | +} |
| 10 | + |
| 11 | +pub fn get_standard_bucket() -> String { |
| 12 | + std::env::var("S3_BUCKET_NAME").expect("Set S3_BUCKET_NAME to run integration tests") |
| 13 | +} |
| 14 | + |
| 15 | +pub fn get_test_bucket_forbidden() -> String { |
| 16 | + std::env::var("S3_FORBIDDEN_BUCKET_NAME").expect("Set S3_FORBIDDEN_BUCKET_NAME to run integration tests") |
| 17 | +} |
| 18 | + |
| 19 | +pub fn get_test_kms_key_id() -> String { |
| 20 | + std::env::var("KMS_TEST_KEY_ID").expect("Set KMS_TEST_KEY_ID to run integration tests") |
| 21 | +} |
| 22 | + |
| 23 | +// Get a region other than what configured in S3_REGION |
| 24 | +pub fn get_non_test_region() -> String { |
| 25 | + match get_test_region().as_str() { |
| 26 | + "us-east-1" => String::from("us-west-2"), |
| 27 | + _ => String::from("us-east-1"), |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +/// Optional config for testing against a custom endpoint url |
| 32 | +pub fn get_test_endpoint_url() -> Option<String> { |
| 33 | + if cfg!(feature = "s3express_tests") { |
| 34 | + std::env::var("S3_EXPRESS_ONE_ZONE_ENDPOINT_URL") |
| 35 | + .ok() |
| 36 | + .filter(|str| !str.is_empty()) |
| 37 | + } else { |
| 38 | + std::env::var("S3_ENDPOINT_URL").ok().filter(|str| !str.is_empty()) |
| 39 | + } |
| 40 | +} |
| 41 | + |
| 42 | +pub async fn get_test_sdk_client(region: &str) -> aws_sdk_s3::Client { |
| 43 | + let mut sdk_config = aws_config::defaults(BehaviorVersion::latest()).region(Region::new(region.to_owned())); |
| 44 | + if let Some(endpoint_url) = get_test_endpoint_url() { |
| 45 | + sdk_config = sdk_config.endpoint_url(endpoint_url); |
| 46 | + } |
| 47 | + aws_sdk_s3::Client::new(&sdk_config.load().await) |
| 48 | +} |
| 49 | + |
| 50 | +pub fn create_objects(bucket: &str, prefix: &str, region: &str, key: &str, value: &[u8]) { |
| 51 | + let sdk_client = tokio_block_on(get_test_sdk_client(region)); |
| 52 | + let full_key = format!("{prefix}{key}"); |
| 53 | + tokio_block_on( |
| 54 | + sdk_client |
| 55 | + .put_object() |
| 56 | + .bucket(bucket) |
| 57 | + .key(full_key) |
| 58 | + .body(ByteStream::from(value.to_vec())) |
| 59 | + .send(), |
| 60 | + ) |
| 61 | + .unwrap(); |
| 62 | +} |
| 63 | + |
| 64 | +pub fn get_test_bucket_and_prefix(test_name: &str) -> (String, String) { |
| 65 | + let bucket = get_test_bucket(); |
| 66 | + let prefix = get_test_prefix(test_name); |
| 67 | + |
| 68 | + (bucket, prefix) |
| 69 | +} |
| 70 | + |
| 71 | +pub fn get_test_prefix(test_name: &str) -> String { |
| 72 | + // Generate a random nonce to make sure this prefix is truly unique |
| 73 | + let nonce = OsRng.next_u64(); |
| 74 | + |
| 75 | + // Prefix always has a trailing "/" to keep meaning in sync with the S3 API. |
| 76 | + let prefix = std::env::var("S3_BUCKET_TEST_PREFIX").unwrap_or(String::from("mountpoint-test/")); |
| 77 | + assert!(prefix.ends_with('/'), "S3_BUCKET_TEST_PREFIX should end in '/'"); |
| 78 | + |
| 79 | + format!("{prefix}{test_name}/{nonce}/") |
| 80 | +} |
| 81 | + |
| 82 | +pub fn get_test_bucket() -> String { |
| 83 | + #[cfg(not(feature = "s3express_tests"))] |
| 84 | + return get_standard_bucket(); |
| 85 | + #[cfg(feature = "s3express_tests")] |
| 86 | + return get_express_bucket(); |
| 87 | +} |
0 commit comments