diff --git a/Cargo.toml b/Cargo.toml index c303ddb..9cd9655 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -59,6 +59,8 @@ chrono = { version = "0.4.31", optional = true, default-features = false, featur chrono-tz = { version = ">=0.6, <0.11.0", optional = true } chronoutil = { version = "0.2", optional = true } duration-str = { version = ">=0.11, <0.16", optional = true, default-features = false } +http = { version = "1.3.1", optional = true } +reqwest = { version = "0.12", features = ["json"], optional = true } # used for tests [dev-dependencies.tokio] version = "1.5" @@ -69,6 +71,7 @@ wasmtime = { version = ">=22, <34", default-features = false, features = [ "cranelift", ] } insta = { version = "1", features = ["yaml"] } +httpmock = "0.7.0" [build-dependencies] # We would like at least this version of rayon, because older versions depend on old rand, @@ -103,6 +106,7 @@ fast = ["wasmtime/cranelift", "wasmtime/parallel-compilation"] rng = ["dep:rand"] time = ["dep:chrono"] +http = ["dep:http"] base64url-builtins = ["dep:base64", "dep:hex"] crypto-digest-builtins = ["dep:digest", "dep:hex"] @@ -117,6 +121,7 @@ json-builtins = ["dep:json-patch"] units-builtins = ["dep:parse-size"] rand-builtins = ["rng"] yaml-builtins = ["dep:serde_yaml"] +http-builtins = ["http", "dep:serde_yaml", "dep:duration-str"] urlquery-builtins = ["dep:form_urlencoded", "dep:urlencoding"] time-builtins = ["time", "dep:chrono-tz", "dep:duration-str", "dep:chronoutil"] @@ -138,13 +143,16 @@ all-builtins = [ "sprintf-builtins", "units-builtins", "yaml-builtins", + "http-builtins", "urlquery-builtins", "time-builtins", ] +testing = ["http", "dep:reqwest"] + [[test]] name = "smoke_test" -required-features = ["loader"] +required-features = ["loader", "http-builtins", "testing"] [[bin]] name = "opa-eval" diff --git a/features.txt b/features.txt index a609839..9f11d0d 100644 --- a/features.txt +++ b/features.txt @@ -2,6 +2,7 @@ loader cli rng +http base64url-builtins crypto-digest-builtins crypto-md5-builtins crypto-digest-builtins crypto-sha1-builtins @@ -16,6 +17,8 @@ json-builtins units-builtins rand-builtins yaml-builtins +http-builtins time-builtins all-crypto-builtins all-builtins +testing diff --git a/src/builtins/impls/http.rs b/src/builtins/impls/http.rs index 097f1d3..0586b3a 100644 --- a/src/builtins/impls/http.rs +++ b/src/builtins/impls/http.rs @@ -14,10 +14,209 @@ //! Builtins used to make HTTP request -use anyhow::{bail, Result}; +use std::{collections::HashMap, future::Future, pin::Pin, time::Duration}; + +use anyhow::{Context, Result}; +use serde_json::{self, Map}; +use serde_yaml; +use tokio::time::sleep; + +use crate::{builtins::traits::Builtin, EvaluationContext}; + +/// This builtin is needed because the wrapper in traits.rs doesn't work when +/// dealing with async+context. +pub struct HttpSendBuiltin; + +impl Builtin for HttpSendBuiltin +where + C: EvaluationContext, +{ + fn call<'a>( + &'a self, + context: &'a mut C, + args: &'a [&'a [u8]], + ) -> Pin, anyhow::Error>> + Send + 'a>> { + Box::pin(async move { + let [opa_req]: [&'a [u8]; 1] = args.try_into().ok().context("invalid arguments")?; + let opa_req: serde_json::Value = + serde_json::from_slice(opa_req).context("failed to convert opa_req argument")?; + let res = send(context, opa_req).await?; + let res = serde_json::to_vec(&res).context("could not serialize result")?; + Ok(res) + }) + } +} /// Returns a HTTP response to the given HTTP request. -#[tracing::instrument(name = "http.send", err)] -pub fn send(request: serde_json::Value) -> Result { - bail!("not implemented"); +/// +/// Wraps [`internal_send`] to add error handling regarding the `raise_error` +/// field in the OPA request. +#[tracing::instrument(name = "http.send", skip(ctx), err)] +pub async fn send( + ctx: &mut C, + opa_req: serde_json::Value, +) -> Result { + let raise_error = opa_req + .get("raise_error") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true); + + match internal_send(ctx, opa_req).await { + Ok(resp) => Ok(resp), + Err(e) => { + if raise_error { + Err(e) + } else { + Ok(serde_json::json!({ + "status_code": 0, + "error": { "message": e.to_string() }, + })) + } + } + } +} + +/// Sends a HTTP request and returns the response. +async fn internal_send( + ctx: &mut C, + opa_req: serde_json::Value, +) -> Result { + let opa_req = opa_req + .as_object() + .ok_or_else(|| anyhow::anyhow!("request must be a JSON object"))?; + + let http_req = convert_opa_req_to_http_req(opa_req)?; + + let timeout_value = opa_req.get("timeout"); + let timeout = if let Some(timeout_value) = timeout_value { + if let Some(timeout_nanos) = timeout_value.as_u64() { + Some(Duration::from_nanos(timeout_nanos)) + } else if let Some(timeout_str) = timeout_value.as_str() { + duration_str::parse(timeout_str).ok() + } else { + None + } + } else { + None + }; + + let enable_redirect = opa_req + .get("enable_redirect") + .and_then(serde_json::Value::as_bool); + + let max_retry_attempts = opa_req + .get("max_retry_attempts") + .and_then(serde_json::Value::as_u64) + .unwrap_or(0); + + let mut http_resp_res: Result> = Err(anyhow::anyhow!("unreachable")); + + for attempt in 0..=max_retry_attempts { + http_resp_res = ctx + .send_http(http_req.clone(), timeout, enable_redirect) + .await; + if http_resp_res.is_ok() { + break; + } + if max_retry_attempts > 0 { + #[allow(clippy::cast_possible_truncation)] + sleep(Duration::from_millis(500 * 2_u64.pow(attempt as u32))).await; + } + } + + let http_resp = http_resp_res?; + + let force_json_decode = opa_req + .get("force_json_decode") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + let force_yaml_decode = opa_req + .get("force_yaml_decode") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + + Ok(convert_http_resp_to_opa_resp( + http_resp, + force_json_decode, + force_yaml_decode, + )) +} + +/// Converts an OPA request to an HTTP request. +fn convert_opa_req_to_http_req( + opa_req: &Map, +) -> Result> { + let url = opa_req + .get("url") + .ok_or_else(|| anyhow::anyhow!("missing url"))? + .as_str() + .ok_or_else(|| anyhow::anyhow!("url must be a string"))?; + let method = opa_req + .get("method") + .ok_or_else(|| anyhow::anyhow!("missing method"))? + .as_str() + .map(str::to_uppercase) + .ok_or_else(|| anyhow::anyhow!("method must be a string"))?; + let headers = opa_req.get("headers").and_then(|v| v.as_object()); + + let mut req_builder = http::Request::builder().method(method.as_str()).uri(url); + if let Some(headers) = headers { + for (key, value) in headers { + req_builder = req_builder.header(key, value.to_string()); + } + } + + let json_req_body = opa_req.get("body"); + let http_req = if let Some(json_req_body) = json_req_body { + req_builder.body(json_req_body.to_string())? + } else { + let raw_req_body = opa_req + .get("raw_body") + .map(ToString::to_string) + .unwrap_or_default(); + req_builder.body(raw_req_body)? + }; + + Ok(http_req) +} + +/// Converts an HTTP response to an OPA response. +fn convert_http_resp_to_opa_resp( + response: http::Response, + force_json_decode: bool, + force_yaml_decode: bool, +) -> serde_json::Value { + let response_headers = response + .headers() + .iter() + .map(|(k, v)| (k.as_str(), v.to_str().unwrap_or(""))) + .collect::>(); + + let mut opa_resp = serde_json::json!({ + "status_code": response.status().as_u16(), + "headers": response_headers, + }); + + let raw_resp_body = response.body().clone(); + opa_resp["raw_body"] = serde_json::Value::String(raw_resp_body.clone()); + + let content_type = response + .headers() + .get("content-type") + .map(|v| v.to_str().unwrap_or_default()); + + if force_json_decode || content_type == Some("application/json") { + if let Ok(parsed_body) = serde_json::from_str::(&raw_resp_body) { + opa_resp["body"] = parsed_body; + } + } else if force_yaml_decode + || content_type == Some("application/yaml") + || content_type == Some("application/x-yaml") + { + if let Ok(parsed_body) = serde_yaml::from_str::(&raw_resp_body) { + opa_resp["body"] = parsed_body; + } + } + + opa_resp } diff --git a/src/builtins/impls/mod.rs b/src/builtins/impls/mod.rs index 8013b1a..4f228b0 100644 --- a/src/builtins/impls/mod.rs +++ b/src/builtins/impls/mod.rs @@ -27,6 +27,7 @@ pub mod graph; pub mod graphql; #[cfg(feature = "hex-builtins")] pub mod hex; +#[cfg(feature = "http-builtins")] pub mod http; pub mod io; #[cfg(feature = "json-builtins")] diff --git a/src/builtins/impls/rand.rs b/src/builtins/impls/rand.rs index dbedd2a..d949ace 100644 --- a/src/builtins/impls/rand.rs +++ b/src/builtins/impls/rand.rs @@ -35,7 +35,7 @@ pub fn intn(ctx: &mut C, str: String, n: i64) -> Result(name: &str) -> Result>> #[cfg(feature = "hex-builtins")] "hex.encode" => Ok(self::impls::hex::encode.wrap()), - "http.send" => Ok(self::impls::http::send.wrap()), + #[cfg(feature = "http-builtins")] + "http.send" => Ok(Box::new(self::impls::http::HttpSendBuiltin)), "indexof_n" => Ok(self::impls::indexof_n.wrap()), "io.jwt.decode" => Ok(self::impls::io::jwt::decode.wrap()), "io.jwt.decode_verify" => Ok(self::impls::io::jwt::decode_verify.wrap()), diff --git a/src/context.rs b/src/context.rs index 110cb0f..746160a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -17,6 +17,8 @@ #![allow(clippy::module_name_repetitions)] use std::collections::HashMap; +#[cfg(feature = "http")] +use std::time::Duration; use anyhow::Result; #[cfg(feature = "time")] @@ -37,6 +39,15 @@ pub trait EvaluationContext: Send + 'static { #[cfg(feature = "time")] fn now(&self) -> chrono::DateTime; + /// Send an HTTP request + #[cfg(feature = "http")] + fn send_http( + &self, + req: http::Request, + timeout: Option, + enable_redirect: Option, + ) -> impl std::future::Future>> + Send + Sync; + /// Notify the context on evaluation start, so it can clean itself up fn evaluation_start(&mut self); @@ -91,6 +102,21 @@ impl EvaluationContext for DefaultContext { self.evaluation_time } + #[cfg(feature = "http")] + fn send_http( + &self, + _req: http::Request, + _timeout: Option, + _enable_redirect: Option, + ) -> impl std::future::Future>> + Send + Sync { + // This is a stub implementation. Default context does not implement + // actual HTTP requests due to security reasons - HTTP calls from policy + // should be explicitly allowed/moderated by the integration. + // For an example of a context that does implement HTTP requests, see + // the `TestContext` in the `tests` module below. + Box::pin(async { anyhow::bail!("http.send not implemented in DefaultContext") }) + } + fn evaluation_start(&mut self) { // Clear the cache self.cache = HashMap::new(); @@ -121,14 +147,33 @@ impl EvaluationContext for DefaultContext { } /// Test utilities +#[cfg(feature = "testing")] pub mod tests { use anyhow::Result; #[cfg(feature = "time")] use chrono::TimeZone; + #[cfg(feature = "http")] + use reqwest; use serde::{de::DeserializeOwned, Serialize}; + #[cfg(feature = "http")] + use std::time::Duration; use crate::{DefaultContext, EvaluationContext}; + /// Builds a [`reqwest::Client`] with the given timeout and redirect policy. + #[cfg(feature = "http")] + fn build_reqwest_client(timeout: Duration, enable_redirect: bool) -> reqwest::Client { + let mut client_builder = reqwest::Client::builder(); + client_builder = client_builder.timeout(timeout); + client_builder = client_builder.redirect(if enable_redirect { + reqwest::redirect::Policy::default() + } else { + reqwest::redirect::Policy::none() + }); + #[allow(clippy::unwrap_used)] + client_builder.build().unwrap() + } + /// A context used in tests pub struct TestContext { /// The inner [`DefaultContext`] @@ -182,6 +227,33 @@ pub mod tests { rand::rngs::StdRng::seed_from_u64(self.seed) } + #[cfg(feature = "http")] + async fn send_http( + &self, + req: http::Request, + timeout: Option, + enable_redirect: Option, + ) -> Result> { + let client = build_reqwest_client( + timeout.unwrap_or(Duration::from_secs(5)), + enable_redirect.unwrap_or(false), + ); + + let response: reqwest::Response = + client.execute(reqwest::Request::try_from(req)?).await?; + + let mut builder = http::Response::builder().status(response.status()); + for (name, value) in response.headers() { + builder = builder.header(name, value); + } + + let bytes_body = response.bytes().await?; + let string_body = String::from_utf8(bytes_body.to_vec())?; + builder + .body(string_body) + .map_err(|e| anyhow::anyhow!("Failed to build response: {}", e)) + } + fn cache_get(&mut self, key: &K) -> Result> { self.inner.cache_get(key) } diff --git a/src/lib.rs b/src/lib.rs index ea0c33a..7b0f561 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,7 +41,9 @@ pub use wasmtime; #[cfg(feature = "loader")] pub use self::loader::{load_bundle, read_bundle}; pub use self::{ - context::{tests::TestContext, DefaultContext, EvaluationContext}, + context::{DefaultContext, EvaluationContext}, policy::{Policy, Runtime}, types::AbiVersion, }; +#[cfg(feature = "testing")] +pub use context::tests::TestContext; diff --git a/tests/fixtures/test-http.rego b/tests/fixtures/test-http.rego new file mode 100644 index 0000000..d0f9102 --- /dev/null +++ b/tests/fixtures/test-http.rego @@ -0,0 +1,15 @@ +package fixtures + +import rego.v1 + +# Test that automatic ser/der is working fine for request and response +get_json := http.send({"url": sprintf("%s/json", [input.base_url]), "method": "get"}) +get_yaml := http.send({"url": sprintf("%s/yaml", [input.base_url]), "method": "get"}) +post_json := http.send({"url": sprintf("%s/post", [input.base_url]), "method": "post", "body": {"key": "value"}}) + +# Test a connection error doesn't error out the whole policy when using raise_error=false +get_no_conn := http.send({"url": "https://cahbe8ang5umaiwavai1shuchiehae7u.com", "method": "get", "raise_error": false}) + +# Test automatic redirection +get_redirect := http.send({"url": sprintf("%s/redirect", [input.base_url]), "method": "get"}) +get_redirect_follow := http.send({"url": sprintf("%s/redirect", [input.base_url]), "method": "get", "enable_redirect": true}) diff --git a/tests/smoke_test.rs b/tests/smoke_test.rs index e728501..d9a28ac 100644 --- a/tests/smoke_test.rs +++ b/tests/smoke_test.rs @@ -23,7 +23,7 @@ macro_rules! integration_test { ($name:ident, $suite:expr) => { #[tokio::test] async fn $name() { - assert_yaml_snapshot!(test_policy($suite, None) + assert_yaml_snapshot!(test_policy_with_datafile($suite, None) .await .expect("error in test suite")); } @@ -31,7 +31,7 @@ macro_rules! integration_test { ($name:ident, $suite:expr, input = $input:expr) => { #[tokio::test] async fn $name() { - assert_yaml_snapshot!(test_policy($suite, Some($input)) + assert_yaml_snapshot!(test_policy_with_datafile($suite, Some($input)) .await .expect("error in test suite")); } @@ -82,13 +82,25 @@ fn input(name: &str) -> String { .into() } -async fn test_policy(bundle_name: &str, data: Option<&str>) -> AnyResult { - let input = if let Some(data) = data { - let input_bytes = tokio::fs::read(input(&format!("{}.json", data))).await?; - serde_json::from_slice(&input_bytes[..])? - } else { - serde_json::Value::Object(serde_json::Map::default()) +async fn test_policy_with_datafile( + bundle_name: &str, + datafile_path: Option<&str>, +) -> AnyResult { + let input = match datafile_path { + Some(path) => { + let bytes = tokio::fs::read(input(&format!("{}.json", path))).await?; + Some(serde_json::from_slice(&bytes[..])?) + } + None => None, }; + test_policy(bundle_name, input).await +} + +async fn test_policy( + bundle_name: &str, + data: Option, +) -> AnyResult { + let input = data.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::default())); eval_policy( &bundle(&format!("{}.rego.tar.gz", bundle_name)), "fixtures", @@ -122,6 +134,95 @@ integration_test!(test_yaml, "test-yaml"); integration_test!(test_urlquery, "test-urlquery"); integration_test!(test_time, "test-time"); +#[tokio::test] +async fn test_http() { + let server = httpmock::MockServer::start(); + + let content_value: serde_json::Value = serde_json::json!({ "key": "value" }); + + let get_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/json"); + then.status(200) + .header("Content-Type", "application/json") + .body(content_value.to_string()); + }); + + let get_yaml_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/yaml"); + then.status(200) + .header("Content-Type", "application/yaml") + .body(serde_yaml::to_string(&content_value).unwrap()); + }); + + let post_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/post") + .json_body(content_value.clone()); + then.status(200); + }); + + let redirect_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/redirect"); + then.status(302).header("Location", "/target"); + }); + + let target_json_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/target"); + then.status(200); + }); + + let res = test_policy( + "test-http", + Some(serde_json::json!({"base_url": server.url("")})), + ) + .await + .expect("error in test suite"); + + let result = res.as_array().unwrap()[0] + .as_object() + .unwrap() + .get("result") + .unwrap() + .as_object() + .unwrap(); + + get_json_mock.assert(); + let get_json_res = result.get("get_json").unwrap().as_object().unwrap(); + assert_eq!( + get_json_res.get("raw_body").unwrap(), + &content_value.to_string() + ); + assert_eq!(get_json_res.get("body").unwrap(), &content_value); + + get_yaml_mock.assert(); + let get_yaml_res = result.get("get_yaml").unwrap().as_object().unwrap(); + assert_eq!( + get_yaml_res.get("raw_body").unwrap(), + &serde_yaml::to_string(&content_value).unwrap() + ); + assert_eq!(get_yaml_res.get("body").unwrap(), &content_value); + + post_json_mock.assert(); + let post_json_res = result.get("post_json").unwrap().as_object().unwrap(); + assert_eq!(post_json_res.get("status_code").unwrap(), &200); + + let get_no_conn_res = result.get("get_no_conn").unwrap().as_object().unwrap(); + assert_eq!(get_no_conn_res.get("status_code").unwrap(), &0); + + redirect_json_mock.assert_hits(2); + target_json_mock.assert(); + + let get_redirect_res = result.get("get_redirect").unwrap().as_object().unwrap(); + assert_eq!(get_redirect_res.get("status_code").unwrap(), &302); + + let get_redirect_follow_res = result + .get("get_redirect_follow") + .unwrap() + .as_object() + .unwrap(); + assert_eq!(get_redirect_follow_res.get("status_code").unwrap(), &200); +} + /* #[tokio::test] async fn test_uuid() {