From 49a722ea84013811f1983e4ad4072f0287dbdb50 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Wed, 23 Apr 2025 11:05:02 -0500 Subject: [PATCH 01/10] Support for the PG wire protocol --- Cargo.lock | 310 ++++++++++- Cargo.toml | 7 + crates/cli/src/subcommands/sql.rs | 2 + crates/client-api-messages/src/name.rs | 2 +- crates/client-api/src/auth.rs | 46 +- crates/client-api/src/routes/database.rs | 42 +- crates/core/src/auth/mod.rs | 9 +- crates/pg/Cargo.toml | 27 + crates/pg/LICENSE | 0 crates/pg/README.md | 3 + crates/pg/src/lib.rs | 1 + crates/pg/src/pg_server.rs | 573 +++++++++++++++++++++ crates/sats/src/product_value.rs | 8 + crates/sats/src/satn.rs | 131 +++-- crates/sats/src/time_duration.rs | 14 + crates/sats/src/timestamp.rs | 10 +- crates/standalone/Cargo.toml | 1 + crates/standalone/src/subcommands/start.rs | 21 +- docker-compose.yml | 2 + smoketests/tests/pg_wire.py | 169 ++++++ 20 files changed, 1302 insertions(+), 76 deletions(-) create mode 100644 crates/pg/Cargo.toml create mode 100644 crates/pg/LICENSE create mode 100644 crates/pg/README.md create mode 100644 crates/pg/src/lib.rs create mode 100644 crates/pg/src/pg_server.rs create mode 100644 smoketests/tests/pg_wire.py diff --git a/Cargo.lock b/Cargo.lock index 87c02c91a7b..130b472e256 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,6 +195,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "array-init" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc" + [[package]] name = "arrayref" version = "0.3.9" @@ -207,6 +213,45 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "asn1-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" +dependencies = [ + "asn1-rs-derive", + "asn1-rs-impl", + "displaydoc", + "nom", + "num-traits", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + +[[package]] +name = "asn1-rs-derive" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", + "synstructure 0.13.2", +] + +[[package]] +name = "asn1-rs-impl" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -293,6 +338,30 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-lc-rs" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878" +dependencies = [ + "aws-lc-sys", + "untrusted 0.7.1", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1" +dependencies = [ + "bindgen 0.69.5", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.7.9" @@ -429,6 +498,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.10.5", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.101", + "which 4.4.2", +] + [[package]] name = "bindgen" version = "0.71.1" @@ -444,7 +536,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 2.1.1", "shlex", "syn 2.0.101", ] @@ -925,6 +1017,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "cobs" version = "0.2.3" @@ -1092,7 +1193,7 @@ dependencies = [ "hashbrown 0.14.5", "log", "regalloc2", - "rustc-hash", + "rustc-hash 2.1.1", "smallvec", "target-lexicon", ] @@ -1475,6 +1576,20 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b" +[[package]] +name = "der-parser" +version = "9.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" +dependencies = [ + "asn1-rs", + "displaydoc", + "nom", + "num-bigint", + "num-traits", + "rusticata-macros", +] + [[package]] name = "deranged" version = "0.4.0" @@ -1485,6 +1600,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "derive_arbitrary" version = "1.4.1" @@ -1602,6 +1728,12 @@ dependencies = [ "shared_child", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "educe" version = "0.4.23" @@ -3011,12 +3143,41 @@ dependencies = [ "spacetimedb 1.4.0", ] +[[package]] +name = "lazy-regex" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60c7310b93682b36b98fa7ea4de998d3463ccbebd94d935d6b48ba5b6ffa7126" +dependencies = [ + "lazy-regex-proc_macros", + "once_cell", + "regex-lite", +] + +[[package]] +name = "lazy-regex-proc_macros" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba01db5ef81e17eb10a5e0f2109d1b3a3e29bac3070fdbd7d156bf7dbd206a1" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.101", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "leb128" version = "0.2.5" @@ -3215,6 +3376,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" + [[package]] name = "memchr" version = "2.7.4" @@ -3570,6 +3737,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "oid-registry" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" +dependencies = [ + "asn1-rs", +] + [[package]] name = "once_cell" version = "1.21.3" @@ -3805,6 +3981,30 @@ dependencies = [ "postgres-types", ] +[[package]] +name = "pgwire" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c84e671791f3a354f265e55e400be8bb4b6262c1ec04fac4289e710ccf22ab43" +dependencies = [ + "async-trait", + "aws-lc-rs", + "bytes", + "chrono", + "derive-new", + "futures", + "hex", + "lazy-regex", + "md5", + "postgres-types", + "rand 0.8.5", + "rust_decimal", + "thiserror 2.0.12", + "tokio", + "tokio-rustls", + "tokio-util", +] + [[package]] name = "phf" version = "0.11.3" @@ -3956,6 +4156,7 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" dependencies = [ + "array-init", "bytes", "chrono", "fallible-iterator 0.2.0", @@ -4399,6 +4600,20 @@ dependencies = [ "crossbeam-utils", ] +[[package]] +name = "rcgen" +version = "0.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "x509-parser", + "yasna", +] + [[package]] name = "rdrand" version = "0.4.0" @@ -4445,7 +4660,7 @@ checksum = "12908dbeb234370af84d0579b9f68258a0f67e201412dd9a2814e6f45b2fc0f0" dependencies = [ "hashbrown 0.14.5", "log", - "rustc-hash", + "rustc-hash 2.1.1", "slice-group-by", "smallvec", ] @@ -4482,6 +4697,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -4623,7 +4844,7 @@ dependencies = [ "cfg-if", "getrandom 0.2.16", "libc", - "untrusted", + "untrusted 0.9.0", "windows-sys 0.52.0", ] @@ -4734,6 +4955,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -4749,6 +4976,15 @@ dependencies = [ "semver", ] +[[package]] +name = "rusticata-macros" +version = "4.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" +dependencies = [ + "nom", +] + [[package]] name = "rustix" version = "0.38.44" @@ -4781,6 +5017,8 @@ version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ + "aws-lc-rs", + "log", "once_cell", "rustls-pki-types", "rustls-webpki", @@ -4821,9 +5059,10 @@ version = "0.103.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7149975849f1abb3832b246010ef62ccc80d3a76169517ada7188252b9cfb437" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -5674,7 +5913,7 @@ dependencies = [ "regex", "reqwest 0.12.15", "rustc-demangle", - "rustc-hash", + "rustc-hash 2.1.1", "scopeguard", "semver", "serde", @@ -5960,6 +6199,29 @@ dependencies = [ "xdg", ] +[[package]] +name = "spacetimedb-pg" +version = "1.4.0" +dependencies = [ + "anyhow", + "async-trait", + "axum", + "futures", + "http 1.3.1", + "log", + "pgwire", + "rcgen", + "rustls", + "rustls-pki-types", + "serde_json", + "spacetimedb-client-api", + "spacetimedb-client-api-messages", + "spacetimedb-lib 1.4.0", + "thiserror 1.0.69", + "tokio", + "tokio-rustls", +] + [[package]] name = "spacetimedb-physical-plan" version = "1.4.0" @@ -6207,6 +6469,7 @@ dependencies = [ "spacetimedb-datastore", "spacetimedb-lib 1.4.0", "spacetimedb-paths", + "spacetimedb-pg", "spacetimedb-schema", "spacetimedb-table", "tempfile", @@ -7397,6 +7660,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" @@ -7477,7 +7746,7 @@ version = "137.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ca393e2032ddba2a57169e15cac5d0a81cdb3d872a8886f4468bc0f486098d2" dependencies = [ - "bindgen", + "bindgen 0.71.1", "bitflags 2.9.0", "fslock", "gzip-header", @@ -8509,6 +8778,24 @@ dependencies = [ "tap", ] +[[package]] +name = "x509-parser" +version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" +dependencies = [ + "asn1-rs", + "data-encoding", + "der-parser", + "lazy_static", + "nom", + "oid-registry", + "ring", + "rusticata-macros", + "thiserror 1.0.69", + "time", +] + [[package]] name = "xattr" version = "1.5.0" @@ -8555,6 +8842,15 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "yoke" version = "0.7.5" diff --git a/Cargo.toml b/Cargo.toml index 76529fe1ef9..a18ace7c8c4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/lib", "crates/metrics", "crates/paths", + "crates/pg", "crates/physical-plan", "crates/primitives", "crates/query", @@ -115,6 +116,7 @@ spacetimedb-lib = { path = "crates/lib", default-features = false, version = "1. spacetimedb-memory-usage = { path = "crates/memory-usage", version = "1.4.0", default-features = false } spacetimedb-metrics = { path = "crates/metrics", version = "1.4.0" } spacetimedb-paths = { path = "crates/paths", version = "1.4.0" } +spacetimedb-pg = { path = "crates/pg", version = "1.4.0" } spacetimedb-physical-plan = { path = "crates/physical-plan", version = "1.4.0" } spacetimedb-primitives = { path = "crates/primitives", version = "1.4.0" } spacetimedb-query = { path = "crates/query", version = "1.4.0" } @@ -214,6 +216,7 @@ paste = "1.0" percent-encoding = "2.3" petgraph = { version = "0.6.5", default-features = false } pin-project-lite = "0.2.9" +pgwire = { version = "0.28.0", features = ["server-api"] } postgres-types = "0.2.5" pretty_assertions = { version = "1.4", features = ["unstable"] } proc-macro2 = "1.0" @@ -226,6 +229,7 @@ rand08 = { package = "rand", version = "0.8" } rand = "0.9" rayon = "1.8" rayon-core = "1.11.0" +rcgen = { version = "0.13.1", features = ["pem", "x509-parser", "crypto", "ring"] } regex = "1" reqwest = { version = "0.12", features = ["stream", "json"] } ron = "0.8" @@ -234,6 +238,8 @@ rust_decimal = { version = "1.29.1", features = ["db-tokio-postgres"] } rustc-demangle = "0.1.21" rustc-hash = "2" rustyline = { version = "12.0.0", features = [] } +rustls-pki-types = "1.11.0" +rustls = "0.23.26" scoped-tls = "1.0.1" scopeguard = "1.1.0" second-stack = "0.3" @@ -265,6 +271,7 @@ termcolor = "1.2.0" thin-vec = "0.2.13" thiserror = "1.0.37" tokio = { version = "1.37", features = ["full"] } +tokio-rustls = "0.26.2" tokio_metrics = { version = "0.4.0" } tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4"] } tokio-stream = "0.1.17" diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 12ad41a496f..27bf77edaf7 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -10,6 +10,7 @@ use anyhow::Context; use clap::{Arg, ArgAction, ArgMatches}; use reqwest::RequestBuilder; use spacetimedb_lib::de::serde::SeedWrapper; +use spacetimedb_lib::sats::satn::PsqlClient; use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace}; pub fn cli() -> clap::Command { @@ -211,6 +212,7 @@ fn build_table( let row = row?; builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| { let ty = satn::PsqlType { + client: PsqlClient::SpacetimeDB, tuple: ty.ty(), field: &ty.ty().elements[idx], idx, diff --git a/crates/client-api-messages/src/name.rs b/crates/client-api-messages/src/name.rs index 5fac594f9c8..49cee23a156 100644 --- a/crates/client-api-messages/src/name.rs +++ b/crates/client-api-messages/src/name.rs @@ -171,7 +171,7 @@ pub enum SetDefaultDomainResult { /// /// Must match the regex `^[a-z0-9]+(-[a-z0-9]+)*$` #[derive(Clone, Debug, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)] -pub struct DatabaseName(String); +pub struct DatabaseName(pub String); impl AsRef for DatabaseName { fn as_ref(&self) -> &str { diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 61031625867..d87f24a10f1 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -132,6 +132,30 @@ impl TokenClaims { } impl SpacetimeAuth { + pub fn from_claims( + ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized), + claims: SpacetimeIdentityClaims, + ) -> axum::response::Result { + let claims = TokenClaims { + issuer: claims.issuer, + subject: claims.subject, + audience: claims.audience, + }; + + let creds = { + let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; + SpacetimeCreds::from_signed_token(token) + }; + let identity = claims.id(); + + Ok(Self { + creds, + identity, + subject: claims.subject, + issuer: claims.issuer, + }) + } + /// Allocate a new identity, and mint a new token for it. pub async fn alloc(ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized)) -> axum::response::Result { // Generate claims with a random subject. @@ -186,6 +210,8 @@ pub trait JwtAuthProvider: Sync + Send + TokenSigner { /// /// The `/identity/public-key` route calls this method to return the public key to callers. fn public_key_bytes(&self) -> &[u8]; + /// Return the private key used to verify JWTs, as the bytes of a PEM private key file. + fn private_key_bytes(&self) -> &[u8]; } pub struct JwtKeyAuthProvider { @@ -222,6 +248,10 @@ impl TokenSigner for JwtKeyAuthProvider { impl JwtAuthProvider for JwtKeyAuthProvider { type TV = TV; + fn validator(&self) -> &Self::TV { + &self.validator + } + fn local_issuer(&self) -> &str { &self.local_issuer } @@ -230,8 +260,8 @@ impl JwtAuthProvider for JwtKeyAuthProvider &Self::TV { - &self.validator + fn private_key_bytes(&self) -> &[u8] { + &self.keys.private_pem } } @@ -260,6 +290,13 @@ mod tests { } } +pub async fn validate_token( + state: &S, + token: &str, +) -> Result { + state.jwt_auth_provider().validator().validate_token(token).await +} + pub struct SpacetimeAuthHeader { auth: Option, } @@ -272,10 +309,7 @@ impl axum::extract::FromRequestParts for Space return Ok(Self { auth: None }); }; - let claims = state - .jwt_auth_provider() - .validator() - .validate_token(&creds.token) + let claims = validate_token(state, &creds.token) .await .map_err(AuthorizationRejection::Custom)?; diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index 52b2458bf08..f3a137a3e0d 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -7,7 +7,7 @@ use crate::auth::{ SpacetimeIdentityToken, }; use crate::routes::subscribe::generate_random_connection_id; -use crate::util::{ByteStringBody, NameOrIdentity}; +pub use crate::util::{ByteStringBody, NameOrIdentity}; use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate}; use axum::body::{Body, Bytes}; use axum::extract::{Path, Query, State}; @@ -31,7 +31,7 @@ use spacetimedb_client_api_messages::name::{ }; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{sats, Timestamp}; +use spacetimedb_lib::{sats, ProductValue, Timestamp}; use spacetimedb_schema::auto_migrate::{ MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle, }; @@ -383,7 +383,7 @@ pub(crate) async fn worker_ctx_find_database( #[derive(Deserialize)] pub struct SqlParams { - name_or_identity: NameOrIdentity, + pub name_or_identity: NameOrIdentity, } #[derive(Deserialize)] @@ -391,16 +391,16 @@ pub struct SqlQueryParams { /// If `true`, return the query result only after its transaction offset /// is confirmed to be durable. #[serde(default)] - confirmed: bool, + pub confirmed: bool, } -pub async fn sql( - State(worker_ctx): State, - Path(SqlParams { name_or_identity }): Path, - Query(SqlQueryParams { confirmed }): Query, - Extension(auth): Extension, - body: String, -) -> axum::response::Result +pub async fn sql_direct( + worker_ctx: S, + SqlParams { name_or_identity }: SqlParams, + SqlQueryParams { confirmed }: SqlQueryParams, + caller_identity: Identity, + sql: String, +) -> axum::response::Result>> where S: NodeDelegate + ControlStateDelegate, { @@ -412,7 +412,7 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - let auth = AuthCtx::new(database.owner_identity, auth.identity); + let auth = AuthCtx::new(database.owner_identity, caller_identity); log::debug!("auth: {auth:?}"); let host = worker_ctx @@ -420,7 +420,21 @@ where .await .map_err(log_and_500)? .ok_or(StatusCode::NOT_FOUND)?; - let json = host.exec_sql(auth, database, confirmed, body).await?; + + host.exec_sql(auth, database, confirmed, sql).await +} + +pub async fn sql( + State(worker_ctx): State, + Path(name_or_identity): Path, + Query(params): Query, + Extension(auth): Extension, + body: String, +) -> axum::response::Result +where + S: NodeDelegate + ControlStateDelegate, +{ + let json = sql_direct(worker_ctx, name_or_identity, params, auth.identity, body).await?; let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros); @@ -488,7 +502,9 @@ pub struct PublishDatabaseQueryParams { policy: MigrationPolicy, } +use spacetimedb_client_api_messages::http::SqlStmtResult; use std::env; + fn require_spacetime_auth_for_creation() -> bool { env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty()) } diff --git a/crates/core/src/auth/mod.rs b/crates/core/src/auth/mod.rs index f9c381902e1..e1e38a667c4 100644 --- a/crates/core/src/auth/mod.rs +++ b/crates/core/src/auth/mod.rs @@ -15,6 +15,7 @@ pub struct JwtKeys { pub public: DecodingKey, pub public_pem: Box<[u8]>, pub private: EncodingKey, + pub private_pem: Box<[u8]>, pub kid: Option, } @@ -23,15 +24,17 @@ impl JwtKeys { /// respectively. /// /// The key files must be PEM encoded ECDSA P256 keys. - pub fn new(public_pem: impl Into>, private_pem: &[u8]) -> anyhow::Result { + pub fn new(public_pem: impl Into>, private_pem: impl Into>) -> anyhow::Result { let public_pem = public_pem.into(); + let private_pem = private_pem.into(); let public = DecodingKey::from_ec_pem(&public_pem)?; - let private = EncodingKey::from_ec_pem(private_pem)?; + let private = EncodingKey::from_ec_pem(&private_pem)?; Ok(Self { public, private, public_pem, + private_pem, kid: None, }) } @@ -75,7 +78,7 @@ pub struct EcKeyPair { impl TryFrom for JwtKeys { type Error = anyhow::Error; fn try_from(pair: EcKeyPair) -> anyhow::Result { - JwtKeys::new(pair.public_key_bytes, &pair.private_key_bytes) + JwtKeys::new(pair.public_key_bytes, pair.private_key_bytes) } } diff --git a/crates/pg/Cargo.toml b/crates/pg/Cargo.toml new file mode 100644 index 00000000000..68e41d746b5 --- /dev/null +++ b/crates/pg/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "spacetimedb-pg" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license-file = "LICENSE" +description = "Postgres wire protocol Server support for SpacetimeDB" + +[dependencies] +spacetimedb-client-api-messages.workspace = true +spacetimedb-client-api.workspace = true +spacetimedb-lib.workspace = true + +anyhow.workspace = true +async-trait.workspace = true +axum.workspace = true +futures.workspace = true +http.workspace = true +log.workspace = true +pgwire.workspace = true +rustls-pki-types.workspace = true +rcgen.workspace = true +rustls.workspace = true +thiserror.workspace = true +tokio.workspace = true +tokio-rustls.workspace = true +serde_json.workspace = true diff --git a/crates/pg/LICENSE b/crates/pg/LICENSE new file mode 100644 index 00000000000..e69de29bb2d diff --git a/crates/pg/README.md b/crates/pg/README.md new file mode 100644 index 00000000000..fc5b684dd86 --- /dev/null +++ b/crates/pg/README.md @@ -0,0 +1,3 @@ +> ⚠️ **Internal Crate** ⚠️ +> +> This crate is intended for internal use only. It is **not** stable and may change without notice. diff --git a/crates/pg/src/lib.rs b/crates/pg/src/lib.rs new file mode 100644 index 00000000000..00efd53b3e9 --- /dev/null +++ b/crates/pg/src/lib.rs @@ -0,0 +1 @@ +pub mod pg_server; diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs new file mode 100644 index 00000000000..6e8fdc94574 --- /dev/null +++ b/crates/pg/src/pg_server.rs @@ -0,0 +1,573 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use async_trait::async_trait; +use axum::body::to_bytes; +use axum::response::IntoResponse; +use futures::{stream, Sink}; +use futures::{SinkExt, Stream}; +use http::StatusCode; +use pgwire::api::auth::{ + finish_authentication, save_startup_parameters_to_metadata, DefaultServerParameterProvider, LoginInfo, + StartupHandler, +}; +use pgwire::api::copy::NoopCopyHandler; +use pgwire::api::portal::Format; +use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; +use pgwire::api::{ClientInfo, Type}; +use pgwire::api::{NoopErrorHandler, METADATA_DATABASE, METADATA_USER}; +use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use pgwire::messages::startup::Authentication; +use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; +use pgwire::tokio::process_socket; +use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair}; +use rustls_pki_types::pem::PemObject; +use rustls_pki_types::PrivateKeyDer; +use spacetimedb_client_api::auth::{validate_token, SpacetimeAuth}; +use spacetimedb_client_api::routes::database; +use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; +use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; +use spacetimedb_client_api_messages::http::SqlStmtResult; +use spacetimedb_client_api_messages::name::DatabaseName; +use spacetimedb_lib::sats::satn::{PsqlPrintFmt, Satn}; +use spacetimedb_lib::sats::ArrayValue; +use spacetimedb_lib::version::spacetimedb_lib_version; +use spacetimedb_lib::{ + AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, +}; +use thiserror::Error; +use tokio::net::TcpListener; +use tokio::sync::{watch, Mutex}; +use tokio_rustls::rustls::ServerConfig; +use tokio_rustls::TlsAcceptor; + +#[derive(Error, Debug)] +enum PgError { + #[error("(metadata) {0}")] + MetadataError(anyhow::Error), + #[error("(Sql) {0}")] + Sql(String), + #[error("Database name is required")] + DatabaseNameRequired, + #[error(transparent)] + Pg(#[from] PgWireError), + #[error(transparent)] + RcGen(#[from] rcgen::Error), + #[error(transparent)] + Pem(#[from] rustls_pki_types::pem::Error), + #[error(transparent)] + RustTls(#[from] rustls::Error), + #[error("Special type with format {0} is invalid for {1}")] + SpecialTypeInvalid(PsqlPrintFmt, String), + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for PgWireError { + fn from(err: PgError) -> Self { + if let PgError::Pg(err) = err { + err + } else { + PgWireError::ApiError(Box::new(err)) + } + } +} + +#[derive(Clone)] +struct Metadata { + database: String, + auth: SpacetimeAuth, +} + +fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { + let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); + match &ty.algebraic_type { + AlgebraicType::String => Type::VARCHAR, + AlgebraicType::Bool => Type::BOOL, + AlgebraicType::I8 | AlgebraicType::U8 | AlgebraicType::I16 => Type::INT2, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8, + AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => { + Type::NUMERIC + } + AlgebraicType::F32 => Type::FLOAT4, + AlgebraicType::F64 => Type::FLOAT8, + AlgebraicType::Array(ty) => match *ty.elem_ty { + AlgebraicType::String => Type::VARCHAR_ARRAY, + AlgebraicType::Bool => Type::BOOL_ARRAY, + AlgebraicType::U8 => Type::BYTEA_ARRAY, + AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY, + AlgebraicType::U64 + | AlgebraicType::I128 + | AlgebraicType::U128 + | AlgebraicType::I256 + | AlgebraicType::U256 => Type::NUMERIC_ARRAY, + _ => Type::ANYARRAY, + }, + AlgebraicType::Product(_) => match format { + PsqlPrintFmt::Hex => Type::BYTEA_ARRAY, + PsqlPrintFmt::Timestamp => Type::TIMESTAMP, + PsqlPrintFmt::Duration => Type::INTERVAL, + _ => Type::JSON, + }, + x if x.as_sum().map(|x| x.is_simple_enum()).unwrap_or(false) => Type::ANYENUM, + _ => Type::UNKNOWN, + } +} + +fn encode_value( + encoder: &mut DataRowEncoder, + schema: &ProductType, + ty: &ProductTypeElement, + value: &AlgebraicValue, +) -> Result<(), PgError> { + let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); + + match value { + AlgebraicValue::Bool(x) => encoder.encode_field(x)?, + AlgebraicValue::I8(x) => encoder.encode_field(x)?, + AlgebraicValue::U8(x) => encoder.encode_field(&(*x as i16))?, + AlgebraicValue::I16(x) => encoder.encode_field(x)?, + AlgebraicValue::U16(x) => encoder.encode_field(&(*x as u32))?, + AlgebraicValue::I32(x) => encoder.encode_field(x)?, + AlgebraicValue::U32(x) => encoder.encode_field(x)?, + AlgebraicValue::I64(x) => encoder.encode_field(&x)?, + AlgebraicValue::U64(x) => encoder.encode_field(&x.to_string())?, + AlgebraicValue::I128(x) => { + let x = x.0; + encoder.encode_field(&(x.to_string()))? + } + AlgebraicValue::U128(x) => { + let x = x.0; + encoder.encode_field(&x.to_string())?; + } + AlgebraicValue::I256(x) => encoder.encode_field(&(x.to_string()))?, + AlgebraicValue::U256(x) => encoder.encode_field(&x.to_string())?, + AlgebraicValue::F32(x) => encoder.encode_field(&x.into_inner())?, + AlgebraicValue::F64(x) => encoder.encode_field(&x.into_inner())?, + AlgebraicValue::String(x) => encoder.encode_field(&x.to_string())?, + AlgebraicValue::Array(x) => match x { + ArrayValue::Bool(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::I8(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::U8(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::I16(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::I32(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::U32(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::I64(x) => { + encoder.encode_field(&x.as_ref())?; + } + ArrayValue::F32(x) => { + let x = x.iter().map(|x| x.into_inner()).collect::>(); + encoder.encode_field(&x)?; + } + ArrayValue::F64(x) => { + let x = x.iter().map(|x| x.into_inner()).collect::>(); + encoder.encode_field(&x)?; + } + ArrayValue::String(x) => { + let x = x.iter().map(|x| x.as_ref()).collect::>(); + encoder.encode_field(&x)?; + } + _ => { + let json = serde_json::to_string(&value).unwrap(); + encoder.encode_field(&json)?; + } + }, + AlgebraicValue::Product(x) => match (format, x.as_special_value_raw()) { + (PsqlPrintFmt::Hex, Some(AlgebraicValue::U128(x))) => encoder.encode_field(&x.0.to_be_bytes())?, + (PsqlPrintFmt::Hex, Some(AlgebraicValue::U256(x))) => encoder.encode_field(&x.to_be_bytes())?, + (PsqlPrintFmt::Timestamp, Some(AlgebraicValue::I64(x))) => { + encoder.encode_field(&Timestamp::from_micros_since_unix_epoch(*x).to_rfc3339()?)? + } + (PsqlPrintFmt::Duration, Some(AlgebraicValue::I64(x))) => { + encoder.encode_field(&TimeDuration::from_micros(*x).to_iso8601())? + } + (PsqlPrintFmt::Satn, Some(..)) + | (PsqlPrintFmt::Hex | PsqlPrintFmt::Timestamp | PsqlPrintFmt::Duration, _) => { + return Err(PgError::SpecialTypeInvalid(format, value.to_satn())) + } + (PsqlPrintFmt::Satn, None) => { + let json = serde_json::to_string(&value).unwrap(); + + encoder.encode_field(&json)? + } + }, + x => encoder.encode_field(&x.to_satn())?, + } + + Ok(()) +} + +fn to_rows( + stmt: SqlStmtResult, + header: Arc>, +) -> Result>, PgError> { + let mut results = Vec::with_capacity(stmt.rows.len()); + + for row in stmt.rows { + let mut encoder = DataRowEncoder::new(header.clone()); + + for (idx, ty) in stmt.schema.elements.iter().enumerate() { + let value = row.get_field(idx, None).unwrap(); + + encode_value(&mut encoder, &stmt.schema, ty, value)?; + } + results.push(encoder.finish()); + } + Ok(stream::iter(results)) +} + +fn row_desc_from_stmt(stmt: &SqlStmtResult, format: &Format) -> Vec { + let mut field_info = Vec::with_capacity(stmt.schema.elements.len()); + for (idx, ty) in stmt.schema.elements.iter().enumerate() { + let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col {idx}")); + let field_type = type_of(&stmt.schema, ty); + let field_desc = FieldInfo::new(field_name, None, None, field_type, format.format_for(idx)); + field_info.push(field_desc); + } + field_info +} + +fn stats(stmt: &SqlStmtResult) -> String { + let mut info = Vec::new(); + if stmt.stats.rows_inserted != 0 { + info.push(format!("inserted: {}", stmt.stats.rows_inserted)); + } + if stmt.stats.rows_deleted != 0 { + info.push(format!("deleted: {}", stmt.stats.rows_deleted)); + } + if stmt.stats.rows_updated != 0 { + info.push(format!("updated: {}", stmt.stats.rows_updated)); + } + info.push(format!( + "server: {:.2?}", + std::time::Duration::from_micros(stmt.total_duration_micros) + )); + + info.join(", ") +} + +struct ResponseWrapper(T); +impl IntoResponse for ResponseWrapper { + fn into_response(self) -> axum::response::Response { + unreachable!("Blank impl to satisfy IntoResponse") + } +} + +async fn response(res: axum::response::Result, database: &str) -> Result { + match res.map(ResponseWrapper) { + Ok(sql) => Ok(sql.0), + err => { + let res = err.into_response(); + if res.status() == StatusCode::NOT_FOUND { + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "FATAL".to_string(), + "3D000".to_string(), + format!("database \"{database}\" does not exist"), + ))) + .into()); + } + let bytes = to_bytes(res.into_body(), usize::MAX) + .await + .map_err(|err| PgWireError::ApiError(Box::new(err)))?; + let err = String::from_utf8_lossy(&bytes); + Err(PgError::Sql(format!("{err}"))) + } + } +} + +struct PgSpacetimeDB { + ctx: Arc, + cached: Mutex, + parameter_provider: DefaultServerParameterProvider, +} + +impl PgSpacetimeDB { + async fn exe_sql<'a>(&self, query: String) -> PgWireResult>> { + let params = self.cached.lock().await; + let db = SqlParams { + name_or_identity: database::NameOrIdentity::Name(DatabaseName(params.database.clone())), + }; + + let sql = match response( + database::sql_direct( + self.ctx.clone(), + db, + SqlQueryParams { confirmed: true }, + params.auth.identity, + query.to_string(), + ) + .await, + ¶ms.database, + ) + .await + { + Ok(sql) => sql, + Err(PgError::Pg(PgWireError::UserError(err))) => { + return Ok(vec![Response::Error(err)]); + } + Err(err) => { + return Err(err.into()); + } + }; + + let mut result = Vec::with_capacity(sql.len()); + for sql_result in sql { + let header = Arc::new(row_desc_from_stmt(&sql_result, &Format::UnifiedText)); + + let tag = Tag::new(&stats(&sql_result)); + + if sql_result.rows.is_empty() { + result.push(Response::EmptyQuery); + } else { + let rows = to_rows(sql_result, header.clone())?; + result.push(Response::Query(QueryResponse::new(header, rows))); + } + result.push(Response::Execution(tag)); + } + Ok(result) + } +} + +async fn close_client(client: &mut C, err: E) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + pgwire::messages::response::ErrorResponse: From, +{ + let err = pgwire::messages::response::ErrorResponse::from(err); + client.feed(PgWireBackendMessage::ErrorResponse(err)).await?; + client.close().await?; + Ok(()) +} + +#[async_trait] +impl StartupHandler + for PgSpacetimeDB +{ + async fn on_startup(&self, client: &mut C, message: PgWireFrontendMessage) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + { + match message { + PgWireFrontendMessage::Startup(ref startup) => { + save_startup_parameters_to_metadata(client, startup); + client.set_state(PgWireConnectionState::AuthenticationInProgress); + + let login_info = LoginInfo::from_client_info(client); + + log::debug!("PG: Login info: {login_info:?}"); + + if login_info.database().is_none() { + return Err(PgError::DatabaseNameRequired.into()); + } + + client + .send(PgWireBackendMessage::Authentication(Authentication::CleartextPassword)) + .await?; + } + PgWireFrontendMessage::PasswordMessageFamily(pwd) => { + let params = client.metadata(); + let param = |param: &str| { + params + .get(param) + .map(String::from) + .ok_or_else(|| PgError::MetadataError(anyhow::anyhow!("Missing parameter: {}", param))) + }; + + let user = param(METADATA_USER)?; + let database = param(METADATA_DATABASE)?; + let pwd = pwd.into_password()?; + if let Ok(application_name) = param("application_name") { + log::info!("PG: Connecting to database: {user}@{database}, by {application_name}",); + } else { + log::info!("PG: Connecting to database: {user}@{database}"); + } + + let name = database::NameOrIdentity::Name(DatabaseName(database.clone())); + match response(name.resolve(&self.ctx).await, &database).await { + Ok(identity) => identity, + Err(PgError::Pg(PgWireError::UserError(err))) => { + return close_client(client, *err).await; + } + Err(err) => { + return Err(err.into()); + } + }; + + let auth = match validate_token(&self.ctx, &pwd.password).await { + Ok(claims) => response(SpacetimeAuth::from_claims(&self.ctx, claims), &database).await?, + Err(err) => { + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; + } + }; + + log::info!( + "PG: Connected to database: {user}@{database} using identity `{}`", + auth.identity + ); + + let metadata = Metadata { database, auth }; + self.cached.lock().await.clone_from(&metadata); + finish_authentication(client, &self.parameter_provider).await?; + } + _ => {} + } + Ok(()) + } +} + +#[async_trait] +impl SimpleQueryHandler + for PgSpacetimeDB +{ + async fn do_query<'a, C>(&self, _client: &mut C, query: &'a str) -> PgWireResult>> + where + C: ClientInfo + Unpin + Send + Sync, + { + self.exe_sql(query.to_string()).await + } +} + +#[derive(Clone)] +struct PgSpacetimeDBFactory { + handler: Arc>, +} + +impl PgSpacetimeDBFactory { + pub fn new(ctx: Arc, auth: SpacetimeAuth) -> Self { + let mut parameter_provider = DefaultServerParameterProvider::default(); + parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version()); + + Self { + handler: Arc::new(PgSpacetimeDB { + ctx, + cached: Mutex::new(Metadata { + // This is a placeholder, it will be set in the startup handler + database: "".to_string(), + auth, + }), + parameter_provider, + }), + } + } +} + +impl PgWireServerHandlers + for PgSpacetimeDBFactory +{ + type StartupHandler = PgSpacetimeDB; + type SimpleQueryHandler = PgSpacetimeDB; + type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; + type CopyHandler = NoopCopyHandler; + type ErrorHandler = NoopErrorHandler; + + fn simple_query_handler(&self) -> Arc { + self.handler.clone() + } + + fn extended_query_handler(&self) -> Arc { + Arc::new(PlaceholderExtendedQueryHandler) + } + + fn startup_handler(&self) -> Arc { + self.handler.clone() + } + + fn copy_handler(&self) -> Arc { + Arc::new(NoopCopyHandler) + } + + fn error_handler(&self) -> Arc { + Arc::new(NoopErrorHandler) + } +} + +fn setup_tls(_ctx: &T, private_key: &[u8]) -> Result { + let private: PrivateKeyDer = PrivateKeyDer::from_pem_slice(private_key)?; + + let keypair = KeyPair::from_der_and_sign_algo(&private, &rcgen::PKCS_ECDSA_P256_SHA256)?; + + let mut params = CertificateParams::new(vec![ + "localhost".to_string(), + "127.0.0.1".to_string(), + "::1".to_string(), + ])?; + params.distinguished_name = DistinguishedName::new(); + params.distinguished_name.push(DnType::CommonName, "localhost"); + let cert = params.self_signed(&keypair)?; + let cert_der = cert.der().clone(); + + let mut config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert_der], private)?; + + config.alpn_protocols = vec![b"postgresql".to_vec(), b"spacetime".to_vec()]; + + Ok(TlsAcceptor::from(Arc::new(config))) +} + +pub async fn start_pg( + mut shutdown: watch::Receiver<()>, + ctx: Arc, + listen_address: &str, + private_key: &[u8], +) { + let tls_acceptor = Arc::new(setup_tls(&ctx, private_key).unwrap()); + + let auth = SpacetimeAuth::alloc(&ctx).await.unwrap(); + let factory = Arc::new(PgSpacetimeDBFactory::new(ctx, auth)); + + let server_addr = format!("{}:5432", listen_address.split(':').next().unwrap()); + let tcp = TcpListener::bind(server_addr).await.unwrap(); + + log::debug!( + "PG: Starting SpacetimeDB Protocol listening on {}", + tcp.local_addr().unwrap() + ); + loop { + tokio::select! { + accept_result = tcp.accept() => { + match accept_result { + Ok((stream, _addr)) => { + let tls_acceptor_ref = tls_acceptor.clone(); + let factory_ref = factory.clone(); + tokio::spawn(async move { + process_socket(stream, Some(tls_acceptor_ref), factory_ref).await.inspect_err(|err|{ + log::error!("PG: Error processing socket: {err:?}"); + }) + }); + } + Err(e) => { + log::error!("PG: Accept error: {e}"); + } + } + } + _ = shutdown.changed() => { + log::info!("PG: Shutting down PostgreSQL server."); + break; + } + } + } +} diff --git a/crates/sats/src/product_value.rs b/crates/sats/src/product_value.rs index 5db2b7c534d..68aaf6433ce 100644 --- a/crates/sats/src/product_value.rs +++ b/crates/sats/src/product_value.rs @@ -114,6 +114,14 @@ impl ProductValue { }) } + /// Assumes that we are dealing with a [`ProductType::is_special`] value if it has a single element. + pub fn as_special_value_raw(&self) -> Option<&AlgebraicValue> { + match &*self.elements { + [val] => Some(val), + _ => None, + } + } + /// Interprets the value at field of `self` identified by `index` as a `bool`. pub fn field_as_bool(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_bool().copied()) diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 4462609520e..504f59c4332 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -5,7 +5,7 @@ use crate::{i256, u256, AlgebraicValue, WithTypespace}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; -use derive_more::{From, Into}; +use derive_more::{Display, From, Into}; /// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods. pub trait Satn: ser::Serialize { @@ -231,7 +231,7 @@ struct SatnFormatter<'a, 'f> { /// An error occurred during serialization to the SATS data format. #[derive(From, Into)] -struct SatnError(fmt::Error); +pub struct SatnError(fmt::Error); impl ser::Error for SatnError { fn custom(_msg: T) -> Self { @@ -435,7 +435,7 @@ struct PsqlEntryWrapper<'a, 'f, const SEP: char> { } /// Provides the data format for named products for `SQL`. -struct PsqlNamedFormatter<'a, 'f> { +pub struct PsqlNamedFormatter<'a, 'f> { /// The formatter for each element separating elements by a `,`. f: PsqlEntryWrapper<'a, 'f, ','>, /// If is not [Self::is_special] to control if we start with `(` @@ -445,7 +445,7 @@ struct PsqlNamedFormatter<'a, 'f> { } impl<'a, 'f> PsqlNamedFormatter<'a, 'f> { - pub fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self { + fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self { Self { start: true, f: PsqlEntryWrapper { @@ -472,10 +472,20 @@ impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { // We need to check for both the enclosing(`self.f.ty`) type and the inner element(`name`) type. self.use_fmt = self.f.ty.use_fmt(name); let res = self.f.entry.entry(|mut f| { - let PsqlType { tuple, field, idx } = self.f.ty; + let PsqlType { + client, + tuple, + field, + idx, + } = self.f.ty; if !self.use_fmt.is_special() { + let start = match client { + PsqlClient::SpacetimeDB => "(", + PsqlClient::Postgres => "{", + }; if self.start { - write!(f, "(")?; + write!(f, "{start}")?; + self.start = false; } // Format the name or use the index if unnamed. @@ -484,7 +494,11 @@ impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { } else { write!(f, "{idx}")?; } - write!(f, " = ")?; + let sep = match client { + PsqlClient::SpacetimeDB => "=", + PsqlClient::Postgres => ":", + }; + write!(f, " {sep} ")?; } //Is a nested product type? let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() { @@ -495,7 +509,12 @@ impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { elem.serialize(PsqlFormatter { fmt: SatnFormatter { f }, - ty: &PsqlType { tuple, field, idx }, + ty: &PsqlType { + client: *client, + tuple, + field, + idx, + }, })?; Ok(()) @@ -513,7 +532,11 @@ impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { fn end(mut self) -> Result { if !self.use_fmt.is_special() { - write!(self.f.entry.fmt, ")")?; + let end = match self.f.ty.client { + PsqlClient::SpacetimeDB => ")", + PsqlClient::Postgres => "}", + }; + write!(self.f.entry.fmt, "{end}")?; } Ok(()) } @@ -538,8 +561,15 @@ impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> { } } +/// Which client is used to format the `SQL` output? +#[derive(PartialEq, Copy, Clone, Debug)] +pub enum PsqlClient { + SpacetimeDB, + Postgres, +} + /// How format of the `SQL` output? -#[derive(PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Display)] pub enum PsqlPrintFmt { /// Print as `hex` format Hex, @@ -552,46 +582,32 @@ pub enum PsqlPrintFmt { } impl PsqlPrintFmt { - fn is_special(&self) -> bool { + pub fn is_special(&self) -> bool { self != &PsqlPrintFmt::Satn } -} - -/// A wrapper that remember the `header` of the tuple/struct and the current field -#[derive(Debug, Clone)] -pub struct PsqlType<'a> { - /// The header of the tuple/struct - pub tuple: &'a ProductType, - /// The current field - pub field: &'a ProductTypeElement, - /// The index of the field in the tuple/struct - pub idx: usize, -} - -impl PsqlType<'_> { /// Returns if the type is a special type /// /// Is required to check both the enclosing type and the inner element type - fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt { - if self.tuple.is_identity() - || self.tuple.is_connection_id() - || self.field.algebraic_type.is_identity() - || self.field.algebraic_type.is_connection_id() + pub fn use_fmt(tuple: &ProductType, field: &ProductTypeElement, name: Option<&str>) -> PsqlPrintFmt { + if tuple.is_identity() + || tuple.is_connection_id() + || field.algebraic_type.is_identity() + || field.algebraic_type.is_connection_id() || name.map(ProductType::is_identity_tag).unwrap_or_default() || name.map(ProductType::is_connection_id_tag).unwrap_or_default() { return PsqlPrintFmt::Hex; }; - if self.tuple.is_timestamp() - || self.field.algebraic_type.is_timestamp() + if tuple.is_timestamp() + || field.algebraic_type.is_timestamp() || name.map(ProductType::is_timestamp_tag).unwrap_or_default() { return PsqlPrintFmt::Timestamp; }; - if self.tuple.is_time_duration() - || self.field.algebraic_type.is_time_duration() + if tuple.is_time_duration() + || field.algebraic_type.is_time_duration() || name.map(ProductType::is_time_duration_tag).unwrap_or_default() { return PsqlPrintFmt::Duration; @@ -601,6 +617,28 @@ impl PsqlType<'_> { } } +/// A wrapper that remember the `header` of the tuple/struct and the current field +#[derive(Debug, Clone)] +pub struct PsqlType<'a> { + /// The client used to format the output + pub client: PsqlClient, + /// The header of the tuple/struct + pub tuple: &'a ProductType, + /// The current field + pub field: &'a ProductTypeElement, + /// The index of the field in the tuple/struct + pub idx: usize, +} + +impl PsqlType<'_> { + /// Returns if the type is a special type + /// + /// Is required to check both the enclosing type and the inner element type + pub fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt { + PsqlPrintFmt::use_fmt(self.tuple, self.field, name) + } +} + /// An implementation of [`Serializer`](ser::Serializer) for `SQL` output. struct PsqlFormatter<'a, 'f> { fmt: SatnFormatter<'a, 'f>, @@ -614,8 +652,13 @@ impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> { type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>; type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>; - fn serialize_bool(self, v: bool) -> Result { - self.fmt.serialize_bool(v) + fn serialize_bool(mut self, v: bool) -> Result { + match self.ty.client { + PsqlClient::SpacetimeDB => self.fmt.serialize_bool(v), + PsqlClient::Postgres => { + write!(self.fmt, "{}", if v { "t" } else { "f" }) + } + } } fn serialize_u8(self, v: u8) -> Result { self.fmt.serialize_u8(v) @@ -676,12 +719,20 @@ impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> { self.fmt.serialize_f64(v) } - fn serialize_str(self, v: &str) -> Result { - self.fmt.serialize_str(v) + fn serialize_str(mut self, v: &str) -> Result { + match self.ty.client { + PsqlClient::SpacetimeDB => self.fmt.serialize_str(v), + PsqlClient::Postgres => write!(self.fmt, "{v}"), + } } - fn serialize_bytes(self, v: &[u8]) -> Result { - self.fmt.serialize_bytes(v) + fn serialize_bytes(mut self, v: &[u8]) -> Result { + match self.ty.client { + PsqlClient::Postgres if self.ty.use_fmt(None) == PsqlPrintFmt::Satn => { + write!(self.fmt, "\\x{}", hex::encode(v)) + } + _ => self.fmt.serialize_bytes(v), + } } fn serialize_array(self, len: usize) -> Result { diff --git a/crates/sats/src/time_duration.rs b/crates/sats/src/time_duration.rs index bd8d4543fbc..17f27ba988f 100644 --- a/crates/sats/src/time_duration.rs +++ b/crates/sats/src/time_duration.rs @@ -82,6 +82,20 @@ impl TimeDuration { pub fn checked_sub(self, other: Self) -> Option { self.to_micros().checked_sub(other.to_micros()).map(Self::from_micros) } + + /// Generate an `iso8601` format string + /// + /// Example: + /// ```rust + /// use std::time::Duration; + /// use spacetimedb_sats::time_duration::TimeDuration; + /// assert_eq!( TimeDuration::from_micros(0).to_iso8601().as_str(), "P0D"); + /// assert_eq!( TimeDuration::from_micros(-1_000_000).to_iso8601().as_str(), "-PT1S"); + /// assert_eq!( TimeDuration::from_duration(Duration::from_secs(60 * 24)).to_iso8601().as_str(), "PT1440S"); + /// ``` + pub fn to_iso8601(self) -> String { + chrono::Duration::microseconds(self.to_micros()).to_string() + } } impl From for TimeDuration { diff --git a/crates/sats/src/timestamp.rs b/crates/sats/src/timestamp.rs index 5da2e73964a..50affcf6094 100644 --- a/crates/sats/src/timestamp.rs +++ b/crates/sats/src/timestamp.rs @@ -171,13 +171,17 @@ impl Timestamp { pub fn checked_sub_duration(&self, duration: Duration) -> Option { self.checked_sub(TimeDuration::from_duration(duration)) } - /// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`. - pub fn to_rfc3339(&self) -> anyhow::Result { + + pub fn to_chrono_date_time(&self) -> anyhow::Result> { DateTime::from_timestamp_micros(self.to_micros_since_unix_epoch()) - .map(|t| t.to_rfc3339()) .ok_or_else(|| anyhow::anyhow!("Timestamp with i64 microseconds since Unix epoch overflows DateTime")) .with_context(|| self.to_micros_since_unix_epoch()) } + + /// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`. + pub fn to_rfc3339(&self) -> anyhow::Result { + Ok(self.to_chrono_date_time()?.to_rfc3339()) + } } impl Add for Timestamp { diff --git a/crates/standalone/Cargo.toml b/crates/standalone/Cargo.toml index 625f751457c..e719721721d 100644 --- a/crates/standalone/Cargo.toml +++ b/crates/standalone/Cargo.toml @@ -27,6 +27,7 @@ spacetimedb-core.workspace = true spacetimedb-datastore.workspace = true spacetimedb-lib.workspace = true spacetimedb-paths.workspace = true +spacetimedb-pg.workspace = true spacetimedb-table.workspace = true spacetimedb-schema.workspace = true diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index 10a52830053..9823a7095db 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -1,3 +1,4 @@ +use spacetimedb_pg::pg_server; use std::sync::Arc; use crate::{StandaloneEnv, StandaloneOptions}; @@ -10,9 +11,11 @@ use spacetimedb::db::{self, Storage}; use spacetimedb::startup::{self, TracingOptions}; use spacetimedb::util::jobs::JobCores; use spacetimedb::worker_metrics; +use spacetimedb_client_api::auth::JwtAuthProvider; use spacetimedb_client_api::routes::database::DatabaseRoutes; use spacetimedb_client_api::routes::router; use spacetimedb_client_api::routes::subscribe::WebSocketOptions; +use spacetimedb_client_api::NodeDelegate; use spacetimedb_paths::cli::{PrivKeyPath, PubKeyPath}; use spacetimedb_paths::server::{ConfigToml, ServerDataDir}; use tokio::net::TcpListener; @@ -176,12 +179,24 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { db_routes.root_post = db_routes.root_post.layer(DefaultBodyLimit::disable()); db_routes.db_put = db_routes.db_put.layer(DefaultBodyLimit::disable()); let extra = axum::Router::new().nest("/health", spacetimedb_client_api::routes::health::router()); - let service = router(&ctx, db_routes, extra).with_state(ctx); + let service = router(&ctx, db_routes, extra).with_state(ctx.clone()); let tcp = TcpListener::bind(listen_addr).await?; socket2::SockRef::from(&tcp).set_nodelay(true)?; - log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr().unwrap()); - axum::serve(tcp, service).await?; + log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); + let private_key = ctx.jwt_auth_provider().private_key_bytes().to_vec(); + tokio::select! { + _ = pg_server::start_pg(shutdown_rx.clone(), ctx, listen_addr, &private_key) => {}, + _ = axum::serve(tcp, service).with_graceful_shutdown(async move { + shutdown_rx.changed().await.ok(); + }) => {}, + _ = tokio::signal::ctrl_c() => { + println!("Shutting down servers..."); + let _ = shutdown_tx.send(()); // Notify all tasks + } + } + Ok(()) } diff --git a/docker-compose.yml b/docker-compose.yml index 8e2dd39fb87..a31ed13d89a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,8 @@ services: - /stdb ports: - "3000:3000" + # Postgres + - "5432:5432" # Tracy - "8086:8086" entrypoint: cargo watch -i flamegraphs -i log.conf --why -C crates/standalone -x 'run start --data-dir=/stdb/data --jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub --jwt-priv-key-path=/etc/spacetimedb/id_ecdsa' diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py new file mode 100644 index 00000000000..cad3f7b02e9 --- /dev/null +++ b/smoketests/tests/pg_wire.py @@ -0,0 +1,169 @@ +from .. import Smoketest +import subprocess +import os +import tomllib + + +def psql(identity: str, sql: str) -> str: + """Call `psql` and execute the given SQL statement.""" + + result = subprocess.run( + ["psql", "port=5432 host=127.0.0.1 user=postgres dbname=quickstart", "--quiet", "-c", sql], + encoding="utf8", + env={**os.environ, "PGPASSWORD": identity}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + # check=True + ) + + if result.stderr: + raise Exception(result.stderr.strip()) + return result.stdout.strip() + + +class SqlFormat(Smoketest): + AUTOPUBLISH = False + MODULE_CODE = """ +use spacetimedb::sats::{i256, u256}; +use spacetimedb::{ConnectionId, Identity, ReducerContext, Table, Timestamp, TimeDuration}; + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_ints, public)] +pub struct TInts { + i8: i8, + i16: i16, + i32: i32, + i64: i64, + i128: i128, + i256: i256, +} + +#[spacetimedb::table(name = t_ints_tuple, public)] +pub struct TIntsTuple { + tuple: TInts, +} + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_uints, public)] +pub struct TUints { + u8: u8, + u16: u16, + u32: u32, + u64: u64, + u128: u128, + u256: u256, +} + +#[spacetimedb::table(name = t_uints_tuple, public)] +pub struct TUintsTuple { + tuple: TUints, +} + +#[derive(Clone)] +#[spacetimedb::table(name = t_others, public)] +pub struct TOthers { + bool: bool, + f32: f32, + f64: f64, + str: String, + bytes: Vec, + identity: Identity, + connection_id: ConnectionId, + timestamp: Timestamp, + duration: TimeDuration, +} + +#[spacetimedb::table(name = t_others_tuple, public)] +pub struct TOthersTuple { + tuple: TOthers +} + +#[spacetimedb::reducer] +pub fn test(ctx: &ReducerContext) { + let tuple = TInts { + i8: -25, + i16: -3224, + i32: -23443, + i64: -2344353, + i128: -234434897853, + i256: (-234434897853i128).into(), + }; + ctx.db.t_ints().insert(tuple); + ctx.db.t_ints_tuple().insert(TIntsTuple { tuple }); + + let tuple = TUints { + u8: 105, + u16: 1050, + u32: 83892, + u64: 48937498, + u128: 4378528978889, + u256: 4378528978889u128.into(), + }; + ctx.db.t_uints().insert(tuple); + ctx.db.t_uints_tuple().insert(TUintsTuple { tuple }); + + let tuple = TOthers { + bool: true, + f32: 594806.58906, + f64: -3454353.345389043278459, + str: "This is spacetimedb".to_string(), + bytes: vec!(1, 2, 3, 4, 5, 6, 7), + identity: Identity::ONE, + connection_id: ConnectionId::ZERO, + timestamp: Timestamp::UNIX_EPOCH, + duration: TimeDuration::from_micros(1000 * 10000), + }; + ctx.db.t_others().insert(tuple.clone()); + ctx.db.t_others_tuple().insert(TOthersTuple { tuple }); +} +""" + + def assertSql(self, token: str, sql: str, expected): + self.maxDiff = None + sql_out = psql(token, sql) + sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()]) + expected = "\n".join([line.rstrip() for line in expected.splitlines()]) + print(sql_out) + self.assertMultiLineEqual(sql_out, expected) + + def test_sql_format(self): + """This test is designed to test calling `psql` to execute SQL statements""" + with open(self.config_path, "rb") as f: + config = tomllib.load(f) + token = config['spacetimedb_token'] + self.publish_module("quickstart", clear=False) + + self.call("test") + + self.assertSql(token, "SELECT * FROM t_ints", """\ +i8 | i16 | i32 | i64 | i128 | i256 +-----+-------+--------+----------+---------------+--------------- + -25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853 +(1 row)""") + self.assertSql(token, "SELECT * FROM t_ints_tuple", """\ +tuple +----------------------------------------------------------- + [-25,-3224,-23443,-2344353,-234434897853,"-0x369568a7bd"] +(1 row)""") + self.assertSql(token, "SELECT * FROM t_uints", """\ +u8 | u16 | u32 | u64 | u128 | u256 +-----+------+-------+----------+---------------+--------------- + 105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889 +(1 row)""") + self.assertSql(token, "SELECT * FROM t_uints_tuple", """\ +tuple +--------------------------------------------------------- + [105,1050,83892,48937498,4378528978889,"0x3fb74aa17c9"] +(1 row)""") + self.assertSql(token, "SELECT * FROM t_others", """\ +bool | f32 | f64 | str | bytes | identity | connection_id | timestamp | duration +------+-----------+---------------------+---------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- + t | 594806.56 | -3454353.3453890434 | This is spacetimedb | \\x01020304050607 | \\x0000000000000000000000000000000000000000000000000000000000000001 | \\x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | PT10S +(1 row)""") + # TODO: Uncomment when tuple support is added + self.assertSql(token, "SELECT * FROM t_others_tuple", """\ +tuple +-------------------------------------------------------------------------------------------------------- + [true,594806.56,-3454353.3453890434,"This is spacetimedb","01020304050607",["0x1"],[0],[0],[10000000]] +(1 row)""") From 24850ae17f46efc6b002ab0356ab21c943184207 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Mon, 12 May 2025 12:04:03 -0500 Subject: [PATCH 02/10] Route sats to deserialize and create a Typed version for decode special types --- Cargo.lock | 1 - crates/cli/src/subcommands/sql.rs | 16 +- crates/pg/Cargo.toml | 3 +- crates/pg/src/encoder.rs | 250 +++++++++++++ crates/pg/src/lib.rs | 1 + crates/pg/src/pg_server.rs | 180 ++-------- crates/sats/src/satn.rs | 572 +++++++++++++++++------------- crates/sats/src/ser.rs | 27 +- crates/sats/src/ser/impls.rs | 15 +- crates/sats/src/time_duration.rs | 4 +- smoketests/tests/pg_wire.py | 13 +- 11 files changed, 649 insertions(+), 433 deletions(-) create mode 100644 crates/pg/src/encoder.rs diff --git a/Cargo.lock b/Cargo.lock index 130b472e256..0e1e2c4dcbe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6213,7 +6213,6 @@ dependencies = [ "rcgen", "rustls", "rustls-pki-types", - "serde_json", "spacetimedb-client-api", "spacetimedb-client-api-messages", "spacetimedb-lib 1.4.0", diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 27bf77edaf7..12107b43a44 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -482,8 +482,8 @@ Roundtrip time: 1.00ms"#, vec![value], r#" column 0 | column 1 | column 2 | column 3 | column 4 | column 5 -----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------- - "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, +----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- + "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | P0D"#, ); // Check struct @@ -513,8 +513,8 @@ Roundtrip time: 1.00ms"#, vec![value.clone()], r#" bool | str | bytes | identity | connection_id | timestamp | duration -------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------- - true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, +------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- + true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | P0D"#, ); // Check nested struct, tuple... @@ -527,8 +527,8 @@ Roundtrip time: 1.00ms"#, vec![value.clone()], r#" column 0 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#, +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = P0D)"#, ); let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into(); @@ -540,8 +540,8 @@ Roundtrip time: 1.00ms"#, vec![value], r#" tuple ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - (0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#, +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + (col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = P0D))"#, ); Ok(()) diff --git a/crates/pg/Cargo.toml b/crates/pg/Cargo.toml index 68e41d746b5..60eb617e760 100644 --- a/crates/pg/Cargo.toml +++ b/crates/pg/Cargo.toml @@ -23,5 +23,4 @@ rcgen.workspace = true rustls.workspace = true thiserror.workspace = true tokio.workspace = true -tokio-rustls.workspace = true -serde_json.workspace = true +tokio-rustls.workspace = true \ No newline at end of file diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs new file mode 100644 index 00000000000..5b4eb8d8062 --- /dev/null +++ b/crates/pg/src/encoder.rs @@ -0,0 +1,250 @@ +use crate::pg_server::PgError; +use pgwire::api::portal::Format; +use pgwire::api::results::{DataRowEncoder, FieldInfo}; +use pgwire::api::Type; +use spacetimedb_lib::sats::satn::{PsqlPrintFmt, PsqlType, TypedWriter}; +use spacetimedb_lib::sats::{satn, ValueWithType}; +use spacetimedb_lib::{ + ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, +}; +use std::borrow::Cow; +use std::sync::Arc; + +pub(crate) fn row_desc(schema: &ProductType, format: &Format) -> Arc> { + Arc::new( + schema + .elements + .iter() + .enumerate() + .map(|(pos, ty)| { + let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col_{pos}")); + let field_type = type_of(schema, ty); + FieldInfo::new(field_name, None, None, field_type, format.format_for(pos)) + }) + .collect(), + ) +} + +pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { + let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); + match &ty.algebraic_type { + AlgebraicType::String => Type::VARCHAR, + AlgebraicType::Bool => Type::BOOL, + AlgebraicType::U8 | AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4, + AlgebraicType::I64 => Type::INT8, + AlgebraicType::U32 + | AlgebraicType::U64 + | AlgebraicType::I128 + | AlgebraicType::U128 + | AlgebraicType::I256 + | AlgebraicType::U256 => Type::NUMERIC, + AlgebraicType::F32 => Type::FLOAT4, + AlgebraicType::F64 => Type::FLOAT8, + AlgebraicType::Array(ty) => match *ty.elem_ty { + AlgebraicType::String => Type::VARCHAR_ARRAY, + AlgebraicType::Bool => Type::BOOL_ARRAY, + AlgebraicType::U8 => Type::BYTEA, + AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY, + AlgebraicType::I64 => Type::INT8_ARRAY, + AlgebraicType::U32 + | AlgebraicType::U64 + | AlgebraicType::I128 + | AlgebraicType::U128 + | AlgebraicType::I256 + | AlgebraicType::U256 => Type::NUMERIC_ARRAY, + _ => Type::ANYARRAY, + }, + AlgebraicType::Product(_) => match format { + PsqlPrintFmt::Hex => Type::BYTEA_ARRAY, + PsqlPrintFmt::Timestamp => Type::TIMESTAMP, + PsqlPrintFmt::Duration => Type::INTERVAL, + _ => Type::JSON, + }, + AlgebraicType::Sum(sum) if sum.is_simple_enum() => Type::ANYENUM, + _ => Type::UNKNOWN, + } +} + +impl ser::Error for PgError { + fn custom(msg: T) -> Self { + PgError::Other(anyhow::anyhow!(msg.to_string())) + } +} + +pub(crate) struct PsqlFormatter<'a> { + pub(crate) encoder: &'a mut DataRowEncoder, +} + +impl TypedWriter for PsqlFormatter<'_> { + type Error = PgError; + + fn write(&mut self, value: W) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_string())?; + Ok(()) + } + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_string(&mut self, value: &str) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_rfc3339()?)?; + Ok(()) + } + + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_iso8601())?; + Ok(()) + } + + fn write_alt_record( + &mut self, + ty: &PsqlType, + value: &ValueWithType<'_, ProductValue>, + ) -> Result { + let json = satn::PsqlWrapper { ty: ty.clone(), value }.to_string(); + self.encoder.encode_field(&json)?; + Ok(true) + } + + fn write_record( + &mut self, + _fields: Vec<(Cow, PsqlType, ValueWithType)>, + ) -> Result<(), Self::Error> { + unreachable!("Use `write_alt_record` for records in PSQL format"); + } + + fn write_variant( + &mut self, + _tag: u8, + ty: PsqlType, + _name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error> { + let json = satn::PsqlWrapper { ty, value }.to_string(); + self.encoder.encode_field(&json)?; + Ok(()) + } +} + +// Tests +#[cfg(test)] +mod tests { + use super::*; + use crate::pg_server::to_rows; + use futures::StreamExt; + use spacetimedb_client_api_messages::http::SqlStmtResult; + use spacetimedb_lib::sats::algebraic_value::Packed; + use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType}; + use spacetimedb_lib::{ConnectionId, Identity}; + + async fn run(schema: ProductType, row: ProductValue) -> String { + let header = row_desc(&schema, &Format::UnifiedText); + + let stmt = SqlStmtResult { + schema, + rows: vec![row], + total_duration_micros: 0, + stats: Default::default(), + }; + let mut stream = to_rows(stmt, header).unwrap(); + let mut result = String::new(); + if let Some(row) = stream.next().await { + result = String::from_utf8_lossy(row.unwrap().data.freeze().as_ref()).to_string(); + } + result + } + + #[tokio::test] + async fn test_primitives() { + let schema = ProductType::from([ + AlgebraicType::U8, + AlgebraicType::I8, + AlgebraicType::I16, + AlgebraicType::U16, + AlgebraicType::I32, + AlgebraicType::U32, + AlgebraicType::I64, + AlgebraicType::U64, + AlgebraicType::I128, + AlgebraicType::U128, + AlgebraicType::I256, + AlgebraicType::U256, + AlgebraicType::F32, + AlgebraicType::F64, + AlgebraicType::String, + AlgebraicType::Bool, + ]); + let value = product![ + 1u8, + -1i8, + -2i16, + 3u16, + -4i32, + 5u32, + -6i64, + 7u64, + Packed::from(-8i128), + Packed::from(9u128), + i256::from(-10), + u256::from(11u128), + 12.34f32, + 56.78f64, + "test".to_string(), + true, + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t"); + } + + #[tokio::test] + async fn test_enum() { + let schema = ProductType::from([AlgebraicType::option(AlgebraicType::I64)]); + let value = product![ + AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Some(1) + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}1"); + } + + #[tokio::test] + async fn test_special_types() { + let schema = ProductType::from([ + AlgebraicType::identity(), + AlgebraicType::connection_id(), + AlgebraicType::time_duration(), + AlgebraicType::timestamp(), + AlgebraicType::bytes(), + ]); + let value = product![ + Identity::ZERO, + ConnectionId::ZERO, + TimeDuration::from_micros(0), + Timestamp::from_micros_since_unix_epoch(1622545800000), + AlgebraicValue::Bytes("test".as_bytes().into()), + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0B\\x0000000000000000000000000000000000000000000000000000000000000000\0\0\0\"\\x00000000000000000000000000000000\0\0\0\u{3}P0D\0\0\0\u{1d}1970-01-19T18:42:25.800+00:00\0\0\0\n\\x74657374"); + } +} diff --git a/crates/pg/src/lib.rs b/crates/pg/src/lib.rs index 00efd53b3e9..c4466bbc50d 100644 --- a/crates/pg/src/lib.rs +++ b/crates/pg/src/lib.rs @@ -1 +1,2 @@ +mod encoder; pub mod pg_server; diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index 6e8fdc94574..9c23f724310 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -1,6 +1,7 @@ use std::fmt::Debug; use std::sync::Arc; +use crate::encoder::{row_desc, PsqlFormatter}; use async_trait::async_trait; use axum::body::to_bytes; use axum::response::IntoResponse; @@ -15,7 +16,7 @@ use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::Format; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, Type}; +use pgwire::api::ClientInfo; use pgwire::api::{NoopErrorHandler, METADATA_DATABASE, METADATA_USER}; use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; @@ -32,12 +33,10 @@ use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; use spacetimedb_client_api_messages::http::SqlStmtResult; use spacetimedb_client_api_messages::name::DatabaseName; -use spacetimedb_lib::sats::satn::{PsqlPrintFmt, Satn}; -use spacetimedb_lib::sats::ArrayValue; +use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer}; +use spacetimedb_lib::sats::{satn, Serialize, Typespace}; use spacetimedb_lib::version::spacetimedb_lib_version; -use spacetimedb_lib::{ - AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, -}; +use spacetimedb_lib::ProductValue; use thiserror::Error; use tokio::net::TcpListener; use tokio::sync::{watch, Mutex}; @@ -45,7 +44,7 @@ use tokio_rustls::rustls::ServerConfig; use tokio_rustls::TlsAcceptor; #[derive(Error, Debug)] -enum PgError { +pub(crate) enum PgError { #[error("(metadata) {0}")] MetadataError(anyhow::Error), #[error("(Sql) {0}")] @@ -60,8 +59,6 @@ enum PgError { Pem(#[from] rustls_pki_types::pem::Error), #[error(transparent)] RustTls(#[from] rustls::Error), - #[error("Special type with format {0} is invalid for {1}")] - SpecialTypeInvalid(PsqlPrintFmt, String), #[error(transparent)] Other(#[from] anyhow::Error), } @@ -82,169 +79,31 @@ struct Metadata { auth: SpacetimeAuth, } -fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { - let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); - match &ty.algebraic_type { - AlgebraicType::String => Type::VARCHAR, - AlgebraicType::Bool => Type::BOOL, - AlgebraicType::I8 | AlgebraicType::U8 | AlgebraicType::I16 => Type::INT2, - AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4, - AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8, - AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => { - Type::NUMERIC - } - AlgebraicType::F32 => Type::FLOAT4, - AlgebraicType::F64 => Type::FLOAT8, - AlgebraicType::Array(ty) => match *ty.elem_ty { - AlgebraicType::String => Type::VARCHAR_ARRAY, - AlgebraicType::Bool => Type::BOOL_ARRAY, - AlgebraicType::U8 => Type::BYTEA_ARRAY, - AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY, - AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY, - AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY, - AlgebraicType::U64 - | AlgebraicType::I128 - | AlgebraicType::U128 - | AlgebraicType::I256 - | AlgebraicType::U256 => Type::NUMERIC_ARRAY, - _ => Type::ANYARRAY, - }, - AlgebraicType::Product(_) => match format { - PsqlPrintFmt::Hex => Type::BYTEA_ARRAY, - PsqlPrintFmt::Timestamp => Type::TIMESTAMP, - PsqlPrintFmt::Duration => Type::INTERVAL, - _ => Type::JSON, - }, - x if x.as_sum().map(|x| x.is_simple_enum()).unwrap_or(false) => Type::ANYENUM, - _ => Type::UNKNOWN, - } -} - -fn encode_value( - encoder: &mut DataRowEncoder, - schema: &ProductType, - ty: &ProductTypeElement, - value: &AlgebraicValue, -) -> Result<(), PgError> { - let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); - - match value { - AlgebraicValue::Bool(x) => encoder.encode_field(x)?, - AlgebraicValue::I8(x) => encoder.encode_field(x)?, - AlgebraicValue::U8(x) => encoder.encode_field(&(*x as i16))?, - AlgebraicValue::I16(x) => encoder.encode_field(x)?, - AlgebraicValue::U16(x) => encoder.encode_field(&(*x as u32))?, - AlgebraicValue::I32(x) => encoder.encode_field(x)?, - AlgebraicValue::U32(x) => encoder.encode_field(x)?, - AlgebraicValue::I64(x) => encoder.encode_field(&x)?, - AlgebraicValue::U64(x) => encoder.encode_field(&x.to_string())?, - AlgebraicValue::I128(x) => { - let x = x.0; - encoder.encode_field(&(x.to_string()))? - } - AlgebraicValue::U128(x) => { - let x = x.0; - encoder.encode_field(&x.to_string())?; - } - AlgebraicValue::I256(x) => encoder.encode_field(&(x.to_string()))?, - AlgebraicValue::U256(x) => encoder.encode_field(&x.to_string())?, - AlgebraicValue::F32(x) => encoder.encode_field(&x.into_inner())?, - AlgebraicValue::F64(x) => encoder.encode_field(&x.into_inner())?, - AlgebraicValue::String(x) => encoder.encode_field(&x.to_string())?, - AlgebraicValue::Array(x) => match x { - ArrayValue::Bool(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::I8(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::U8(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::I16(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::I32(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::U32(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::I64(x) => { - encoder.encode_field(&x.as_ref())?; - } - ArrayValue::F32(x) => { - let x = x.iter().map(|x| x.into_inner()).collect::>(); - encoder.encode_field(&x)?; - } - ArrayValue::F64(x) => { - let x = x.iter().map(|x| x.into_inner()).collect::>(); - encoder.encode_field(&x)?; - } - ArrayValue::String(x) => { - let x = x.iter().map(|x| x.as_ref()).collect::>(); - encoder.encode_field(&x)?; - } - _ => { - let json = serde_json::to_string(&value).unwrap(); - encoder.encode_field(&json)?; - } - }, - AlgebraicValue::Product(x) => match (format, x.as_special_value_raw()) { - (PsqlPrintFmt::Hex, Some(AlgebraicValue::U128(x))) => encoder.encode_field(&x.0.to_be_bytes())?, - (PsqlPrintFmt::Hex, Some(AlgebraicValue::U256(x))) => encoder.encode_field(&x.to_be_bytes())?, - (PsqlPrintFmt::Timestamp, Some(AlgebraicValue::I64(x))) => { - encoder.encode_field(&Timestamp::from_micros_since_unix_epoch(*x).to_rfc3339()?)? - } - (PsqlPrintFmt::Duration, Some(AlgebraicValue::I64(x))) => { - encoder.encode_field(&TimeDuration::from_micros(*x).to_iso8601())? - } - (PsqlPrintFmt::Satn, Some(..)) - | (PsqlPrintFmt::Hex | PsqlPrintFmt::Timestamp | PsqlPrintFmt::Duration, _) => { - return Err(PgError::SpecialTypeInvalid(format, value.to_satn())) - } - (PsqlPrintFmt::Satn, None) => { - let json = serde_json::to_string(&value).unwrap(); - - encoder.encode_field(&json)? - } - }, - x => encoder.encode_field(&x.to_satn())?, - } - - Ok(()) -} - -fn to_rows( +pub(crate) fn to_rows( stmt: SqlStmtResult, header: Arc>, ) -> Result>, PgError> { let mut results = Vec::with_capacity(stmt.rows.len()); + let ty = Typespace::EMPTY.with_type(&stmt.schema); for row in stmt.rows { let mut encoder = DataRowEncoder::new(header.clone()); - for (idx, ty) in stmt.schema.elements.iter().enumerate() { - let value = row.get_field(idx, None).unwrap(); - - encode_value(&mut encoder, &stmt.schema, ty, value)?; + for (idx, value) in ty.with_values(&row).enumerate() { + let ty = satn::PsqlType { + client: PsqlClient::Postgres, + tuple: ty.ty(), + field: &ty.ty().elements[idx], + idx, + }; + let mut fmt = PsqlFormatter { encoder: &mut encoder }; + value.serialize(TypedSerializer { ty: &ty, f: &mut fmt })?; } results.push(encoder.finish()); } Ok(stream::iter(results)) } -fn row_desc_from_stmt(stmt: &SqlStmtResult, format: &Format) -> Vec { - let mut field_info = Vec::with_capacity(stmt.schema.elements.len()); - for (idx, ty) in stmt.schema.elements.iter().enumerate() { - let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col {idx}")); - let field_type = type_of(&stmt.schema, ty); - let field_desc = FieldInfo::new(field_name, None, None, field_type, format.format_for(idx)); - field_info.push(field_desc); - } - field_info -} - fn stats(stmt: &SqlStmtResult) -> String { let mut info = Vec::new(); if stmt.stats.rows_inserted != 0 { @@ -277,6 +136,7 @@ async fn response(res: axum::response::Result, database: &str) -> Result { let res = err.into_response(); if res.status() == StatusCode::NOT_FOUND { + log::error!("PG: Database not found: {database}"); return Err(PgWireError::UserError(Box::new(ErrorInfo::new( "FATAL".to_string(), "3D000".to_string(), @@ -288,6 +148,7 @@ async fn response(res: axum::response::Result, database: &str) -> Result PgSpace let mut result = Vec::with_capacity(sql.len()); for sql_result in sql { - let header = Arc::new(row_desc_from_stmt(&sql_result, &Format::UnifiedText)); + let header = row_desc(&sql_result.schema, &Format::UnifiedText); let tag = Tag::new(&stats(&sql_result)); @@ -418,6 +279,7 @@ impl response(SpacetimeAuth::from_claims(&self.ctx, claims), &database).await?, Err(err) => { + log::error!("PG: Authentication failed for user {user} on database {database}: {err}"); let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); return close_client(client, err).await; } diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 504f59c4332..140fa0af4f5 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -1,13 +1,14 @@ -use crate::de::DeserializeSeed; use crate::time_duration::TimeDuration; use crate::timestamp::Timestamp; -use crate::{i256, u256, AlgebraicValue, WithTypespace}; +use crate::{i256, u256, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; use derive_more::{Display, From, Into}; +use std::borrow::Cow; +use std::marker::PhantomData; -/// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods. +/// An extension trait for [`Serialize`] providing formatting methods. pub trait Satn: ser::Serialize { /// Formats the value using the SATN data format into the formatter `f`. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -18,9 +19,12 @@ pub trait Satn: ser::Serialize { /// Formats the value using the postgres SATN(PsqlFormatter { f }, /* PsqlType */) formatter `f`. fn fmt_psql(&self, f: &mut fmt::Formatter, ty: &PsqlType<'_>) -> fmt::Result { Writer::with(f, |f| { - self.serialize(PsqlFormatter { - fmt: SatnFormatter { f }, + self.serialize(TypedSerializer { ty, + f: &mut SqlFormatter { + fmt: SatnFormatter { f }, + ty, + }, }) })?; Ok(()) @@ -229,6 +233,27 @@ struct SatnFormatter<'a, 'f> { f: Writer<'a, 'f>, } +impl SatnFormatter<'_, '_> { + fn ser_variant( + &mut self, + _tag: u8, + name: Option<&str>, + value: &T, + ) -> Result<(), SatnError> { + write!(self, "(")?; + EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { + if let Some(name) = name { + write!(f, "{name}")?; + } + write!(f, " = ")?; + value.serialize(SatnFormatter { f })?; + Ok(()) + })?; + write!(self, ")")?; + + Ok(()) + } +} /// An error occurred during serialization to the SATS data format. #[derive(From, Into)] pub struct SatnError(fmt::Error); @@ -331,20 +356,11 @@ impl<'a, 'f> ser::Serializer for SatnFormatter<'a, 'f> { fn serialize_variant( mut self, - _tag: u8, + tag: u8, name: Option<&str>, value: &T, ) -> Result { - write!(self, "(")?; - EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { - if let Some(name) = name { - write!(f, "{name}")?; - } - write!(f, " = ")?; - value.serialize(SatnFormatter { f })?; - Ok(()) - })?; - write!(self, ")") + self.ser_variant(tag, name, value) } } @@ -427,140 +443,6 @@ impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> { } } -struct PsqlEntryWrapper<'a, 'f, const SEP: char> { - entry: EntryWrapper<'a, 'f, SEP>, - /// The index of the element. - idx: usize, - ty: &'a PsqlType<'a>, -} - -/// Provides the data format for named products for `SQL`. -pub struct PsqlNamedFormatter<'a, 'f> { - /// The formatter for each element separating elements by a `,`. - f: PsqlEntryWrapper<'a, 'f, ','>, - /// If is not [Self::is_special] to control if we start with `(` - start: bool, - /// Remember what format we are using - use_fmt: PsqlPrintFmt, -} - -impl<'a, 'f> PsqlNamedFormatter<'a, 'f> { - fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self { - Self { - start: true, - f: PsqlEntryWrapper { - entry: EntryWrapper::new(f), - idx: 0, - ty, - }, - // Will set later - use_fmt: PsqlPrintFmt::Satn, - } - } -} - -impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { - type Ok = (); - type Error = SatnError; - - fn serialize_element( - &mut self, - name: Option<&str>, - elem: &T, - ) -> Result<(), Self::Error> { - // For binary data & special types, output in `hex` format and skip the tagging of each value - // We need to check for both the enclosing(`self.f.ty`) type and the inner element(`name`) type. - self.use_fmt = self.f.ty.use_fmt(name); - let res = self.f.entry.entry(|mut f| { - let PsqlType { - client, - tuple, - field, - idx, - } = self.f.ty; - if !self.use_fmt.is_special() { - let start = match client { - PsqlClient::SpacetimeDB => "(", - PsqlClient::Postgres => "{", - }; - if self.start { - write!(f, "{start}")?; - - self.start = false; - } - // Format the name or use the index if unnamed. - if let Some(name) = name { - write!(f, "{name}")?; - } else { - write!(f, "{idx}")?; - } - let sep = match client { - PsqlClient::SpacetimeDB => "=", - PsqlClient::Postgres => ":", - }; - write!(f, " {sep} ")?; - } - //Is a nested product type? - let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() { - (product, &product.elements[self.f.idx], self.f.idx) - } else { - (*tuple, *field, *idx) - }; - - elem.serialize(PsqlFormatter { - fmt: SatnFormatter { f }, - ty: &PsqlType { - client: *client, - tuple, - field, - idx, - }, - })?; - - Ok(()) - }); - - // Advance to the next field. - if !self.use_fmt.is_special() { - self.f.idx += 1; - } - - res?; - - Ok(()) - } - - fn end(mut self) -> Result { - if !self.use_fmt.is_special() { - let end = match self.f.ty.client { - PsqlClient::SpacetimeDB => ")", - PsqlClient::Postgres => "}", - }; - write!(self.f.entry.fmt, "{end}")?; - } - Ok(()) - } -} - -/// Provides the data format for unnamed products for `SQL`. -struct PsqlSeqFormatter<'a, 'f> { - /// Delegates to the named format. - inner: PsqlNamedFormatter<'a, 'f>, -} - -impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> { - type Ok = (); - type Error = SatnError; - - fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { - ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem) - } - - fn end(self) -> Result { - ser::SerializeNamedProduct::end(self.inner) - } -} - /// Which client is used to format the `SQL` output? #[derive(PartialEq, Copy, Clone, Debug)] pub enum PsqlClient { @@ -634,157 +516,363 @@ impl PsqlType<'_> { /// Returns if the type is a special type /// /// Is required to check both the enclosing type and the inner element type - pub fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt { - PsqlPrintFmt::use_fmt(self.tuple, self.field, name) + pub fn use_fmt(&self) -> PsqlPrintFmt { + PsqlPrintFmt::use_fmt(self.tuple, self.field, None) } } /// An implementation of [`Serializer`](ser::Serializer) for `SQL` output. -struct PsqlFormatter<'a, 'f> { +struct SqlFormatter<'a, 'f> { fmt: SatnFormatter<'a, 'f>, ty: &'a PsqlType<'a>, } -impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> { +/// A trait for writing values, after the special types has been determined. +/// +/// This is used to write values that could have different representations depending on the output format, +/// as defined by [`PsqlClient`] and [`PsqlPrintFmt`]. +pub trait TypedWriter { + type Error: ser::Error; + + /// Writes a value using [`ser::Serializer`] + fn write(&mut self, value: W) -> Result<(), Self::Error>; + + // Values that need special handling: + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error>; + fn write_string(&mut self, value: &str) -> Result<(), Self::Error>; + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error>; + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error>; + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error>; + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error>; + /// Writes a value as an alternative record format, e.g., for use `JSON` inside `SQL`. + fn write_alt_record( + &mut self, + _ty: &PsqlType, + _value: &ValueWithType<'_, ProductValue>, + ) -> Result { + Ok(false) + } + + fn write_record( + &mut self, + fields: Vec<(Cow, PsqlType, ValueWithType)>, + ) -> Result<(), Self::Error>; + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error>; +} + +/// A formatter for arrays that uses the `TypedWriter` trait to write elements. +pub struct TypedArrayFormatter<'a, 'f, F> { + ty: &'a PsqlType<'a>, + f: &'f mut F, +} + +impl ser::SerializeArray for TypedArrayFormatter<'_, '_, F> { type Ok = (); - type Error = SatnError; - type SerializeArray = ArrayFormatter<'a, 'f>; - type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>; - type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>; + type Error = F::Error; - fn serialize_bool(mut self, v: bool) -> Result { - match self.ty.client { - PsqlClient::SpacetimeDB => self.fmt.serialize_bool(v), - PsqlClient::Postgres => { - write!(self.fmt, "{}", if v { "t" } else { "f" }) - } - } + fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A formatter for sequences that uses the `TypedWriter` trait to write elements. +pub struct TypedSeqFormatter<'a, 'f, F> { + ty: &'a PsqlType<'a>, + f: &'f mut F, +} + +impl ser::SerializeSeqProduct for TypedSeqFormatter<'_, '_, F> { + type Ok = (); + type Error = F::Error; + + fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; + Ok(()) + } + + fn end(self) -> Result { + Ok(()) } +} + +/// A formatter for named products that uses the `TypedWriter` trait to write elements. +pub struct TypedNamedProductFormatter { + f: PhantomData, +} + +impl ser::SerializeNamedProduct for TypedNamedProductFormatter { + type Ok = (); + type Error = F::Error; + + fn serialize_element( + &mut self, + _name: Option<&str>, + _elem: &T, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A serializer that uses the `TypedWriter` trait to serialize values +pub struct TypedSerializer<'a, 'f, F> { + pub ty: &'a PsqlType<'a>, + pub f: &'f mut F, +} + +impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { + type Ok = (); + type Error = F::Error; + type SerializeArray = TypedArrayFormatter<'a, 'f, F>; + type SerializeSeqProduct = TypedSeqFormatter<'a, 'f, F>; + type SerializeNamedProduct = TypedNamedProductFormatter; + + fn serialize_bool(self, v: bool) -> Result { + self.f.write_bool(v) + } + fn serialize_u8(self, v: u8) -> Result { - self.fmt.serialize_u8(v) + self.f.write(v) } + fn serialize_u16(self, v: u16) -> Result { - self.fmt.serialize_u16(v) + self.f.write(v) } + fn serialize_u32(self, v: u32) -> Result { - self.fmt.serialize_u32(v) + self.f.write(v) } + fn serialize_u64(self, v: u64) -> Result { - self.fmt.serialize_u64(v) + self.f.write(v) } + fn serialize_u128(self, v: u128) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()), - _ => self.fmt.serialize_u128(v), + match self.ty.use_fmt() { + PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()), + _ => self.f.write(v), } } + fn serialize_u256(self, v: u256) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()), - _ => self.fmt.serialize_u256(v), + match self.ty.use_fmt() { + PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()), + _ => self.f.write(v), } } + fn serialize_i8(self, v: i8) -> Result { - self.fmt.serialize_i8(v) + self.f.write(v) } + fn serialize_i16(self, v: i16) -> Result { - self.fmt.serialize_i16(v) + self.f.write(v) } + fn serialize_i32(self, v: i32) -> Result { - self.fmt.serialize_i32(v) + self.f.write(v) } - fn serialize_i64(mut self, v: i64) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Duration => { - write!(self.fmt, "{}", TimeDuration::from_micros(v))?; - Ok(()) - } - PsqlPrintFmt::Timestamp => { - write!(self.fmt, "{}", Timestamp::from_micros_since_unix_epoch(v))?; - Ok(()) - } - _ => self.fmt.serialize_i64(v), + + fn serialize_i64(self, v: i64) -> Result { + match self.ty.use_fmt() { + PsqlPrintFmt::Duration => self.f.write_duration(TimeDuration::from_micros(v)), + PsqlPrintFmt::Timestamp => self.f.write_timestamp(Timestamp::from_micros_since_unix_epoch(v)), + _ => self.f.write(v), } } + fn serialize_i128(self, v: i128) -> Result { - self.fmt.serialize_i128(v) + self.f.write(v) } + fn serialize_i256(self, v: i256) -> Result { - self.fmt.serialize_i256(v) + self.f.write(v) } + fn serialize_f32(self, v: f32) -> Result { - self.fmt.serialize_f32(v) + self.f.write(v) } + fn serialize_f64(self, v: f64) -> Result { - self.fmt.serialize_f64(v) + self.f.write(v) } - fn serialize_str(mut self, v: &str) -> Result { - match self.ty.client { - PsqlClient::SpacetimeDB => self.fmt.serialize_str(v), - PsqlClient::Postgres => write!(self.fmt, "{v}"), - } + fn serialize_str(self, v: &str) -> Result { + self.f.write_string(v) } - fn serialize_bytes(mut self, v: &[u8]) -> Result { - match self.ty.client { - PsqlClient::Postgres if self.ty.use_fmt(None) == PsqlPrintFmt::Satn => { - write!(self.fmt, "\\x{}", hex::encode(v)) - } - _ => self.fmt.serialize_bytes(v), + fn serialize_bytes(self, v: &[u8]) -> Result { + if self.ty.use_fmt() == PsqlPrintFmt::Satn { + self.f.write_hex(v) + } else { + self.f.write_bytes(v) } } - fn serialize_array(self, len: usize) -> Result { - self.fmt.serialize_array(len) + fn serialize_array(self, _len: usize) -> Result { + Ok(TypedArrayFormatter { ty: self.ty, f: self.f }) } - fn serialize_seq_product(self, len: usize) -> Result { - Ok(PsqlSeqFormatter { - inner: self.serialize_named_product(len)?, - }) + fn serialize_seq_product(self, _len: usize) -> Result { + Ok(TypedSeqFormatter { ty: self.ty, f: self.f }) } fn serialize_named_product(self, _len: usize) -> Result { - Ok(PsqlNamedFormatter::new(self.ty, self.fmt.f)) + unreachable!("This should never be called, use `serialize_named_product_raw` instead."); + } + + fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result { + let val = &value.val.elements; + assert_eq!(val.len(), value.ty().elements.len()); + // If the value is a special type, we can write it directly + if self.ty.use_fmt().is_special() { + // Is a nested product type? + // We need to check for both the enclosing(`self.ty`) type and the inner element type. + let (tuple, field) = if let Some(product) = self.ty.field.algebraic_type.as_product() { + (product, &product.elements[0]) + } else { + (self.ty.tuple, self.ty.field) + }; + return value.val.serialize(TypedSerializer { + ty: &PsqlType { + client: self.ty.client, + tuple, + field, + idx: self.ty.idx, + }, + f: self.f, + }); + } + // Allow to switch to an alternative record format, for example to write a `JSON` record. + if self.f.write_alt_record(self.ty, value)? { + return Ok(()); + } + let mut record = Vec::with_capacity(val.len()); + + for (idx, (val, field)) in val.iter().zip(&*value.ty().elements).enumerate() { + let ty = PsqlType { + client: self.ty.client, + tuple: value.ty(), + field, + idx, + }; + record.push(( + field + .name() + .map(Cow::from) + .unwrap_or_else(|| Cow::from(format!("col_{idx}"))), + ty, + value.with(&field.algebraic_type, val), + )); + } + self.f.write_record(record) + } + + fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result { + let sv = sum.value(); + let (tag, val) = (sv.tag, &*sv.value); + let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. + let product = ProductType::from([var_ty.algebraic_type.clone()]); + let ty = PsqlType { + client: self.ty.client, + tuple: &product, + field: &product.elements[0], + idx: 0, + }; + self.f + .write_variant(tag, ty, var_ty.name(), sum.with(&var_ty.algebraic_type, val)) } - fn serialize_variant( + fn serialize_variant( self, - tag: u8, - name: Option<&str>, - value: &T, + _tag: u8, + _name: Option<&str>, + _value: &T, ) -> Result { - self.fmt.serialize_variant(tag, name, value) + unreachable!("Use `serialize_variant_raw` instead."); + } +} + +impl TypedWriter for SqlFormatter<'_, '_> { + type Error = SatnError; + + fn write(&mut self, value: W) -> Result<(), Self::Error> { + write!(self.fmt, "{value}") } - unsafe fn serialize_bsatn(self, ty: &Ty, bsatn: &[u8]) -> Result - where - for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into>, - { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_bsatn(ty, bsatn) } + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> { + write!(self.fmt, "{value}") } - unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator>( - self, - ty: &Ty, - total_bsatn_len: usize, - bsatn: I, - ) -> Result - where - for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into>, - { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) } - } - - unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator>( - self, - total_len: usize, - string: I, - ) -> Result { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_str_in_chunks(total_len, string) } + fn write_string(&mut self, value: &str) -> Result<(), Self::Error> { + write!(self.fmt, "\"{value}\"") + } + + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { + write!(self.fmt, "0x{}", hex::encode(value)) + } + + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { + write!(self.fmt, "0x{}", hex::encode(value)) + } + + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { + write!(self.fmt, "{}", value.to_rfc3339().unwrap()) + } + + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { + write!(self.fmt, "{}", value.to_iso8601()) + } + + fn write_record( + &mut self, + fields: Vec<(Cow, PsqlType<'_>, ValueWithType)>, + ) -> Result<(), Self::Error> { + let (start, sep, end, quote) = match self.ty.client { + PsqlClient::SpacetimeDB => ("(", " =", ")", ""), + PsqlClient::Postgres => ("{", ":", "}", "\""), + }; + write!(self.fmt, "{start}")?; + for (idx, (name, ty, value)) in fields.into_iter().enumerate() { + if idx > 0 { + write!(self.fmt, ", ")?; + } + write!(self.fmt, "{quote}{name}{quote}{sep} ")?; + + // Serialize the value + value.serialize(TypedSerializer { ty: &ty, f: self })?; + } + write!(self.fmt, "{end}")?; + Ok(()) + } + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error> { + self.write_record(vec![( + name.map(Cow::from).unwrap_or_else(|| Cow::from(format!("col_{tag}"))), + ty, + value, + )]) } } diff --git a/crates/sats/src/ser.rs b/crates/sats/src/ser.rs index a9fd0fe01f8..4d56fc5aabc 100644 --- a/crates/sats/src/ser.rs +++ b/crates/sats/src/ser.rs @@ -6,7 +6,7 @@ mod impls; pub mod serde; use crate::de::DeserializeSeed; -use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter}; +use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType}; use crate::{AlgebraicValue, WithTypespace}; use core::marker::PhantomData; use core::{convert::Infallible, fmt}; @@ -117,6 +117,31 @@ pub trait Serializer: Sized { /// The argument is the number of fields in the product. fn serialize_named_product(self, len: usize) -> Result; + /// Serialize a product with named fields. + /// + /// Allow to override the default serialization for where we need to switch the output format, + /// see [`crate::satn::TypedWriter`]. + fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result { + let val = &value.val.elements; + assert_eq!(val.len(), value.ty().elements.len()); + let mut prod = self.serialize_named_product(val.len())?; + for (val, el_ty) in val.iter().zip(&*value.ty().elements) { + prod.serialize_element(el_ty.name(), &value.with(&el_ty.algebraic_type, val))? + } + prod.end() + } + + /// Serialize a sum value + /// + /// Allow to override the default serialization for where we need to switch the output format, + /// see [`crate::satn::TypedWriter`]. + fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result { + let sv = sum.value(); + let (tag, val) = (sv.tag, &*sv.value); + let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. + self.serialize_variant(tag, var_ty.name(), &sum.with(&var_ty.algebraic_type, val)) + } + /// Serialize a sum value provided the chosen `tag`, `name`, and `value`. fn serialize_variant( self, diff --git a/crates/sats/src/ser/impls.rs b/crates/sats/src/ser/impls.rs index 914099e14c5..39d08d38d64 100644 --- a/crates/sats/src/ser/impls.rs +++ b/crates/sats/src/ser/impls.rs @@ -1,4 +1,4 @@ -use super::{Serialize, SerializeArray, SerializeNamedProduct, SerializeSeqProduct, Serializer}; +use super::{Serialize, SerializeArray, SerializeSeqProduct, Serializer}; use crate::{i256, u256}; use crate::{AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, SumValue, ValueWithType, F32, F64}; use core::ops::Bound; @@ -190,19 +190,10 @@ impl_serialize!( } ); impl_serialize!([] ValueWithType<'_, SumValue>, (self, ser) => { - let sv = self.value(); - let (tag, val) = (sv.tag, &*sv.value); - let var_ty = &self.ty().variants[tag as usize]; // Extract the variant type by tag. - ser.serialize_variant(tag, var_ty.name(), &self.with(&var_ty.algebraic_type, val)) + ser.serialize_variant_raw(self) }); impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => { - let val = &self.value().elements; - assert_eq!(val.len(), self.ty().elements.len()); - let mut prod = ser.serialize_named_product(val.len())?; - for (val, el_ty) in val.iter().zip(&*self.ty().elements) { - prod.serialize_element(el_ty.name(), &self.with(&el_ty.algebraic_type, val))? - } - prod.end() + ser.serialize_named_product_raw(self) }); impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => { let mut ty = &*self.ty().elem_ty; diff --git a/crates/sats/src/time_duration.rs b/crates/sats/src/time_duration.rs index 17f27ba988f..f6d932c4bd8 100644 --- a/crates/sats/src/time_duration.rs +++ b/crates/sats/src/time_duration.rs @@ -83,7 +83,9 @@ impl TimeDuration { self.to_micros().checked_sub(other.to_micros()).map(Self::from_micros) } - /// Generate an `iso8601` format string + /// Generate an `iso8601` format string. + /// + /// This is the better supported format for use for the `pg wire protocol`. /// /// Example: /// ```rust diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py index cad3f7b02e9..553e3938b9f 100644 --- a/smoketests/tests/pg_wire.py +++ b/smoketests/tests/pg_wire.py @@ -143,8 +143,8 @@ def test_sql_format(self): (1 row)""") self.assertSql(token, "SELECT * FROM t_ints_tuple", """\ tuple ------------------------------------------------------------ - [-25,-3224,-23443,-2344353,-234434897853,"-0x369568a7bd"] +--------------------------------------------------------------------------------------------------------- + {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} (1 row)""") self.assertSql(token, "SELECT * FROM t_uints", """\ u8 | u16 | u32 | u64 | u128 | u256 @@ -153,17 +153,16 @@ def test_sql_format(self): (1 row)""") self.assertSql(token, "SELECT * FROM t_uints_tuple", """\ tuple ---------------------------------------------------------- - [105,1050,83892,48937498,4378528978889,"0x3fb74aa17c9"] +------------------------------------------------------------------------------------------------------- + {"u8": 105, "u16": 1050, "u32": 83892, "u64": 48937498, "u128": 4378528978889, "u256": 4378528978889} (1 row)""") self.assertSql(token, "SELECT * FROM t_others", """\ bool | f32 | f64 | str | bytes | identity | connection_id | timestamp | duration ------+-----------+---------------------+---------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- t | 594806.56 | -3454353.3453890434 | This is spacetimedb | \\x01020304050607 | \\x0000000000000000000000000000000000000000000000000000000000000001 | \\x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | PT10S (1 row)""") - # TODO: Uncomment when tuple support is added self.assertSql(token, "SELECT * FROM t_others_tuple", """\ tuple --------------------------------------------------------------------------------------------------------- - [true,594806.56,-3454353.3453890434,"This is spacetimedb","01020304050607",["0x1"],[0],[0],[10000000]] +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000001, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": PT10S} (1 row)""") From cd2e4a5612403da6c54b2500d2aa48d73fa41b06 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Fri, 18 Jul 2025 11:08:47 -0500 Subject: [PATCH 03/10] Add Ci step for install psql --- .github/workflows/ci.yml | 6 +++++- smoketests/tests/pg_wire.py | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62edf2d64a9..831b38203e5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: include: - { runner: spacetimedb-runner, smoketest_args: --docker } - { runner: windows-latest, smoketest_args: --no-build-cli } - runner: [spacetimedb-runner, windows-latest] + runner: [ spacetimedb-runner, windows-latest ] runs-on: ${{ matrix.runner }} steps: - name: Find Git ref @@ -44,6 +44,10 @@ jobs: - uses: actions/setup-dotnet@v4 with: global-json-file: modules/global.json + - name: Install psql (Windows) + if: runner.os == 'Windows' + run: choco install psql -y --no-progress + shell: powershell - name: Build and start database (Linux) if: runner.os == 'Linux' run: docker compose up -d diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py index 553e3938b9f..ebef08dd9f6 100644 --- a/smoketests/tests/pg_wire.py +++ b/smoketests/tests/pg_wire.py @@ -8,7 +8,7 @@ def psql(identity: str, sql: str) -> str: """Call `psql` and execute the given SQL statement.""" result = subprocess.run( - ["psql", "port=5432 host=127.0.0.1 user=postgres dbname=quickstart", "--quiet", "-c", sql], + ["psql", "-h", "127.0.0.1", "-p", "5432", "-U", "postgres", "-d", "quickstart", "--quiet", "-c", sql], encoding="utf8", env={**os.environ, "PGPASSWORD": identity}, stdout=subprocess.PIPE, From 137e8fdc04ba666513545733bada365bf11fd739 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Thu, 7 Aug 2025 13:07:09 -0500 Subject: [PATCH 04/10] Output duration differently based on client to avoid a breaking change --- crates/cli/src/subcommands/sql.rs | 66 ++++++++++++++++++++++++++----- crates/sats/src/satn.rs | 5 ++- 2 files changed, 61 insertions(+), 10 deletions(-) diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 12107b43a44..15917821fd7 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -112,7 +112,7 @@ fn print_stmt_result( for (pos, result) in if_empty .into_iter() .chain(stmt_results.iter().map(|stmt_result| { - let (stats, table) = stmt_result_to_table(stmt_result)?; + let (stats, table) = stmt_result_to_table(PsqlClient::SpacetimeDB, stmt_result)?; anyhow::Ok(StmtResult { stats: with_stats.is_some().then_some(stats), @@ -158,12 +158,13 @@ pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool Ok(()) } -fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> { +fn stmt_result_to_table(client: PsqlClient, stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> { let stats = StmtStats::from(stmt_result); let SqlStmtResult { schema, rows, .. } = stmt_result; let ty = Typespace::EMPTY.with_type(schema); let table = build_table( + client, schema, rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))), )?; @@ -195,6 +196,7 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error /// Generates a [`tabled::Table`] from a schema and rows, using the style of a psql table. fn build_table( + client: PsqlClient, schema: &ProductType, rows: impl Iterator>, ) -> Result { @@ -212,7 +214,7 @@ fn build_table( let row = row?; builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| { let ty = satn::PsqlType { - client: PsqlClient::SpacetimeDB, + client, tuple: ty.ty(), field: &ty.ty().elements[idx], idx, @@ -448,8 +450,10 @@ Roundtrip time: 1.00ms"#, Ok(()) } - fn expect_psql_table(ty: &ProductType, rows: Vec, expected: &str) { - let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string(); + fn expect_psql_table(client: PsqlClient, ty: &ProductType, rows: Vec, expected: &str) { + let table = build_table(client, ty, rows.into_iter().map(Ok::<_, ()>)) + .unwrap() + .to_string(); let mut table = table.split('\n').map(|x| x.trim_end()).join("\n"); table.insert(0, '\n'); assert_eq!(expected, table); @@ -478,6 +482,17 @@ Roundtrip time: 1.00ms"#, ]; expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + column 0 | column 1 | column 2 | column 3 | column 4 | column 5 +----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------- + "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, + ); + + expect_psql_table( + PsqlClient::Postgres, &kind, vec![value], r#" @@ -509,6 +524,17 @@ Roundtrip time: 1.00ms"#, ]; expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + bool | str | bytes | identity | connection_id | timestamp | duration +------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------- + true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, + ); + + expect_psql_table( + PsqlClient::Postgres, &kind, vec![value.clone()], r#" @@ -523,12 +549,23 @@ Roundtrip time: 1.00ms"#, let value = product![value.clone()]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, vec![value.clone()], r#" column 0 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = P0D)"#, +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#, + ); + + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + column 0 +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000000, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": P0D}"#, ); let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into(); @@ -536,12 +573,23 @@ Roundtrip time: 1.00ms"#, let value = product![value]; expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + tuple +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + (col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#, + ); + + expect_psql_table( + PsqlClient::Postgres, &kind, vec![value], r#" tuple --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - (col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = P0D))"#, +---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000000, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": P0D}}"#, ); Ok(()) diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 140fa0af4f5..2804a533784 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -837,7 +837,10 @@ impl TypedWriter for SqlFormatter<'_, '_> { } fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { - write!(self.fmt, "{}", value.to_iso8601()) + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "{value}"), + PsqlClient::Postgres => write!(self.fmt, "{}", value.to_iso8601()), + } } fn write_record( From 2a45f07ed9ab00b24bc3da887ed12f8da7443560 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Mon, 11 Aug 2025 10:03:51 -0500 Subject: [PATCH 05/10] Removed tls support, TBD awaiting decision for general solution --- Cargo.lock | 116 --------------------- Cargo.toml | 4 - crates/client-api/src/auth.rs | 6 -- crates/pg/Cargo.toml | 4 - crates/pg/src/encoder.rs | 1 - crates/pg/src/pg_server.rs | 44 +------- crates/standalone/src/subcommands/start.rs | 5 +- 7 files changed, 3 insertions(+), 177 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0e1e2c4dcbe..92abffae5fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -213,45 +213,6 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" -[[package]] -name = "asn1-rs" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5493c3bedbacf7fd7382c6346bbd66687d12bbaad3a89a2d2c303ee6cf20b048" -dependencies = [ - "asn1-rs-derive", - "asn1-rs-impl", - "displaydoc", - "nom", - "num-traits", - "rusticata-macros", - "thiserror 1.0.69", - "time", -] - -[[package]] -name = "asn1-rs-derive" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "965c2d33e53cb6b267e148a4cb0760bc01f4904c1cd4bb4002a085bb016d1490" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.101", - "synstructure 0.13.2", -] - -[[package]] -name = "asn1-rs-impl" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7b18050c2cd6fe86c3a76584ef5e0baf286d038cda203eb6223df2cc413565f7" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.101", -] - [[package]] name = "async-stream" version = "0.3.6" @@ -1576,20 +1537,6 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da692b8d1080ea3045efaab14434d40468c3d8657e42abddfffca87b428f4c1b" -[[package]] -name = "der-parser" -version = "9.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cd0a5c643689626bec213c4d8bd4d96acc8ffdb4ad4bb6bc16abf27d5f4b553" -dependencies = [ - "asn1-rs", - "displaydoc", - "nom", - "num-bigint", - "num-traits", - "rusticata-macros", -] - [[package]] name = "deranged" version = "0.4.0" @@ -3737,15 +3684,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "oid-registry" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8d8034d9489cdaf79228eb9f6a3b8d7bb32ba00d6645ebd48eef4077ceb5bd9" -dependencies = [ - "asn1-rs", -] - [[package]] name = "once_cell" version = "1.21.3" @@ -4600,20 +4538,6 @@ dependencies = [ "crossbeam-utils", ] -[[package]] -name = "rcgen" -version = "0.13.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75e669e5202259b5314d1ea5397316ad400819437857b90861765f24c4cf80a2" -dependencies = [ - "pem", - "ring", - "rustls-pki-types", - "time", - "x509-parser", - "yasna", -] - [[package]] name = "rdrand" version = "0.4.0" @@ -4976,15 +4900,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rusticata-macros" -version = "4.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "faf0c4a6ece9950b9abdb62b1cfcf2a68b3b67a10ba445b3bb85be2a293d0632" -dependencies = [ - "nom", -] - [[package]] name = "rustix" version = "0.38.44" @@ -6210,15 +6125,11 @@ dependencies = [ "http 1.3.1", "log", "pgwire", - "rcgen", - "rustls", - "rustls-pki-types", "spacetimedb-client-api", "spacetimedb-client-api-messages", "spacetimedb-lib 1.4.0", "thiserror 1.0.69", "tokio", - "tokio-rustls", ] [[package]] @@ -8777,24 +8688,6 @@ dependencies = [ "tap", ] -[[package]] -name = "x509-parser" -version = "0.16.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fcbc162f30700d6f3f82a24bf7cc62ffe7caea42c0b2cba8bf7f3ae50cf51f69" -dependencies = [ - "asn1-rs", - "data-encoding", - "der-parser", - "lazy_static", - "nom", - "oid-registry", - "ring", - "rusticata-macros", - "thiserror 1.0.69", - "time", -] - [[package]] name = "xattr" version = "1.5.0" @@ -8841,15 +8734,6 @@ version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cfe53a6657fd280eaa890a3bc59152892ffa3e30101319d168b781ed6529b049" -[[package]] -name = "yasna" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" -dependencies = [ - "time", -] - [[package]] name = "yoke" version = "0.7.5" diff --git a/Cargo.toml b/Cargo.toml index a18ace7c8c4..3b202f019f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -229,7 +229,6 @@ rand08 = { package = "rand", version = "0.8" } rand = "0.9" rayon = "1.8" rayon-core = "1.11.0" -rcgen = { version = "0.13.1", features = ["pem", "x509-parser", "crypto", "ring"] } regex = "1" reqwest = { version = "0.12", features = ["stream", "json"] } ron = "0.8" @@ -238,8 +237,6 @@ rust_decimal = { version = "1.29.1", features = ["db-tokio-postgres"] } rustc-demangle = "0.1.21" rustc-hash = "2" rustyline = { version = "12.0.0", features = [] } -rustls-pki-types = "1.11.0" -rustls = "0.23.26" scoped-tls = "1.0.1" scopeguard = "1.1.0" second-stack = "0.3" @@ -271,7 +268,6 @@ termcolor = "1.2.0" thin-vec = "0.2.13" thiserror = "1.0.37" tokio = { version = "1.37", features = ["full"] } -tokio-rustls = "0.26.2" tokio_metrics = { version = "0.4.0" } tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4"] } tokio-stream = "0.1.17" diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index d87f24a10f1..988b8a3cba2 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -210,8 +210,6 @@ pub trait JwtAuthProvider: Sync + Send + TokenSigner { /// /// The `/identity/public-key` route calls this method to return the public key to callers. fn public_key_bytes(&self) -> &[u8]; - /// Return the private key used to verify JWTs, as the bytes of a PEM private key file. - fn private_key_bytes(&self) -> &[u8]; } pub struct JwtKeyAuthProvider { @@ -259,10 +257,6 @@ impl JwtAuthProvider for JwtKeyAuthProvider &[u8] { &self.keys.public_pem } - - fn private_key_bytes(&self) -> &[u8] { - &self.keys.private_pem - } } #[cfg(test)] diff --git a/crates/pg/Cargo.toml b/crates/pg/Cargo.toml index 60eb617e760..dd49122dea0 100644 --- a/crates/pg/Cargo.toml +++ b/crates/pg/Cargo.toml @@ -18,9 +18,5 @@ futures.workspace = true http.workspace = true log.workspace = true pgwire.workspace = true -rustls-pki-types.workspace = true -rcgen.workspace = true -rustls.workspace = true thiserror.workspace = true tokio.workspace = true -tokio-rustls.workspace = true \ No newline at end of file diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs index 5b4eb8d8062..3cd565b1a92 100644 --- a/crates/pg/src/encoder.rs +++ b/crates/pg/src/encoder.rs @@ -145,7 +145,6 @@ impl TypedWriter for PsqlFormatter<'_> { } } -// Tests #[cfg(test)] mod tests { use super::*; diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index 9c23f724310..2c400e34cdb 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -16,17 +16,13 @@ use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::Format; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::ClientInfo; -use pgwire::api::{NoopErrorHandler, METADATA_DATABASE, METADATA_USER}; +use pgwire::api::{ClientInfo, NoopErrorHandler, METADATA_DATABASE, METADATA_USER}; use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; use pgwire::messages::startup::Authentication; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; use pgwire::tokio::process_socket; -use rcgen::{CertificateParams, DistinguishedName, DnType, KeyPair}; -use rustls_pki_types::pem::PemObject; -use rustls_pki_types::PrivateKeyDer; use spacetimedb_client_api::auth::{validate_token, SpacetimeAuth}; use spacetimedb_client_api::routes::database; use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; @@ -40,8 +36,6 @@ use spacetimedb_lib::ProductValue; use thiserror::Error; use tokio::net::TcpListener; use tokio::sync::{watch, Mutex}; -use tokio_rustls::rustls::ServerConfig; -use tokio_rustls::TlsAcceptor; #[derive(Error, Debug)] pub(crate) enum PgError { @@ -54,12 +48,6 @@ pub(crate) enum PgError { #[error(transparent)] Pg(#[from] PgWireError), #[error(transparent)] - RcGen(#[from] rcgen::Error), - #[error(transparent)] - Pem(#[from] rustls_pki_types::pem::Error), - #[error(transparent)] - RustTls(#[from] rustls::Error), - #[error(transparent)] Other(#[from] anyhow::Error), } @@ -366,38 +354,11 @@ impl(_ctx: &T, private_key: &[u8]) -> Result { - let private: PrivateKeyDer = PrivateKeyDer::from_pem_slice(private_key)?; - - let keypair = KeyPair::from_der_and_sign_algo(&private, &rcgen::PKCS_ECDSA_P256_SHA256)?; - - let mut params = CertificateParams::new(vec![ - "localhost".to_string(), - "127.0.0.1".to_string(), - "::1".to_string(), - ])?; - params.distinguished_name = DistinguishedName::new(); - params.distinguished_name.push(DnType::CommonName, "localhost"); - let cert = params.self_signed(&keypair)?; - let cert_der = cert.der().clone(); - - let mut config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(vec![cert_der], private)?; - - config.alpn_protocols = vec![b"postgresql".to_vec(), b"spacetime".to_vec()]; - - Ok(TlsAcceptor::from(Arc::new(config))) -} - pub async fn start_pg( mut shutdown: watch::Receiver<()>, ctx: Arc, listen_address: &str, - private_key: &[u8], ) { - let tls_acceptor = Arc::new(setup_tls(&ctx, private_key).unwrap()); - let auth = SpacetimeAuth::alloc(&ctx).await.unwrap(); let factory = Arc::new(PgSpacetimeDBFactory::new(ctx, auth)); @@ -413,10 +374,9 @@ pub async fn start_pg { match accept_result { Ok((stream, _addr)) => { - let tls_acceptor_ref = tls_acceptor.clone(); let factory_ref = factory.clone(); tokio::spawn(async move { - process_socket(stream, Some(tls_acceptor_ref), factory_ref).await.inspect_err(|err|{ + process_socket(stream, None, factory_ref).await.inspect_err(|err|{ log::error!("PG: Error processing socket: {err:?}"); }) }); diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index 9823a7095db..bbd3180d5ba 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -11,11 +11,9 @@ use spacetimedb::db::{self, Storage}; use spacetimedb::startup::{self, TracingOptions}; use spacetimedb::util::jobs::JobCores; use spacetimedb::worker_metrics; -use spacetimedb_client_api::auth::JwtAuthProvider; use spacetimedb_client_api::routes::database::DatabaseRoutes; use spacetimedb_client_api::routes::router; use spacetimedb_client_api::routes::subscribe::WebSocketOptions; -use spacetimedb_client_api::NodeDelegate; use spacetimedb_paths::cli::{PrivKeyPath, PubKeyPath}; use spacetimedb_paths::server::{ConfigToml, ServerDataDir}; use tokio::net::TcpListener; @@ -185,9 +183,8 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { socket2::SockRef::from(&tcp).set_nodelay(true)?; log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?); let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); - let private_key = ctx.jwt_auth_provider().private_key_bytes().to_vec(); tokio::select! { - _ = pg_server::start_pg(shutdown_rx.clone(), ctx, listen_addr, &private_key) => {}, + _ = pg_server::start_pg(shutdown_rx.clone(), ctx, listen_addr) => {}, _ = axum::serve(tcp, service).with_graceful_shutdown(async move { shutdown_rx.changed().await.ok(); }) => {}, From 11cc983566e07e107c2b4fb538f2d3c7e3850901 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Wed, 13 Aug 2025 09:56:26 -0500 Subject: [PATCH 06/10] Adding more tests --- .github/workflows/ci.yml | 2 + crates/client-api/src/auth.rs | 24 ---- crates/pg/LICENSE | 1 + crates/pg/src/encoder.rs | 87 ++++++++++--- crates/pg/src/pg_server.rs | 89 +++++++------ crates/sats/src/product_value.rs | 8 -- crates/sats/src/satn.rs | 37 +++++- crates/standalone/src/subcommands/start.rs | 5 +- smoketests/tests/pg_wire.py | 143 +++++++++++++++++++-- 9 files changed, 288 insertions(+), 108 deletions(-) mode change 100644 => 120000 crates/pg/LICENSE diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 831b38203e5..62fec318a9d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,6 +63,8 @@ jobs: - uses: actions/setup-python@v5 with: { python-version: '3.12' } if: runner.os == 'Windows' + - name: Install psycopg2 + run: python -m pip install psycopg2-binary - name: Run smoketests # Note: clear_database and replication only work in private run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 988b8a3cba2..c616ff53ea6 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -132,30 +132,6 @@ impl TokenClaims { } impl SpacetimeAuth { - pub fn from_claims( - ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized), - claims: SpacetimeIdentityClaims, - ) -> axum::response::Result { - let claims = TokenClaims { - issuer: claims.issuer, - subject: claims.subject, - audience: claims.audience, - }; - - let creds = { - let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; - SpacetimeCreds::from_signed_token(token) - }; - let identity = claims.id(); - - Ok(Self { - creds, - identity, - subject: claims.subject, - issuer: claims.issuer, - }) - } - /// Allocate a new identity, and mint a new token for it. pub async fn alloc(ctx: &(impl NodeDelegate + ControlStateDelegate + ?Sized)) -> axum::response::Result { // Generate claims with a random subject. diff --git a/crates/pg/LICENSE b/crates/pg/LICENSE deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/crates/pg/LICENSE b/crates/pg/LICENSE new file mode 120000 index 00000000000..8540cf8a991 --- /dev/null +++ b/crates/pg/LICENSE @@ -0,0 +1 @@ +../../licenses/BSL.txt \ No newline at end of file diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs index 3cd565b1a92..fbe7cfd4fd6 100644 --- a/crates/pg/src/encoder.rs +++ b/crates/pg/src/encoder.rs @@ -2,7 +2,7 @@ use crate::pg_server::PgError; use pgwire::api::portal::Format; use pgwire::api::results::{DataRowEncoder, FieldInfo}; use pgwire::api::Type; -use spacetimedb_lib::sats::satn::{PsqlPrintFmt, PsqlType, TypedWriter}; +use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter}; use spacetimedb_lib::sats::{satn, ValueWithType}; use spacetimedb_lib::{ ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, @@ -32,13 +32,10 @@ pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { AlgebraicType::Bool => Type::BOOL, AlgebraicType::U8 | AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2, AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4, - AlgebraicType::I64 => Type::INT8, - AlgebraicType::U32 - | AlgebraicType::U64 - | AlgebraicType::I128 - | AlgebraicType::U128 - | AlgebraicType::I256 - | AlgebraicType::U256 => Type::NUMERIC, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8, + AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => { + Type::NUMERIC + } AlgebraicType::F32 => Type::FLOAT4, AlgebraicType::F64 => Type::FLOAT8, AlgebraicType::Array(ty) => match *ty.elem_ty { @@ -47,9 +44,8 @@ pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { AlgebraicType::U8 => Type::BYTEA, AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY, AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY, - AlgebraicType::I64 => Type::INT8_ARRAY, - AlgebraicType::U32 - | AlgebraicType::U64 + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY, + AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 @@ -63,6 +59,7 @@ pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { _ => Type::JSON, }, AlgebraicType::Sum(sum) if sum.is_simple_enum() => Type::ANYENUM, + AlgebraicType::Sum(_) => Type::JSON, _ => Type::UNKNOWN, } } @@ -134,12 +131,27 @@ impl TypedWriter for PsqlFormatter<'_> { fn write_variant( &mut self, - _tag: u8, + tag: u8, ty: PsqlType, - _name: Option<&str>, + name: Option<&str>, value: ValueWithType, ) -> Result<(), Self::Error> { - let json = satn::PsqlWrapper { ty, value }.to_string(); + // Is a simple enum? + if let AlgebraicType::Sum(sum) = &ty.field.algebraic_type { + if sum.is_simple_enum() { + if let Some(variant_name) = name { + self.encoder.encode_field(&variant_name)?; + return Ok(()); + } + } + } + + let PsqlChars { start, sep, end, quote } = ty.client.format_chars(); + let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string())); + let json = format!( + "{start}{quote}{name}{quote}{sep} {}{end}", + satn::PsqlWrapper { ty, value } + ); self.encoder.encode_field(&json)?; Ok(()) } @@ -152,7 +164,7 @@ mod tests { use futures::StreamExt; use spacetimedb_client_api_messages::http::SqlStmtResult; use spacetimedb_lib::sats::algebraic_value::Packed; - use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType}; + use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant}; use spacetimedb_lib::{ConnectionId, Identity}; async fn run(schema: ProductType, row: ProductValue) -> String { @@ -217,13 +229,54 @@ mod tests { #[tokio::test] async fn test_enum() { - let schema = ProductType::from([AlgebraicType::option(AlgebraicType::I64)]); + let some = AlgebraicType::option(AlgebraicType::I64); + let schema = ProductType::from([some.clone(), some]); let value = product![ AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Some(1) + AlgebraicValue::sum(1, AlgebraicValue::unit()), // None ]; let row = run(schema, value).await; - assert_eq!(row, "\0\0\0\u{1}1"); + assert_eq!(row, "\0\0\0\u{b}{\"some\": 1}\0\0\0\u{c}{\"none\": {}}"); + + let color = AlgebraicType::Sum([SumTypeVariant::new_named(AlgebraicType::I64, "Gray")].into()); + let nested = AlgebraicType::option(color.clone()); + let schema = ProductType::from([color, nested]); + // {"Gray": 1}, {"some": {"Gray": 2}} + let value = product![ + AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Gray(1) + AlgebraicValue::sum(0, AlgebraicValue::sum(0, AlgebraicValue::I64(2))), // Some(Gray(2)) + ]; + let row = run(schema.clone(), value.clone()).await; + assert_eq!(row, "\0\0\0\u{b}{\"Gray\": 1}\0\0\0\u{15}{\"some\": {\"Gray\": 2}}"); + + // Now a product with a nested option + // {x: {"Gray": 1}, y: "a"}, {some: {x: {"Gray": 2}, y: "b"}} + let product = AlgebraicType::product([ + ProductTypeElement::new(AlgebraicType::Product(schema), Some("x".into())), + ProductTypeElement::new(AlgebraicType::String, Some("y".into())), + ]); + let schema = ProductType::from([product.clone()]); + let value = product![AlgebraicValue::product(vec![ + value.into(), + AlgebraicValue::String("a".into()), + ])]; + let row = run(schema, value).await; + assert_eq!( + row, + "\0\0\0G{\"x\": {\"col_0\": {\"Gray\": 1}, \"col_1\": {\"some\": {\"Gray\": 2}}}, \"y\": \"a\"}" + ); + + // Now a simple enum + let names = AlgebraicType::simple_enum(["A", "B", "C"].into_iter()); + let schema = ProductType::from([names.clone(), names.clone(), names]); + let value = product![ + AlgebraicValue::enum_simple(0), // A + AlgebraicValue::enum_simple(1), // B + AlgebraicValue::enum_simple(2), // C + ]; + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}A\0\0\0\u{1}B\0\0\0\u{1}C"); } #[tokio::test] diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index 2c400e34cdb..f1bef37e132 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -16,14 +16,14 @@ use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::Format; use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, NoopErrorHandler, METADATA_DATABASE, METADATA_USER}; +use pgwire::api::{ClientInfo, NoopErrorHandler, METADATA_DATABASE}; use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; use pgwire::messages::startup::Authentication; use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; use pgwire::tokio::process_socket; -use spacetimedb_client_api::auth::{validate_token, SpacetimeAuth}; +use spacetimedb_client_api::auth::validate_token; use spacetimedb_client_api::routes::database; use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; @@ -32,7 +32,7 @@ use spacetimedb_client_api_messages::name::DatabaseName; use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer}; use spacetimedb_lib::sats::{satn, Serialize, Typespace}; use spacetimedb_lib::version::spacetimedb_lib_version; -use spacetimedb_lib::ProductValue; +use spacetimedb_lib::{Identity, ProductValue}; use thiserror::Error; use tokio::net::TcpListener; use tokio::sync::{watch, Mutex}; @@ -47,6 +47,8 @@ pub(crate) enum PgError { DatabaseNameRequired, #[error(transparent)] Pg(#[from] PgWireError), + #[error("SSL is not supported by SpacetimeDB")] + SSLNotSupported, #[error(transparent)] Other(#[from] anyhow::Error), } @@ -64,7 +66,7 @@ impl From for PgWireError { #[derive(Clone)] struct Metadata { database: String, - auth: SpacetimeAuth, + caller_identity: Identity, } pub(crate) fn to_rows( @@ -144,13 +146,13 @@ async fn response(res: axum::response::Result, database: &str) -> Result { ctx: Arc, - cached: Mutex, + cached: Mutex>, parameter_provider: DefaultServerParameterProvider, } impl PgSpacetimeDB { async fn exe_sql<'a>(&self, query: String) -> PgWireResult>> { - let params = self.cached.lock().await; + let params = self.cached.lock().await.clone().unwrap(); let db = SqlParams { name_or_identity: database::NameOrIdentity::Name(DatabaseName(params.database.clone())), }; @@ -160,7 +162,7 @@ impl PgSpace self.ctx.clone(), db, SqlQueryParams { confirmed: true }, - params.auth.identity, + params.caller_identity, query.to_string(), ) .await, @@ -180,16 +182,14 @@ impl PgSpace let mut result = Vec::with_capacity(sql.len()); for sql_result in sql { let header = row_desc(&sql_result.schema, &Format::UnifiedText); - - let tag = Tag::new(&stats(&sql_result)); - - if sql_result.rows.is_empty() { - result.push(Response::EmptyQuery); + if sql_result.rows.is_empty() && !query.to_uppercase().contains("SELECT") { + let tag = Tag::new(&stats(&sql_result)); + result.push(Response::Execution(tag)); } else { let rows = to_rows(sql_result, header.clone())?; - result.push(Response::Query(QueryResponse::new(header, rows))); + let q = QueryResponse::new(header, rows); + result.push(Response::Query(q)); } - result.push(Response::Execution(tag)); } Ok(result) } @@ -225,8 +225,6 @@ impl response(SpacetimeAuth::from_claims(&self.ctx, claims), &database).await?, + let caller_identity = match validate_token(&self.ctx, &pwd.password).await { + Ok(claims) => claims.identity, Err(err) => { - log::error!("PG: Authentication failed for user {user} on database {database}: {err}"); + log::error!( + "PG: Authentication failed for identity `{}` on database {database}: {err}", + pwd.password + ); let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); return close_client(client, err).await; } }; - log::info!( - "PG: Connected to database: {user}@{database} using identity `{}`", - auth.identity - ); + log::info!("PG: Connected to database: {database} using identity `{caller_identity}`"); - let metadata = Metadata { database, auth }; - self.cached.lock().await.clone_from(&metadata); + let metadata = Metadata { + database, + caller_identity, + }; + self.cached.lock().await.clone_from(&Some(metadata)); finish_authentication(client, &self.parameter_provider).await?; } - _ => {} + PgWireFrontendMessage::SslRequest(ssl) => { + if ssl.is_some() { + let err = PgError::SSLNotSupported; + log::error!("{err}"); + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; + } + } + // The other messages are for features not supported by SpacetimeDB, that are rejected by the parser. + _ => { + unreachable!("Unsupported startup message: {message:?}"); + } } Ok(()) } @@ -306,18 +318,15 @@ struct PgSpacetimeDBFactory { } impl PgSpacetimeDBFactory { - pub fn new(ctx: Arc, auth: SpacetimeAuth) -> Self { + pub fn new(ctx: Arc) -> Self { let mut parameter_provider = DefaultServerParameterProvider::default(); parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version()); Self { handler: Arc::new(PgSpacetimeDB { ctx, - cached: Mutex::new(Metadata { - // This is a placeholder, it will be set in the startup handler - database: "".to_string(), - auth, - }), + // This is a placeholder, it will be set in the startup handler + cached: None.into(), parameter_provider, }), } @@ -357,13 +366,9 @@ impl( mut shutdown: watch::Receiver<()>, ctx: Arc, - listen_address: &str, + tcp: TcpListener, ) { - let auth = SpacetimeAuth::alloc(&ctx).await.unwrap(); - let factory = Arc::new(PgSpacetimeDBFactory::new(ctx, auth)); - - let server_addr = format!("{}:5432", listen_address.split(':').next().unwrap()); - let tcp = TcpListener::bind(server_addr).await.unwrap(); + let factory = Arc::new(PgSpacetimeDBFactory::new(ctx)); log::debug!( "PG: Starting SpacetimeDB Protocol listening on {}", @@ -376,7 +381,7 @@ pub async fn start_pg { let factory_ref = factory.clone(); tokio::spawn(async move { - process_socket(stream, None, factory_ref).await.inspect_err(|err|{ + process_socket(stream, None, factory_ref).await.inspect_err(|err|{ log::error!("PG: Error processing socket: {err:?}"); }) }); diff --git a/crates/sats/src/product_value.rs b/crates/sats/src/product_value.rs index 68aaf6433ce..5db2b7c534d 100644 --- a/crates/sats/src/product_value.rs +++ b/crates/sats/src/product_value.rs @@ -114,14 +114,6 @@ impl ProductValue { }) } - /// Assumes that we are dealing with a [`ProductType::is_special`] value if it has a single element. - pub fn as_special_value_raw(&self) -> Option<&AlgebraicValue> { - match &*self.elements { - [val] => Some(val), - _ => None, - } - } - /// Interprets the value at field of `self` identified by `index` as a `bool`. pub fn field_as_bool(&self, index: usize, named: Option<&'static str>) -> Result { self.extract_field(index, named, |f| f.as_bool().copied()) diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 2804a533784..cd388213c03 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -1,6 +1,6 @@ use crate::time_duration::TimeDuration; use crate::timestamp::Timestamp; -use crate::{i256, u256, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; +use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; @@ -450,6 +450,32 @@ pub enum PsqlClient { Postgres, } +pub struct PsqlChars { + pub start: char, + pub sep: &'static str, + pub end: char, + pub quote: &'static str, +} + +impl PsqlClient { + pub fn format_chars(&self) -> PsqlChars { + match self { + PsqlClient::SpacetimeDB => PsqlChars { + start: '(', + sep: " =", + end: ')', + quote: "", + }, + PsqlClient::Postgres => PsqlChars { + start: '{', + sep: ":", + end: '}', + quote: "\"", + }, + } + } +} + /// How format of the `SQL` output? #[derive(Debug, Copy, Clone, PartialEq, Display)] pub enum PsqlPrintFmt { @@ -522,7 +548,7 @@ impl PsqlType<'_> { } /// An implementation of [`Serializer`](ser::Serializer) for `SQL` output. -struct SqlFormatter<'a, 'f> { +pub struct SqlFormatter<'a, 'f> { fmt: SatnFormatter<'a, 'f>, ty: &'a PsqlType<'a>, } @@ -788,7 +814,7 @@ impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { let sv = sum.value(); let (tag, val) = (sv.tag, &*sv.value); let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. - let product = ProductType::from([var_ty.algebraic_type.clone()]); + let product = ProductType::from([AlgebraicType::sum(sum.ty().clone())]); let ty = PsqlType { client: self.ty.client, tuple: &product, @@ -847,10 +873,7 @@ impl TypedWriter for SqlFormatter<'_, '_> { &mut self, fields: Vec<(Cow, PsqlType<'_>, ValueWithType)>, ) -> Result<(), Self::Error> { - let (start, sep, end, quote) = match self.ty.client { - PsqlClient::SpacetimeDB => ("(", " =", ")", ""), - PsqlClient::Postgres => ("{", ":", "}", "\""), - }; + let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars(); write!(self.fmt, "{start}")?; for (idx, (name, ty, value)) in fields.into_iter().enumerate() { if idx > 0 { diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index bbd3180d5ba..e8f2e347005 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -182,9 +182,12 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { let tcp = TcpListener::bind(listen_addr).await?; socket2::SockRef::from(&tcp).set_nodelay(true)?; log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?); + let pg_server_addr = format!("{}:5432", listen_addr.split(':').next().unwrap()); + let tcp_pg = TcpListener::bind(pg_server_addr).await?; + let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); tokio::select! { - _ = pg_server::start_pg(shutdown_rx.clone(), ctx, listen_addr) => {}, + _ = pg_server::start_pg(shutdown_rx.clone(), ctx, tcp_pg) => {}, _ = axum::serve(tcp, service).with_graceful_shutdown(async move { shutdown_rx.changed().await.ok(); }) => {}, diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py index ebef08dd9f6..7efef115634 100644 --- a/smoketests/tests/pg_wire.py +++ b/smoketests/tests/pg_wire.py @@ -2,19 +2,20 @@ import subprocess import os import tomllib +import psycopg2 -def psql(identity: str, sql: str) -> str: +def psql(identity: str, sql: str, extra=None) -> str: """Call `psql` and execute the given SQL statement.""" - + if extra is None: + extra = dict() result = subprocess.run( ["psql", "-h", "127.0.0.1", "-p", "5432", "-U", "postgres", "-d", "quickstart", "--quiet", "-c", sql], encoding="utf8", - env={**os.environ, "PGPASSWORD": identity}, + env={**os.environ, **extra, "PGPASSWORD": identity}, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, - # check=True ) if result.stderr: @@ -22,11 +23,18 @@ def psql(identity: str, sql: str) -> str: return result.stdout.strip() +def connect_db(identity: str): + """Connect to the database using `psycopg2`.""" + conn = psycopg2.connect(host="127.0.0.1", port=5432, user="postgres", password=identity, dbname="quickstart") + conn.set_session(autocommit=True) # Disable automic transaction + return conn + + class SqlFormat(Smoketest): AUTOPUBLISH = False MODULE_CODE = """ use spacetimedb::sats::{i256, u256}; -use spacetimedb::{ConnectionId, Identity, ReducerContext, Table, Timestamp, TimeDuration}; +use spacetimedb::{ConnectionId, Identity, ReducerContext, SpacetimeType, Table, Timestamp, TimeDuration}; #[derive(Copy, Clone)] #[spacetimedb::table(name = t_ints, public)] @@ -79,6 +87,37 @@ class SqlFormat(Smoketest): tuple: TOthers } +#[derive(SpacetimeType, Debug, Clone, Copy)] +pub enum Action { + Inactive, + Active, +} + +#[derive(SpacetimeType, Debug, Clone, Copy)] +pub enum Color { + Gray(u8), +} + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_simple_enum, public)] +pub struct TSimpleEnum { + id : u32, + action: Action, +} + +#[spacetimedb::table(name = t_enum, public)] +pub struct TEnum { + id : u32, + color: Color, +} + +#[spacetimedb::table(name = t_nested, public)] +pub struct TNested { + en: TEnum, + se: TSimpleEnum, + ints: TInts, +} + #[spacetimedb::reducer] pub fn test(ctx: &ReducerContext) { let tuple = TInts { @@ -89,6 +128,7 @@ class SqlFormat(Smoketest): i128: -234434897853, i256: (-234434897853i128).into(), }; + let ints = tuple; ctx.db.t_ints().insert(tuple); ctx.db.t_ints_tuple().insert(TIntsTuple { tuple }); @@ -116,6 +156,17 @@ class SqlFormat(Smoketest): }; ctx.db.t_others().insert(tuple.clone()); ctx.db.t_others_tuple().insert(TOthersTuple { tuple }); + + ctx.db.t_simple_enum().insert(TSimpleEnum { id: 1, action: Action::Inactive }); + ctx.db.t_simple_enum().insert(TSimpleEnum { id: 2, action: Action::Active }); + + ctx.db.t_enum().insert(TEnum { id: 1, color: Color::Gray(128) }); + + ctx.db.t_nested().insert(TNested { + en: TEnum { id: 1, color: Color::Gray(128) }, + se: TSimpleEnum { id: 2, action: Action::Active }, + ints, + }); } """ @@ -127,12 +178,16 @@ def assertSql(self, token: str, sql: str, expected): print(sql_out) self.assertMultiLineEqual(sql_out, expected) - def test_sql_format(self): - """This test is designed to test calling `psql` to execute SQL statements""" + def read_token(self): + """Read the token from the config file.""" with open(self.config_path, "rb") as f: config = tomllib.load(f) - token = config['spacetimedb_token'] - self.publish_module("quickstart", clear=False) + return config['spacetimedb_token'] + + def test_sql_format(self): + """This test is designed to test calling `psql` to execute SQL statements""" + token = self.read_token() + self.publish_module("quickstart", clear=True) self.call("test") @@ -166,3 +221,73 @@ def test_sql_format(self): ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- {"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000001, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": PT10S} (1 row)""") + self.assertSql(token, "SELECT * FROM t_simple_enum", """\ +id | action +----+---------- + 1 | Inactive + 2 | Active +(2 rows)""") + self.assertSql(token, "SELECT * FROM t_enum", """\ +id | color +----+--------------- + 1 | {"Gray": 128} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_nested", """\ +en | se | ints +-----------------------------------+-------------------------------------+--------------------------------------------------------------------------------------------------------- + {"id": 1, "color": {"Gray": 128}} | {"id": 2, "action": {"Active": {}}} | {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} +(1 row)""") + + def test_sql_conn(self): + """This test is designed to test connecting to the database and executing queries using `psycopg2`""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + self.call("test") + + conn = connect_db(token) + # Check prepared statements (faked by `psycopg2`) + with conn.cursor() as cur: + cur.execute("select * from t_uints where u8 = %s and u16 = %s", (105, 1050)) + rows = cur.fetchall() + self.assertEqual(rows[0], (105, 1050, 83892, 48937498, 4378528978889, 4378528978889)) + # Check long-lived connection + with conn.cursor() as cur: + for _ in range(10): + cur.execute("select count(*) as t from t_uints") + rows = cur.fetchall() + self.assertEqual(rows[0], (1,)) + conn.close() + + def test_failures(self): + """This test is designed to test failure cases""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + + # Empty query + sql_out = psql(token, "") + self.assertEqual(sql_out, "") + + # Connection fails when `ssl` is required + for ssl_mode in ["require", "verify-ca", "verify-full"]: + with self.assertRaises(Exception) as cm: + psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode}) + self.assertIn("not support SSL", str(cm.exception)) + + # But works with `ssl` is disabled or optional + for ssl_mode in ["disable", "allow", "prefer"]: + psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode}) + + # Connection fails with invalid token + with self.assertRaises(Exception) as cm: + psql("invalid_token", "SELECT * FROM t_uints") + self.assertIn("Invalid token", str(cm.exception)) + + # Returns error for unsupported `sql` statements + with self.assertRaises(Exception) as cm: + psql(token, "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints") + self.assertIn("Unsupported", str(cm.exception)) + + # And prepared statements + with self.assertRaises(Exception) as cm: + psql(token, "SELECT * FROM t_uints where u8 = $1") + self.assertIn("Unsupported", str(cm.exception)) From 9bfa676b94eb303c95d6bbcae63601e71d7a7674 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Tue, 2 Sep 2025 11:17:20 -0500 Subject: [PATCH 07/10] Change to notify --- crates/pg/src/encoder.rs | 3 +-- crates/pg/src/pg_server.rs | 6 +++--- crates/standalone/src/subcommands/start.rs | 11 ++++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs index fbe7cfd4fd6..f5a6ed990ed 100644 --- a/crates/pg/src/encoder.rs +++ b/crates/pg/src/encoder.rs @@ -250,8 +250,7 @@ mod tests { let row = run(schema.clone(), value.clone()).await; assert_eq!(row, "\0\0\0\u{b}{\"Gray\": 1}\0\0\0\u{15}{\"some\": {\"Gray\": 2}}"); - // Now a product with a nested option - // {x: {"Gray": 1}, y: "a"}, {some: {x: {"Gray": 2}, y: "b"}} + // Now nested product let product = AlgebraicType::product([ ProductTypeElement::new(AlgebraicType::Product(schema), Some("x".into())), ProductTypeElement::new(AlgebraicType::String, Some("y".into())), diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index f1bef37e132..6750588728d 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -35,7 +35,7 @@ use spacetimedb_lib::version::spacetimedb_lib_version; use spacetimedb_lib::{Identity, ProductValue}; use thiserror::Error; use tokio::net::TcpListener; -use tokio::sync::{watch, Mutex}; +use tokio::sync::{Mutex, Notify}; #[derive(Error, Debug)] pub(crate) enum PgError { @@ -364,7 +364,7 @@ impl( - mut shutdown: watch::Receiver<()>, + shutdown: Arc, ctx: Arc, tcp: TcpListener, ) { @@ -391,7 +391,7 @@ pub async fn start_pg { + _ = shutdown.notified() => { log::info!("PG: Shutting down PostgreSQL server."); break; } diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index e8f2e347005..5615945e547 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -185,15 +185,16 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { let pg_server_addr = format!("{}:5432", listen_addr.split(':').next().unwrap()); let tcp_pg = TcpListener::bind(pg_server_addr).await?; - let (shutdown_tx, mut shutdown_rx) = tokio::sync::watch::channel(()); + let notify = Arc::new(tokio::sync::Notify::new()); + let shutdown_notify = notify.clone(); tokio::select! { - _ = pg_server::start_pg(shutdown_rx.clone(), ctx, tcp_pg) => {}, - _ = axum::serve(tcp, service).with_graceful_shutdown(async move { - shutdown_rx.changed().await.ok(); + _ = pg_server::start_pg(notify.clone(), ctx, tcp_pg) => {}, + _ = axum::serve(tcp, service).with_graceful_shutdown(async move { + shutdown_notify.notified().await; }) => {}, _ = tokio::signal::ctrl_c() => { println!("Shutting down servers..."); - let _ = shutdown_tx.send(()); // Notify all tasks + notify.notify_waiters(); // Notify all tasks } } From da5caa1fb3bbed6b25e9582d9bd018789887dfb0 Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Thu, 4 Sep 2025 15:40:13 -0500 Subject: [PATCH 08/10] Apply patch to upgrade pg_wire --- Cargo.lock | 13 +++--- Cargo.toml | 2 +- crates/client-api/src/routes/database.rs | 2 +- crates/client-api/src/util.rs | 9 ++-- crates/pg/src/pg_server.rs | 57 ++++++++---------------- 5 files changed, 34 insertions(+), 49 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 92abffae5fd..92a4baf56ea 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -468,7 +468,7 @@ dependencies = [ "bitflags 2.9.0", "cexpr", "clang-sys", - "itertools 0.10.5", + "itertools 0.12.1", "lazy_static", "lazycell", "log", @@ -3325,9 +3325,9 @@ dependencies = [ [[package]] name = "md5" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "490cc448043f947bae3cbee9c203358d62dbee0db12107a74be5c30ccfd09771" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" [[package]] name = "memchr" @@ -3921,9 +3921,9 @@ dependencies = [ [[package]] name = "pgwire" -version = "0.28.0" +version = "0.32.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c84e671791f3a354f265e55e400be8bb4b6262c1ec04fac4289e710ccf22ab43" +checksum = "ddf403a6ee31cf7f2217b2bd8447cb13dbb6c268d7e81501bc78a4d3daafd294" dependencies = [ "async-trait", "aws-lc-rs", @@ -3935,8 +3935,9 @@ dependencies = [ "lazy-regex", "md5", "postgres-types", - "rand 0.8.5", + "rand 0.9.1", "rust_decimal", + "rustls-pki-types", "thiserror 2.0.12", "tokio", "tokio-rustls", diff --git a/Cargo.toml b/Cargo.toml index 3b202f019f9..cb45e9d89d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -216,7 +216,7 @@ paste = "1.0" percent-encoding = "2.3" petgraph = { version = "0.6.5", default-features = false } pin-project-lite = "0.2.9" -pgwire = { version = "0.28.0", features = ["server-api"] } +pgwire = { version = "0.32", features = ["server-api"] } postgres-types = "0.2.5" pretty_assertions = { version = "1.4", features = ["unstable"] } proc-macro2 = "1.0" diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index f3a137a3e0d..fdee60f490d 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -542,7 +542,7 @@ pub async fn publish( // so, unless you are the owner, this will fail. let (database_identity, db_name) = match &name_or_identity { - Some(noa) => match noa.try_resolve(&ctx).await? { + Some(noa) => match noa.try_resolve(&ctx).await.map_err(log_and_500)? { Ok(resolved) => (resolved, noa.name()), Err(name) => { // `name_or_identity` was a `NameOrIdentity::Name`, but no record diff --git a/crates/client-api/src/util.rs b/crates/client-api/src/util.rs index 91386986fe4..c38bf33c0ae 100644 --- a/crates/client-api/src/util.rs +++ b/crates/client-api/src/util.rs @@ -97,10 +97,10 @@ impl NameOrIdentity { pub async fn try_resolve( &self, ctx: &(impl ControlStateReadAccess + ?Sized), - ) -> axum::response::Result> { + ) -> anyhow::Result> { Ok(match self { Self::Identity(identity) => Ok(Identity::from(*identity)), - Self::Name(name) => ctx.lookup_identity(name.as_ref()).map_err(log_and_500)?.ok_or(name), + Self::Name(name) => ctx.lookup_identity(name.as_ref())?.ok_or(name), }) } @@ -108,7 +108,10 @@ impl NameOrIdentity { /// response if `self` is a [`NameOrIdentity::Name`] for which no /// corresponding [`Identity`] is found in the SpacetimeDB DNS. pub async fn resolve(&self, ctx: &(impl ControlStateReadAccess + ?Sized)) -> axum::response::Result { - self.try_resolve(ctx).await?.map_err(|_| StatusCode::NOT_FOUND.into()) + self.try_resolve(ctx) + .await + .map_err(log_and_500)? + .map_err(|_| StatusCode::NOT_FOUND.into()) } } diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs index 6750588728d..39b0fcacacd 100644 --- a/crates/pg/src/pg_server.rs +++ b/crates/pg/src/pg_server.rs @@ -12,11 +12,10 @@ use pgwire::api::auth::{ finish_authentication, save_startup_parameters_to_metadata, DefaultServerParameterProvider, LoginInfo, StartupHandler, }; -use pgwire::api::copy::NoopCopyHandler; use pgwire::api::portal::Format; -use pgwire::api::query::{PlaceholderExtendedQueryHandler, SimpleQueryHandler}; +use pgwire::api::query::SimpleQueryHandler; use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; -use pgwire::api::{ClientInfo, NoopErrorHandler, METADATA_DATABASE}; +use pgwire::api::{ClientInfo, METADATA_DATABASE}; use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; use pgwire::messages::data::DataRow; @@ -145,12 +144,12 @@ async fn response(res: axum::response::Result, database: &str) -> Result { - ctx: Arc, + ctx: T, cached: Mutex>, parameter_provider: DefaultServerParameterProvider, } -impl PgSpacetimeDB { +impl PgSpacetimeDB { async fn exe_sql<'a>(&self, query: String) -> PgWireResult>> { let params = self.cached.lock().await.clone().unwrap(); let db = SqlParams { @@ -283,13 +282,11 @@ impl { - if ssl.is_some() { - let err = PgError::SSLNotSupported; - log::error!("{err}"); - let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); - return close_client(client, err).await; - } + PgWireFrontendMessage::SslRequest(_) => { + let err = PgError::SSLNotSupported; + log::error!("{err}"); + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; } // The other messages are for features not supported by SpacetimeDB, that are rejected by the parser. _ => { @@ -301,10 +298,10 @@ impl SimpleQueryHandler +impl SimpleQueryHandler for PgSpacetimeDB { - async fn do_query<'a, C>(&self, _client: &mut C, query: &'a str) -> PgWireResult>> + async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> where C: ClientInfo + Unpin + Send + Sync, { @@ -313,12 +310,12 @@ impl { +pub struct PgSpacetimeDBFactory { handler: Arc>, } impl PgSpacetimeDBFactory { - pub fn new(ctx: Arc) -> Self { + pub fn new(ctx: T) -> Self { let mut parameter_provider = DefaultServerParameterProvider::default(); parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version()); @@ -333,39 +330,23 @@ impl PgSpacetimeDBFactory { } } -impl PgWireServerHandlers +impl PgWireServerHandlers for PgSpacetimeDBFactory { - type StartupHandler = PgSpacetimeDB; - type SimpleQueryHandler = PgSpacetimeDB; - type ExtendedQueryHandler = PlaceholderExtendedQueryHandler; - type CopyHandler = NoopCopyHandler; - type ErrorHandler = NoopErrorHandler; - - fn simple_query_handler(&self) -> Arc { + fn simple_query_handler(&self) -> Arc { self.handler.clone() } - fn extended_query_handler(&self) -> Arc { - Arc::new(PlaceholderExtendedQueryHandler) - } + // TODO: fn extended_query_handler(&self) -> Arc {} - fn startup_handler(&self) -> Arc { + fn startup_handler(&self) -> Arc { self.handler.clone() } - - fn copy_handler(&self) -> Arc { - Arc::new(NoopCopyHandler) - } - - fn error_handler(&self) -> Arc { - Arc::new(NoopErrorHandler) - } } -pub async fn start_pg( +pub async fn start_pg( shutdown: Arc, - ctx: Arc, + ctx: T, tcp: TcpListener, ) { let factory = Arc::new(PgSpacetimeDBFactory::new(ctx)); From c966ccedd8275ce78b2b9c1af7ceedaebab13b88 Mon Sep 17 00:00:00 2001 From: Noa Date: Fri, 5 Sep 2025 15:32:56 -0500 Subject: [PATCH 09/10] Add pg_addr field to Node --- crates/core/src/messages/control_db.rs | 4 ++++ crates/standalone/src/lib.rs | 1 + 2 files changed, 5 insertions(+) diff --git a/crates/core/src/messages/control_db.rs b/crates/core/src/messages/control_db.rs index 21a9155dc64..8299875e339 100644 --- a/crates/core/src/messages/control_db.rs +++ b/crates/core/src/messages/control_db.rs @@ -63,6 +63,10 @@ pub struct Node { /// /// If `None`, the node is not currently live. pub advertise_addr: Option, + /// The address this node is running its postgres API at. + /// + /// If `None`, the node is not currently live. + pub pg_addr: Option, } #[derive(Clone, PartialEq, Serialize, Deserialize)] pub struct NodeStatus { diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index 9ebb9c23047..860dddd682f 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -186,6 +186,7 @@ impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv { id: 0, unschedulable: false, advertise_addr: Some("node:80".to_owned()), + pg_addr: Some("node:5432".to_owned()), })); } Ok(None) From 599d622996bb00ef23e6b9c5740efb0dde9b5f9c Mon Sep 17 00:00:00 2001 From: Mario Alejandro Montoya Cortes Date: Tue, 9 Sep 2025 11:11:16 -0500 Subject: [PATCH 10/10] Quote special types in JSON --- .github/workflows/ci.yml | 2 +- crates/cli/src/subcommands/sql.rs | 20 ++++++++++---------- crates/sats/src/satn.rs | 14 ++++++++++---- smoketests/tests/pg_wire.py | 4 ++-- 4 files changed, 23 insertions(+), 17 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62fec318a9d..84576981d69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -58,7 +58,7 @@ jobs: Start-Process target/debug/spacetimedb-cli.exe start cd modules # the sdk-manifests on windows-latest are messed up, so we need to update them - dotnet workload config --update-mode workload-set + dotnet workload config --update-mode manifests dotnet workload update - uses: actions/setup-python@v5 with: { python-version: '3.12' } diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 15917821fd7..00a8d645e42 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -496,9 +496,9 @@ Roundtrip time: 1.00ms"#, &kind, vec![value], r#" - column 0 | column 1 | column 2 | column 3 | column 4 | column 5 -----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- - "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | P0D"#, + column 0 | column 1 | column 2 | column 3 | column 4 | column 5 +----------+----------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+---------- + "a" | 0 | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#, ); // Check struct @@ -538,9 +538,9 @@ Roundtrip time: 1.00ms"#, &kind, vec![value.clone()], r#" - bool | str | bytes | identity | connection_id | timestamp | duration -------+-----------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- - true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | P0D"#, + bool | str | bytes | identity | connection_id | timestamp | duration +------+-----------------------+--------------------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+---------- + true | "This is spacetimedb" | "0x01020304050607" | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#, ); // Check nested struct, tuple... @@ -564,8 +564,8 @@ Roundtrip time: 1.00ms"#, vec![value.clone()], r#" column 0 ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ - {"bool": true, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000000, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": P0D}"#, +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}"#, ); let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into(); @@ -588,8 +588,8 @@ Roundtrip time: 1.00ms"#, vec![value], r#" tuple ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - {"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000000, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": P0D}}"#, +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}}"#, ); Ok(()) diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index cd388213c03..2a5f128262e 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -851,21 +851,27 @@ impl TypedWriter for SqlFormatter<'_, '_> { } fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { - write!(self.fmt, "0x{}", hex::encode(value)) + self.write_hex(value) } fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { - write!(self.fmt, "0x{}", hex::encode(value)) + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "0x{}", hex::encode(value)), + PsqlClient::Postgres => write!(self.fmt, "\"0x{}\"", hex::encode(value)), + } } fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { - write!(self.fmt, "{}", value.to_rfc3339().unwrap()) + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "{}", value.to_rfc3339().unwrap()), + PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_rfc3339().unwrap()), + } } fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { match self.ty.client { PsqlClient::SpacetimeDB => write!(self.fmt, "{value}"), - PsqlClient::Postgres => write!(self.fmt, "{}", value.to_iso8601()), + PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_iso8601()), } } diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py index 7efef115634..89c7880c33c 100644 --- a/smoketests/tests/pg_wire.py +++ b/smoketests/tests/pg_wire.py @@ -218,8 +218,8 @@ def test_sql_format(self): (1 row)""") self.assertSql(token, "SELECT * FROM t_others_tuple", """\ tuple ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ - {"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": 0x01020304050607, "identity": 0x0000000000000000000000000000000000000000000000000000000000000001, "connection_id": 0x00000000000000000000000000000000, "timestamp": 1970-01-01T00:00:00+00:00, "duration": PT10S} +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000001", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "PT10S"} (1 row)""") self.assertSql(token, "SELECT * FROM t_simple_enum", """\ id | action