diff --git a/Cargo.lock b/Cargo.lock index 759a94a031..87bf8aba8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2173,6 +2173,40 @@ dependencies = [ "sha2", ] +[[package]] +name = "deadpool" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ed5957ff93768adf7a65ab167a17835c3d2c3c50d084fe305174c112f468e2f" +dependencies = [ + "deadpool-runtime", + "num_cpus", + "tokio", +] + +[[package]] +name = "deadpool-postgres" +version = "0.14.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d697d376cbfa018c23eb4caab1fd1883dd9c906a8c034e8d9a3cb06a7e0bef9" +dependencies = [ + "async-trait", + "deadpool", + "getrandom 0.2.15", + "tokio", + "tokio-postgres", + "tracing", +] + +[[package]] +name = "deadpool-runtime" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "092966b41edc516079bdf31ec78a2e0588d1d0c08f78b91d8307215928642b2b" +dependencies = [ + "tokio", +] + [[package]] name = "debugid" version = "0.8.0" @@ -3301,6 +3335,20 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "generator" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d18470a76cb7f8ff746cf1f7470914f900252ec36bbc40b569d74b1258446827" +dependencies = [ + "cc", + "cfg-if", + "libc", + "log", + "rustversion", + "windows 0.61.3", +] + [[package]] name = "generic-array" version = "0.14.7" @@ -3561,7 +3609,7 @@ dependencies = [ "dirs 4.0.0", "gix-path", "libc", - "windows", + "windows 0.43.0", ] [[package]] @@ -4084,7 +4132,7 @@ dependencies = [ "iana-time-zone-haiku", "js-sys", "wasm-bindgen", - "windows-core", + "windows-core 0.52.0", ] [[package]] @@ -4927,6 +4975,19 @@ dependencies = [ "logos-codegen", ] +[[package]] +name = "loom" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "419e0dc8046cb947daa77eb95ae174acfbddb7673b4151f56d1eed8e93fbfaca" +dependencies = [ + "cfg-if", + "generator", + "scoped-tls", + "tracing", + "tracing-subscriber", +] + [[package]] name = "lru" version = "0.12.5" @@ -5173,6 +5234,25 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "moka" +version = "0.12.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9321642ca94a4282428e6ea4af8cc2ca4eac48ac7a6a4ea8f33f76d0ce70926" +dependencies = [ + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "loom", + "parking_lot", + "portable-atomic", + "rustc_version", + "smallvec", + "tagptr", + "thiserror 1.0.69", + "uuid", +] + [[package]] name = "monostate" version = "0.1.13" @@ -7426,6 +7506,12 @@ dependencies = [ "syn 2.0.87", ] +[[package]] +name = "scoped-tls" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294" + [[package]] name = "scopeguard" version = "1.2.0" @@ -8239,6 +8325,8 @@ version = "3.4.0-pre0" dependencies = [ "anyhow", "chrono", + "deadpool-postgres", + "moka", "native-tls", "postgres-native-tls", "spin-core", @@ -9136,6 +9224,12 @@ dependencies = [ "winx", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tar" version = "0.4.43" @@ -11248,6 +11342,28 @@ dependencies = [ "windows_x86_64_msvc 0.42.2", ] +[[package]] +name = "windows" +version = "0.61.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9babd3a767a4c1aef6900409f85f5d53ce2544ccdfaa86dad48c91782c6d6893" +dependencies = [ + "windows-collections", + "windows-core 0.61.2", + "windows-future", + "windows-link", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3beeceb5e5cfd9eb1d76b381630e82c4241ccd0d27f1a39ed41b2760b255c5e8" +dependencies = [ + "windows-core 0.61.2", +] + [[package]] name = "windows-core" version = "0.52.0" @@ -11257,14 +11373,76 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-core" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result 0.3.4", + "windows-strings 0.4.2", +] + +[[package]] +name = "windows-future" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc6a41e98427b19fe4b73c550f060b59fa592d7d686537eebf9385621bfbad8e" +dependencies = [ + "windows-core 0.61.2", + "windows-link", + "windows-threading", +] + +[[package]] +name = "windows-implement" +version = "0.60.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-interface" +version = "0.59.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.87", +] + +[[package]] +name = "windows-link" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e6ad25900d524eaabdbbb96d20b4311e1e7ae1699af4fb28c17ae66c80d798a" + +[[package]] +name = "windows-numerics" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9150af68066c4c5c07ddc0ce30421554771e528bde427614c61038bc2c92c2b1" +dependencies = [ + "windows-core 0.61.2", + "windows-link", +] + [[package]] name = "windows-registry" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" dependencies = [ - "windows-result", - "windows-strings", + "windows-result 0.2.0", + "windows-strings 0.1.0", "windows-targets 0.52.6", ] @@ -11277,16 +11455,34 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-result" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-strings" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" dependencies = [ - "windows-result", + "windows-result 0.2.0", "windows-targets 0.52.6", ] +[[package]] +name = "windows-strings" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.45.0" @@ -11369,6 +11565,15 @@ dependencies = [ "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-threading" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b66463ad2e0ea3bbf808b7f1d371311c80e115c0b71d60efc142cafbcfb057a6" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.42.2" diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index da7bf6ed14..f7a1aef389 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -7,6 +7,8 @@ edition = { workspace = true } [dependencies] anyhow = { workspace = true } chrono = { workspace = true } +deadpool-postgres = { version = "0.14", features = ["rt_tokio_1"] } +moka = { version = "0.12", features = ["sync"] } native-tls = "0.2" postgres-native-tls = "0.5" spin-core = { path = "../core" } diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs index 9af7ab2aa3..72a686e53c 100644 --- a/crates/factor-outbound-pg/src/client.rs +++ b/crates/factor-outbound-pg/src/client.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use native_tls::TlsConnector; use postgres_native_tls::MakeTlsConnector; use spin_world::async_trait; @@ -6,15 +6,84 @@ use spin_world::spin::postgres::postgres::{ self as v3, Column, DbDataType, DbValue, ParameterValue, RowSet, }; use tokio_postgres::types::Type; -use tokio_postgres::{config::SslMode, types::ToSql, Row}; -use tokio_postgres::{Client as TokioClient, NoTls, Socket}; +use tokio_postgres::{config::SslMode, types::ToSql, NoTls, Row}; +/// Max connections in a given address' connection pool +const CONNECTION_POOL_SIZE: usize = 64; +/// Max addresses for which to keep pools in cache. +const CONNECTION_POOL_CACHE_CAPACITY: u64 = 16; + +/// A factory object for Postgres clients. This abstracts +/// details of client creation such as pooling. #[async_trait] -pub trait Client { - async fn build_client(address: &str) -> Result - where - Self: Sized; +pub trait ClientFactory: Default + Send + Sync + 'static { + /// The type of client produced by `get_client`. + type Client: Client; + /// Gets a client from the factory. + async fn get_client(&self, address: &str) -> Result; +} + +/// A `ClientFactory` that uses a connection pool per address. +pub struct PooledTokioClientFactory { + pools: moka::sync::Cache, +} + +impl Default for PooledTokioClientFactory { + fn default() -> Self { + Self { + pools: moka::sync::Cache::new(CONNECTION_POOL_CACHE_CAPACITY), + } + } +} + +#[async_trait] +impl ClientFactory for PooledTokioClientFactory { + type Client = deadpool_postgres::Object; + + async fn get_client(&self, address: &str) -> Result { + let pool = self + .pools + .try_get_with_by_ref(address, || create_connection_pool(address)) + .map_err(ArcError) + .context("establishing PostgreSQL connection pool")?; + + Ok(pool.get().await?) + } +} + +/// Creates a Postgres connection pool for the given address. +fn create_connection_pool(address: &str) -> Result { + let config = address + .parse::() + .context("parsing Postgres connection string")?; + + tracing::debug!("Build new connection: {}", address); + + let mgr_config = deadpool_postgres::ManagerConfig { + recycling_method: deadpool_postgres::RecyclingMethod::Clean, + }; + + let mgr = if config.get_ssl_mode() == SslMode::Disable { + deadpool_postgres::Manager::from_config(config, NoTls, mgr_config) + } else { + let builder = TlsConnector::builder(); + let connector = MakeTlsConnector::new(builder.build()?); + deadpool_postgres::Manager::from_config(config, connector, mgr_config) + }; + + // TODO: what is our max size heuristic? Should this be passed in so that different + // hosts can manage it according to their needs? Will a plain number suffice for + // sophisticated hosts anyway? + let pool = deadpool_postgres::Pool::builder(mgr) + .max_size(CONNECTION_POOL_SIZE) + .build() + .context("building Postgres connection pool")?; + + Ok(pool) +} +#[async_trait] +pub trait Client: Send + Sync + 'static { async fn execute( &self, statement: String, @@ -29,28 +98,7 @@ pub trait Client { } #[async_trait] -impl Client for TokioClient { - async fn build_client(address: &str) -> Result - where - Self: Sized, - { - let config = address.parse::()?; - - tracing::debug!("Build new connection: {}", address); - - if config.get_ssl_mode() == SslMode::Disable { - let (client, connection) = config.connect(NoTls).await?; - spawn_connection(connection); - Ok(client) - } else { - let builder = TlsConnector::builder(); - let connector = MakeTlsConnector::new(builder.build()?); - let (client, connection) = config.connect(connector).await?; - spawn_connection(connection); - Ok(client) - } - } - +impl Client for deadpool_postgres::Object { async fn execute( &self, statement: String, @@ -67,7 +115,8 @@ impl Client for TokioClient { .map(|b| b.as_ref() as &(dyn ToSql + Sync)) .collect(); - self.execute(&statement, params_refs.as_slice()) + self.as_ref() + .execute(&statement, params_refs.as_slice()) .await .map_err(|e| v3::Error::QueryFailed(format!("{e:?}"))) } @@ -89,6 +138,7 @@ impl Client for TokioClient { .collect(); let results = self + .as_ref() .query(&statement, params_refs.as_slice()) .await .map_err(|e| v3::Error::QueryFailed(format!("{e:?}")))?; @@ -111,17 +161,6 @@ impl Client for TokioClient { } } -fn spawn_connection(connection: tokio_postgres::Connection) -where - T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, -{ - tokio::spawn(async move { - if let Err(e) = connection.await { - tracing::error!("Postgres connection error: {}", e); - } - }); -} - fn to_sql_parameter(value: &ParameterValue) -> Result> { match value { ParameterValue::Boolean(v) => Ok(Box::new(*v)), @@ -373,3 +412,25 @@ impl std::fmt::Debug for PgNull { f.debug_struct("NULL").finish() } } + +/// Workaround for moka returning Arc which, although +/// necessary for concurrency, does not play well with others. +struct ArcError(std::sync::Arc); + +impl std::error::Error for ArcError { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0.source() + } +} + +impl std::fmt::Debug for ArcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Debug::fmt(&self.0, f) + } +} + +impl std::fmt::Display for ArcError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.0, f) + } +} diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index d0b4b52a9b..b40574a0af 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -9,17 +9,18 @@ use tracing::field::Empty; use tracing::instrument; use tracing::Level; -use crate::client::Client; +use crate::client::{Client, ClientFactory}; use crate::InstanceState; -impl InstanceState { +impl InstanceState { async fn open_connection( &mut self, address: &str, ) -> Result, v3::Error> { self.connections .push( - C::build_client(address) + self.client_factory + .get_client(address) .await .map_err(|e| v3::Error::ConnectionFailed(format!("{e:?}")))?, ) @@ -28,9 +29,9 @@ impl InstanceState { } async fn get_client( - &mut self, + &self, connection: Resource, - ) -> Result<&C, v3::Error> { + ) -> Result<&CF::Client, v3::Error> { self.connections .get(connection.rep()) .ok_or_else(|| v3::Error::ConnectionFailed("no connection found".into())) @@ -71,9 +72,7 @@ fn v2_params_to_v3( params.into_iter().map(|p| p.try_into()).collect() } -impl spin_world::spin::postgres::postgres::HostConnection - for InstanceState -{ +impl spin_world::spin::postgres::postgres::HostConnection for InstanceState { #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v3::Error> { spin_factor_outbound_networking::record_address_fields(&address); @@ -122,13 +121,13 @@ impl spin_world::spin::postgres::postgres::HostConnecti } } -impl v2_types::Host for InstanceState { +impl v2_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) } } -impl v3::Host for InstanceState { +impl v3::Host for InstanceState { fn convert_error(&mut self, error: v3::Error) -> Result { Ok(error) } @@ -152,9 +151,9 @@ macro_rules! delegate { }}; } -impl v2::Host for InstanceState {} +impl v2::Host for InstanceState {} -impl v2::HostConnection for InstanceState { +impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_pg.open", skip(self, address), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", db.address = Empty, server.port = Empty, db.namespace = Empty))] async fn open(&mut self, address: String) -> Result, v2::Error> { spin_factor_outbound_networking::record_address_fields(&address); @@ -206,7 +205,7 @@ impl v2::HostConnection for InstanceState { } } -impl v1::Host for InstanceState { +impl v1::Host for InstanceState { async fn execute( &mut self, address: String, diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 4b38ded051..d410ef64f3 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,7 +1,9 @@ pub mod client; mod host; -use client::Client; +use std::sync::Arc; + +use client::ClientFactory; use spin_factor_outbound_networking::{ config::allowed_hosts::OutboundAllowedHosts, OutboundNetworkingFactor, }; @@ -9,16 +11,15 @@ use spin_factors::{ anyhow, ConfigureAppContext, Factor, FactorData, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; -use tokio_postgres::Client as PgClient; -pub struct OutboundPgFactor { - _phantom: std::marker::PhantomData, +pub struct OutboundPgFactor { + _phantom: std::marker::PhantomData, } -impl Factor for OutboundPgFactor { +impl Factor for OutboundPgFactor { type RuntimeConfig = (); - type AppState = (); - type InstanceBuilder = InstanceState; + type AppState = Arc; + type InstanceBuilder = InstanceState; fn init(&mut self, ctx: &mut impl spin_factors::InitContext) -> anyhow::Result<()> { ctx.link_bindings(spin_world::v1::postgres::add_to_linker::<_, FactorData>)?; @@ -33,7 +34,7 @@ impl Factor for OutboundPgFactor { &self, _ctx: ConfigureAppContext, ) -> anyhow::Result { - Ok(()) + Ok(Arc::new(CF::default())) } fn prepare( @@ -45,6 +46,7 @@ impl Factor for OutboundPgFactor { .allowed_hosts(); Ok(InstanceState { allowed_hosts, + client_factory: ctx.app_state().clone(), connections: Default::default(), }) } @@ -64,9 +66,10 @@ impl OutboundPgFactor { } } -pub struct InstanceState { +pub struct InstanceState { allowed_hosts: OutboundAllowedHosts, - connections: spin_resource_table::Table, + client_factory: Arc, + connections: spin_resource_table::Table, } -impl SelfInstanceBuilder for InstanceState {} +impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index ae0ab28767..6c996d8fc9 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -1,6 +1,7 @@ use anyhow::{bail, Result}; use spin_factor_outbound_networking::OutboundNetworkingFactor; use spin_factor_outbound_pg::client::Client; +use spin_factor_outbound_pg::client::ClientFactory; use spin_factor_outbound_pg::OutboundPgFactor; use spin_factor_variables::VariablesFactor; use spin_factors::{anyhow, RuntimeFactors}; @@ -15,14 +16,14 @@ use spin_world::spin::postgres::postgres::{ParameterValue, RowSet}; struct TestFactors { variables: VariablesFactor, networking: OutboundNetworkingFactor, - pg: OutboundPgFactor, + pg: OutboundPgFactor, } fn factors() -> TestFactors { TestFactors { variables: VariablesFactor::default(), networking: OutboundNetworkingFactor::new(), - pg: OutboundPgFactor::::new(), + pg: OutboundPgFactor::::new(), } } @@ -104,17 +105,20 @@ async fn exercise_query() -> anyhow::Result<()> { } // TODO: We can expand this mock to track calls and simulate return values +#[derive(Default)] +pub struct MockClientFactory {} pub struct MockClient {} #[async_trait] -impl Client for MockClient { - async fn build_client(_address: &str) -> anyhow::Result - where - Self: Sized, - { +impl ClientFactory for MockClientFactory { + type Client = MockClient; + async fn get_client(&self, _address: &str) -> Result { Ok(MockClient {}) } +} +#[async_trait] +impl Client for MockClient { async fn execute( &self, _statement: String,