From 51ad2f234e732dfb6b730f0b8d8c95217f838e90 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Fri, 18 Jul 2025 13:25:40 -0700 Subject: [PATCH 01/17] Add system table for credentials --- .../locking_tx_datastore/committed_state.rs | 9 +++- .../src/locking_tx_datastore/datastore.rs | 14 ++++-- crates/datastore/src/system_tables.rs | 49 ++++++++++++++++++- crates/sdk/examples/quickstart-chat/main.rs | 12 ++--- 4 files changed, 71 insertions(+), 13 deletions(-) diff --git a/crates/datastore/src/locking_tx_datastore/committed_state.rs b/crates/datastore/src/locking_tx_datastore/committed_state.rs index a074d43c124..ed8b15ec1f5 100644 --- a/crates/datastore/src/locking_tx_datastore/committed_state.rs +++ b/crates/datastore/src/locking_tx_datastore/committed_state.rs @@ -6,6 +6,7 @@ use super::{ tx_state::{IndexIdMap, PendingSchemaChange, TxState}, IterByColEqTx, }; +use crate::system_tables::{ST_CONNECTION_CREDENTIALS_ID, ST_CONNECTION_CREDENTIALS_IDX}; use crate::{ db_metrics::DB_METRICS, error::{IndexError, TableError}, @@ -25,7 +26,7 @@ use core::{convert::Infallible, ops::RangeBounds}; use itertools::Itertools; use spacetimedb_data_structures::map::{HashSet, IntMap}; use spacetimedb_lib::{ - db::auth::{StAccess, StTableType}, + db::auth::StTableType, Identity, }; use spacetimedb_primitives::{ColList, ColSet, IndexId, TableId}; @@ -183,7 +184,7 @@ impl CommittedState { table_id, table_name: schema.table_name.clone(), table_type: StTableType::System, - table_access: StAccess::Public, + table_access: schema.table_access, table_primary_key: schema.primary_key.map(Into::into), }; let row = ProductValue::from(row); @@ -272,6 +273,10 @@ impl CommittedState { self.create_table(ST_SCHEDULED_ID, schemas[ST_SCHEDULED_IDX].clone()); self.create_table(ST_ROW_LEVEL_SECURITY_ID, schemas[ST_ROW_LEVEL_SECURITY_IDX].clone()); + self.create_table( + ST_CONNECTION_CREDENTIALS_ID, + schemas[ST_CONNECTION_CREDENTIALS_IDX].clone(), + ); // IMPORTANT: It is crucial that the `st_sequences` table is created last diff --git a/crates/datastore/src/locking_tx_datastore/datastore.rs b/crates/datastore/src/locking_tx_datastore/datastore.rs index d9b2ac32fe2..392ec7bad22 100644 --- a/crates/datastore/src/locking_tx_datastore/datastore.rs +++ b/crates/datastore/src/locking_tx_datastore/datastore.rs @@ -1147,9 +1147,10 @@ mod tests { use crate::error::IndexError; use crate::locking_tx_datastore::tx_state::PendingSchemaChange; use crate::system_tables::{ - system_tables, StColumnRow, StConstraintData, StConstraintFields, StConstraintRow, StIndexAlgorithm, - StIndexFields, StIndexRow, StRowLevelSecurityFields, StScheduledFields, StSequenceFields, StSequenceRow, - StTableRow, StVarFields, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_NAME, + system_tables, StColumnRow, StConnectionCredentialsFields, StConstraintData, StConstraintFields, + StConstraintRow, StIndexAlgorithm, StIndexFields, StIndexRow, StRowLevelSecurityFields, StScheduledFields, + StSequenceFields, StSequenceRow, StTableRow, StVarFields, ST_CLIENT_NAME, ST_COLUMN_ID, ST_COLUMN_NAME, + ST_CONNECTION_CREDENTIALS_ID, ST_CONNECTION_CREDENTIALS_NAME, ST_CONSTRAINT_ID, ST_CONSTRAINT_NAME, ST_INDEX_ID, ST_INDEX_NAME, ST_MODULE_NAME, ST_RESERVED_SEQUENCE_RANGE, ST_ROW_LEVEL_SECURITY_ID, ST_ROW_LEVEL_SECURITY_NAME, ST_SCHEDULED_ID, ST_SCHEDULED_NAME, ST_SEQUENCE_ID, ST_SEQUENCE_NAME, ST_TABLE_NAME, ST_VAR_ID, ST_VAR_NAME, @@ -1603,6 +1604,7 @@ mod tests { TableRow { id: ST_VAR_ID.into(), name: ST_VAR_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StVarFields::Name.into()) }, TableRow { id: ST_SCHEDULED_ID.into(), name: ST_SCHEDULED_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StScheduledFields::ScheduleId.into()) }, TableRow { id: ST_ROW_LEVEL_SECURITY_ID.into(), name: ST_ROW_LEVEL_SECURITY_NAME, ty: StTableType::System, access: StAccess::Public, primary_key: Some(StRowLevelSecurityFields::Sql.into()) }, + TableRow { id: ST_CONNECTION_CREDENTIALS_ID.into(), name: ST_CONNECTION_CREDENTIALS_NAME, ty: StTableType::System, access: StAccess::Private, primary_key: Some(StConnectionCredentialsFields::ConnectionId.into()) }, ])); #[rustfmt::skip] assert_eq!(query.scan_st_columns()?, map_array([ @@ -1658,6 +1660,9 @@ mod tests { ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 0, name: "table_id", ty: TableId::get_type() }, ColRow { table: ST_ROW_LEVEL_SECURITY_ID.into(), pos: 1, name: "sql", ty: AlgebraicType::String }, + + ColRow { table: ST_CONNECTION_CREDENTIALS_ID.into(), pos: 0, name: "connection_id", ty: AlgebraicType::U128 }, + ColRow { table: ST_CONNECTION_CREDENTIALS_ID.into(), pos: 1, name: "jwt_payload", ty: AlgebraicType::String }, ])); #[rustfmt::skip] assert_eq!(query.scan_st_indexes()?, map_array([ @@ -1673,6 +1678,7 @@ mod tests { IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "st_scheduled_table_id_idx_btree", }, IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "st_row_level_security_table_id_idx_btree", }, IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "st_row_level_security_sql_idx_btree", }, + IndexRow { id: 13, table: ST_CONNECTION_CREDENTIALS_ID.into(), col: col(0), name: "st_connection_credentials_connection_id_idx_btree", }, ])); let start = FIRST_NON_SYSTEM_ID as i128; #[rustfmt::skip] @@ -1702,6 +1708,7 @@ mod tests { ConstraintRow { constraint_id: 9, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(0), constraint_name: "st_scheduled_schedule_id_key", }, ConstraintRow { constraint_id: 10, table_id: ST_SCHEDULED_ID.into(), unique_columns: col(1), constraint_name: "st_scheduled_table_id_key", }, ConstraintRow { constraint_id: 11, table_id: ST_ROW_LEVEL_SECURITY_ID.into(), unique_columns: col(1), constraint_name: "st_row_level_security_sql_key", }, + ConstraintRow { constraint_id: 12, table_id: ST_CONNECTION_CREDENTIALS_ID.into(), unique_columns: col(0), constraint_name: "st_connection_credentials_connection_id_key", }, ])); // Verify we get back the tables correctly with the proper ids... @@ -2099,6 +2106,7 @@ mod tests { IndexRow { id: 10, table: ST_SCHEDULED_ID.into(), col: col(1), name: "st_scheduled_table_id_idx_btree", }, IndexRow { id: 11, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(0), name: "st_row_level_security_table_id_idx_btree", }, IndexRow { id: 12, table: ST_ROW_LEVEL_SECURITY_ID.into(), col: col(1), name: "st_row_level_security_sql_idx_btree", }, + IndexRow { id: 13, table: ST_CONNECTION_CREDENTIALS_ID.into(), col: col(0), name: "st_connection_credentials_connection_id_idx_btree", }, IndexRow { id: seq_start, table: FIRST_NON_SYSTEM_ID, col: col(0), name: "Foo_id_idx_btree", }, IndexRow { id: seq_start + 1, table: FIRST_NON_SYSTEM_ID, col: col(1), name: "Foo_name_idx_btree", }, IndexRow { id: seq_start + 2, table: FIRST_NON_SYSTEM_ID, col: col(2), name: "Foo_age_idx_btree", }, diff --git a/crates/datastore/src/system_tables.rs b/crates/datastore/src/system_tables.rs index 1c56df19a65..1c6fdbbb036 100644 --- a/crates/datastore/src/system_tables.rs +++ b/crates/datastore/src/system_tables.rs @@ -60,6 +60,11 @@ pub const ST_SCHEDULED_ID: TableId = TableId(9); /// The static ID of the table that defines the row level security (RLS) policies pub const ST_ROW_LEVEL_SECURITY_ID: TableId = TableId(10); + +/// The static ID of the table that stores the credentials for each connection. +pub const ST_CONNECTION_CREDENTIALS_ID: TableId = TableId(11); + +pub(crate) const ST_CONNECTION_CREDENTIALS_NAME: &str = "st_connection_credentials"; pub const ST_TABLE_NAME: &str = "st_table"; pub const ST_COLUMN_NAME: &str = "st_column"; pub const ST_SEQUENCE_NAME: &str = "st_sequence"; @@ -97,7 +102,7 @@ pub enum SystemTable { st_row_level_security, } -pub fn system_tables() -> [TableSchema; 10] { +pub fn system_tables() -> [TableSchema; 11] { [ // The order should match the `id` of the system table, that start with [ST_TABLE_IDX]. st_table_schema(), @@ -109,6 +114,7 @@ pub fn system_tables() -> [TableSchema; 10] { st_var_schema(), st_scheduled_schema(), st_row_level_security_schema(), + st_connection_credential_schema(), // Is important this is always last, so the starting sequence for each // system table is correct. st_sequence_schema(), @@ -149,8 +155,9 @@ pub(crate) const ST_CLIENT_IDX: usize = 5; pub(crate) const ST_VAR_IDX: usize = 6; pub(crate) const ST_SCHEDULED_IDX: usize = 7; pub(crate) const ST_ROW_LEVEL_SECURITY_IDX: usize = 8; +pub(crate) const ST_CONNECTION_CREDENTIALS_IDX: usize = 9; // Must be the last index in the array. -pub(crate) const ST_SEQUENCE_IDX: usize = 9; +pub(crate) const ST_SEQUENCE_IDX: usize = 10; macro_rules! st_fields_enum { ($(#[$attr:meta])* enum $ty_name:ident { $($name:expr, $var:ident = $discr:expr,)* }) => { @@ -248,6 +255,13 @@ st_fields_enum!(enum StClientFields { "identity", Identity = 0, "connection_id", ConnectionId = 1, }); + +// WARNING: For a stable schema, don't change the field names and discriminants. +st_fields_enum!(enum StConnectionCredentialsFields { + "connection_id", ConnectionId = 0, + "jwt_payload", JwtPayload = 1, +}); + // WARNING: For a stable schema, don't change the field names and discriminants. st_fields_enum!(enum StVarFields { "name", Name = 0, @@ -341,6 +355,19 @@ fn system_module_def() -> ModuleDef { .with_type(TableType::System); // TODO: add empty unique constraint here, once we've implemented those. + let st_connection_credentials_type = builder.add_type::(); + // let st_connection_credentials_unique_cols = [StConnectionCredentialsFields::ConnectionId]; + builder + .build_table( + ST_CONNECTION_CREDENTIALS_NAME, + *st_connection_credentials_type.as_ref().expect("should be ref"), + ) + .with_type(TableType::System) + .with_unique_constraint(StConnectionCredentialsFields::ConnectionId) + .with_index_no_accessor_name(btree(StConnectionCredentialsFields::ConnectionId)) + .with_access(v9::TableAccess::Private) + .with_primary_key(StConnectionCredentialsFields::ConnectionId); + let st_client_type = builder.add_type::(); let st_client_unique_cols = [StClientFields::Identity, StClientFields::ConnectionId]; builder @@ -382,6 +409,7 @@ fn system_module_def() -> ModuleDef { validate_system_table::(&result, ST_CLIENT_NAME); validate_system_table::(&result, ST_VAR_NAME); validate_system_table::(&result, ST_SCHEDULED_NAME); + validate_system_table::(&result, ST_CONNECTION_CREDENTIALS_NAME); result } @@ -442,6 +470,10 @@ fn st_client_schema() -> TableSchema { st_schema(ST_CLIENT_NAME, ST_CLIENT_ID) } +fn st_connection_credential_schema() -> TableSchema { + st_schema(ST_CONNECTION_CREDENTIALS_NAME, ST_CONNECTION_CREDENTIALS_ID) +} + fn st_scheduled_schema() -> TableSchema { st_schema(ST_SCHEDULED_NAME, ST_SCHEDULED_ID) } @@ -466,6 +498,7 @@ pub(crate) fn system_table_schema(table_id: TableId) -> Option { ST_ROW_LEVEL_SECURITY_ID => Some(st_row_level_security_schema()), ST_MODULE_ID => Some(st_module_schema()), ST_CLIENT_ID => Some(st_client_schema()), + ST_CONNECTION_CREDENTIALS_ID => Some(st_connection_credential_schema()), ST_VAR_ID => Some(st_var_schema()), ST_SCHEDULED_ID => Some(st_scheduled_schema()), _ => None, @@ -927,6 +960,18 @@ pub struct StClientRow { pub connection_id: ConnectionIdViaU128, } +/// System table [ST_CONNECTION_CREDENTIALS_NAME] +/// +/// | connection_id | jwt_payload | +/// |------------------------------------|---------------------------------------------------------| +/// | 0x6bdea3ab517f5857dc9b1b5fe99e1b14 | '{"iss":"issuer","sub":"user-id","iat":1629212345,...}' | +#[derive(Clone, Debug, Eq, PartialEq, SpacetimeType)] +#[sats(crate = spacetimedb_lib)] +pub struct StConnectionCredentialsRow { + pub connection_id: ConnectionIdViaU128, + pub jwt_payload: String, +} + impl From for ProductValue { fn from(var: StClientRow) -> Self { to_product_value(&var) diff --git a/crates/sdk/examples/quickstart-chat/main.rs b/crates/sdk/examples/quickstart-chat/main.rs index e04ff9cb3cc..57ed21818f5 100644 --- a/crates/sdk/examples/quickstart-chat/main.rs +++ b/crates/sdk/examples/quickstart-chat/main.rs @@ -63,7 +63,7 @@ fn creds_store() -> credentials::File { /// Our `on_connect` callback: save our credentials to a file. fn on_connected(_ctx: &DbConnection, _identity: Identity, token: &str) { if let Err(e) = creds_store().save(token) { - eprintln!("Failed to save credentials: {:?}", e); + eprintln!("Failed to save credentials: {e:?}"); } } @@ -71,14 +71,14 @@ fn on_connected(_ctx: &DbConnection, _identity: Identity, token: &str) { /// Our `on_connect_error` callback: print the error, then exit the process. fn on_connect_error(_ctx: &ErrorContext, err: Error) { - eprintln!("Connection error: {}", err); + eprintln!("Connection error: {err}"); std::process::exit(1); } /// Our `on_disconnect` callback: print a note, then exit the process. fn on_disconnected(_ctx: &ErrorContext, err: Option) { if let Some(err) = err { - eprintln!("Disconnected: {}", err); + eprintln!("Disconnected: {err}"); std::process::exit(1); } else { println!("Disconnected."); @@ -166,14 +166,14 @@ fn print_message(ctx: &impl RemoteDbContext, message: &Message) { /// Our `on_set_name` callback: print a warning if the reducer failed. fn on_name_set(ctx: &ReducerEventContext, name: &String) { if let Status::Failed(err) = &ctx.event.status { - eprintln!("Failed to change name to {:?}: {}", name, err); + eprintln!("Failed to change name to {name:?}: {err}"); } } /// Our `on_send_message` callback: print a warning if the reducer failed. fn on_message_sent(ctx: &ReducerEventContext, text: &String) { if let Status::Failed(err) = &ctx.event.status { - eprintln!("Failed to send message {:?}: {}", text, err); + eprintln!("Failed to send message {text:?}: {err}"); } } @@ -206,7 +206,7 @@ fn on_sub_applied(ctx: &SubscriptionEventContext) { /// Or `on_error` callback: /// print the error, then exit the process. fn on_sub_error(_ctx: &ErrorContext, err: Error) { - eprintln!("Subscription failed: {}", err); + eprintln!("Subscription failed: {err}"); std::process::exit(1); } From cd6625a85a65a93d6735b980771b63329e95af10 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Mon, 21 Jul 2025 13:01:22 -0700 Subject: [PATCH 02/17] Add jwt payload to SpacetimeAuth --- Cargo.lock | 1 + crates/client-api/Cargo.toml | 1 + crates/client-api/src/auth.rs | 69 ++++++++++++++++++++- crates/sdk/examples/quickstart-chat/main.rs | 12 ++-- 4 files changed, 74 insertions(+), 9 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c5f1c04fc08..899f9080456 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5433,6 +5433,7 @@ dependencies = [ "async-trait", "axum", "axum-extra", + "base64 0.21.7", "bytes", "bytestring", "chrono", diff --git a/crates/client-api/Cargo.toml b/crates/client-api/Cargo.toml index b791ee05334..1d4aaf26918 100644 --- a/crates/client-api/Cargo.toml +++ b/crates/client-api/Cargo.toml @@ -15,6 +15,7 @@ spacetimedb-lib = { workspace = true, features = ["serde"] } spacetimedb-paths.workspace = true spacetimedb-schema.workspace = true +base64.workspace = true tokio = { version = "1.2", features = ["full"] } lazy_static = "1.4.0" log = "0.4.4" diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 61031625867..08b7e77226c 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -1,5 +1,4 @@ -use std::time::{Duration, SystemTime}; - +use anyhow::anyhow; use axum::extract::{Query, Request, State}; use axum::middleware::Next; use axum::response::IntoResponse; @@ -15,9 +14,11 @@ use spacetimedb::auth::token_validation::{ use spacetimedb::auth::JwtKeys; use spacetimedb::energy::EnergyQuanta; use spacetimedb::identity::Identity; +use std::time::{Duration, SystemTime}; use uuid::Uuid; use crate::{log_and_500, ControlStateDelegate, NodeDelegate}; +use base64::{engine::general_purpose, Engine}; /// Credentials for login for a spacetime identity, represented as a JWT. /// @@ -41,6 +42,19 @@ impl SpacetimeCreds { Self { token } } + fn extract_jwt_payload_string(&self) -> Option { + let parts: Vec<&str> = self.token.split('.').collect(); + if parts.len() != 3 { + return None; + } + + let payload_encoded = parts[1]; + let decoded_bytes = general_purpose::URL_SAFE_NO_PAD.decode(payload_encoded).ok()?; + let json_str = String::from_utf8(decoded_bytes).ok()?; + + Some(json_str) + } + pub fn to_header_value(&self) -> HeaderValue { let mut val = HeaderValue::try_from(["Bearer ", self.token()].concat()).unwrap(); val.set_sensitive(true); @@ -73,6 +87,8 @@ pub struct SpacetimeAuth { pub identity: Identity, pub subject: String, pub issuer: String, + // The decoded JWT payload. + pub raw_payload: String, } use jsonwebtoken; @@ -148,12 +164,17 @@ impl SpacetimeAuth { let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; SpacetimeCreds::from_signed_token(token) }; + // Pulling out the payload should never fail, since we just made it. + let payload = creds + .extract_jwt_payload_string() + .ok_or_else(|| log_and_500("internal error"))?; Ok(Self { creds, identity, subject, issuer: ctx.jwt_auth_provider().local_issuer().to_string(), + raw_payload: payload, }) } @@ -237,9 +258,11 @@ impl JwtAuthProvider for JwtKeyAuthProvider Result<(), anyhow::Error> { + let kp = JwtKeys::generate()?; + + let dummy_audience = "spacetimedb".to_string(); + let claims = TokenClaims { + issuer: "localhost".to_string(), + subject: "test-subject".to_string(), + audience: vec![dummy_audience.clone()], + }; + let token = claims.encode_and_sign(&kp.private)?; + let st_creds = SpacetimeCreds::from_signed_token(token); + let payload = st_creds + .extract_jwt_payload_string() + .ok_or_else(|| anyhow::anyhow!("Failed to extract JWT payload"))?; + // Make sure it is valid json. + let parsed: serde_json::Value = serde_json::from_str(&payload)?; + assert_eq!(parsed.get("iss").unwrap().as_str().unwrap(), claims.issuer); + assert_eq!(parsed.get("sub").unwrap().as_str().unwrap(), claims.subject); + assert_eq!( + parsed.get("aud").unwrap().as_array().unwrap()[0].as_str().unwrap(), + dummy_audience + ); + let as_object = parsed + .as_object() + .ok_or_else(|| anyhow::anyhow!("Failed to parse JWT payload as object"))?; + let keys: HashSet = as_object.keys().map(|s| s.to_string()).collect(); + let expected_keys = vec!["iss", "sub", "aud", "iat", "exp", "hex_identity"] + .into_iter() + .map(|s| s.to_string()) + .collect::>(); + assert_eq!(keys, expected_keys); + Ok(()) + } } pub struct SpacetimeAuthHeader { @@ -279,11 +338,15 @@ impl axum::extract::FromRequestParts for Space .await .map_err(AuthorizationRejection::Custom)?; + let payload = creds.extract_jwt_payload_string().ok_or_else(|| { + AuthorizationRejection::Custom(TokenValidationError::Other(anyhow!("Internal error parsing token"))) + })?; let auth = SpacetimeAuth { creds, identity: claims.identity, subject: claims.subject, issuer: claims.issuer, + raw_payload: payload, }; Ok(Self { auth: Some(auth) }) } diff --git a/crates/sdk/examples/quickstart-chat/main.rs b/crates/sdk/examples/quickstart-chat/main.rs index e04ff9cb3cc..57ed21818f5 100644 --- a/crates/sdk/examples/quickstart-chat/main.rs +++ b/crates/sdk/examples/quickstart-chat/main.rs @@ -63,7 +63,7 @@ fn creds_store() -> credentials::File { /// Our `on_connect` callback: save our credentials to a file. fn on_connected(_ctx: &DbConnection, _identity: Identity, token: &str) { if let Err(e) = creds_store().save(token) { - eprintln!("Failed to save credentials: {:?}", e); + eprintln!("Failed to save credentials: {e:?}"); } } @@ -71,14 +71,14 @@ fn on_connected(_ctx: &DbConnection, _identity: Identity, token: &str) { /// Our `on_connect_error` callback: print the error, then exit the process. fn on_connect_error(_ctx: &ErrorContext, err: Error) { - eprintln!("Connection error: {}", err); + eprintln!("Connection error: {err}"); std::process::exit(1); } /// Our `on_disconnect` callback: print a note, then exit the process. fn on_disconnected(_ctx: &ErrorContext, err: Option) { if let Some(err) = err { - eprintln!("Disconnected: {}", err); + eprintln!("Disconnected: {err}"); std::process::exit(1); } else { println!("Disconnected."); @@ -166,14 +166,14 @@ fn print_message(ctx: &impl RemoteDbContext, message: &Message) { /// Our `on_set_name` callback: print a warning if the reducer failed. fn on_name_set(ctx: &ReducerEventContext, name: &String) { if let Status::Failed(err) = &ctx.event.status { - eprintln!("Failed to change name to {:?}: {}", name, err); + eprintln!("Failed to change name to {name:?}: {err}"); } } /// Our `on_send_message` callback: print a warning if the reducer failed. fn on_message_sent(ctx: &ReducerEventContext, text: &String) { if let Status::Failed(err) = &ctx.event.status { - eprintln!("Failed to send message {:?}: {}", text, err); + eprintln!("Failed to send message {text:?}: {err}"); } } @@ -206,7 +206,7 @@ fn on_sub_applied(ctx: &SubscriptionEventContext) { /// Or `on_error` callback: /// print the error, then exit the process. fn on_sub_error(_ctx: &ErrorContext, err: Error) { - eprintln!("Subscription failed: {}", err); + eprintln!("Subscription failed: {err}"); std::process::exit(1); } From 4c1807f2fb9f2865ff1e0b0a8cf806161ef68704 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Wed, 23 Jul 2025 10:17:29 -0700 Subject: [PATCH 03/17] Plumb the jwt payload around --- crates/auth/Cargo.toml | 1 + crates/auth/src/identity.rs | 20 +++++++- crates/client-api/src/auth.rs | 54 ++++++++++++--------- crates/client-api/src/routes/database.rs | 41 +++++++++------- crates/client-api/src/routes/energy.rs | 6 +-- crates/client-api/src/routes/identity.rs | 6 +-- crates/client-api/src/routes/subscribe.rs | 15 ++++-- crates/core/src/client.rs | 3 +- crates/core/src/client/client_connection.rs | 19 ++++++-- crates/core/src/error.rs | 3 +- crates/core/src/host/module_host.rs | 9 ++-- crates/testing/src/modules.rs | 2 +- 12 files changed, 120 insertions(+), 59 deletions(-) diff --git a/crates/auth/Cargo.toml b/crates/auth/Cargo.toml index a5592c08f5f..f3db2f86af7 100644 --- a/crates/auth/Cargo.toml +++ b/crates/auth/Cargo.toml @@ -11,6 +11,7 @@ spacetimedb-lib = { workspace = true, features = ["serde"] } anyhow.workspace = true serde.workspace = true +serde_json.workspace = true serde_with.workspace = true jsonwebtoken.workspace = true diff --git a/crates/auth/src/identity.rs b/crates/auth/src/identity.rs index 286d28582dd..576e850d7ee 100644 --- a/crates/auth/src/identity.rs +++ b/crates/auth/src/identity.rs @@ -6,9 +6,27 @@ use serde::{Deserialize, Serialize}; use spacetimedb_lib::Identity; use std::time::SystemTime; +#[derive(Debug, Clone)] +pub struct ConnectionAuthCtx { + pub claims: SpacetimeIdentityClaims, + pub jwt_payload: String, +} + +impl TryFrom for ConnectionAuthCtx { + type Error = anyhow::Error; + fn try_from(claims: SpacetimeIdentityClaims) -> Result { + let payload = + serde_json::to_string(&claims).map_err(|e| anyhow::anyhow!("Failed to serialize claims: {}", e))?; + Ok(ConnectionAuthCtx { + claims, + jwt_payload: payload, + }) + } +} + // These are the claims that can be attached to a request/connection. #[serde_with::serde_as] -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct SpacetimeIdentityClaims { #[serde(rename = "hex_identity")] pub identity: Identity, diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 08b7e77226c..ccf2218abf7 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -6,7 +6,7 @@ use axum_extra::typed_header::TypedHeader; use headers::{authorization, HeaderMapExt}; use http::{request, HeaderValue, StatusCode}; use serde::{Deserialize, Serialize}; -use spacetimedb::auth::identity::SpacetimeIdentityClaims; +use spacetimedb::auth::identity::{ConnectionAuthCtx, SpacetimeIdentityClaims}; use spacetimedb::auth::identity::{JwtError, JwtErrorKind}; use spacetimedb::auth::token_validation::{ new_validator, DefaultValidator, TokenSigner, TokenValidationError, TokenValidator, @@ -84,13 +84,25 @@ impl SpacetimeCreds { #[derive(Clone)] pub struct SpacetimeAuth { pub creds: SpacetimeCreds, + pub claims: SpacetimeIdentityClaims, + /* pub identity: Identity, pub subject: String, pub issuer: String, + */ // The decoded JWT payload. pub raw_payload: String, } +impl From for ConnectionAuthCtx { + fn from(auth: SpacetimeAuth) -> Self { + ConnectionAuthCtx { + claims: auth.claims, + jwt_payload: auth.raw_payload.clone(), + } + } +} + use jsonwebtoken; pub struct TokenClaims { @@ -100,10 +112,10 @@ pub struct TokenClaims { } impl From for TokenClaims { - fn from(claims: SpacetimeAuth) -> Self { + fn from(auth: SpacetimeAuth) -> Self { Self { - issuer: claims.issuer, - subject: claims.subject, + issuer: auth.claims.issuer, + subject: auth.claims.subject, // This will need to be changed when we care about audiencies. audience: Vec::new(), } @@ -128,7 +140,7 @@ impl TokenClaims { &self, signer: &impl TokenSigner, expiry: Option, - ) -> Result { + ) -> Result<(SpacetimeIdentityClaims, String), JwtError> { let iat = SystemTime::now(); let exp = expiry.map(|dur| iat + dur); let claims = SpacetimeIdentityClaims { @@ -139,10 +151,11 @@ impl TokenClaims { iat, exp, }; - signer.sign(&claims) + let token = signer.sign(&claims)?; + Ok((claims, token)) } - pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result { + pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<(SpacetimeIdentityClaims, String), JwtError> { self.encode_and_sign_with_expiry(signer, None) } } @@ -159,11 +172,8 @@ impl SpacetimeAuth { audience: vec!["spacetimedb".to_string()], }; - let identity = claims.id(); - let creds = { - let token = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; - SpacetimeCreds::from_signed_token(token) - }; + let (claims, token) = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; + let creds = SpacetimeCreds::from_signed_token(token); // Pulling out the payload should never fail, since we just made it. let payload = creds .extract_jwt_payload_string() @@ -171,9 +181,7 @@ impl SpacetimeAuth { Ok(Self { creds, - identity, - subject, - issuer: ctx.jwt_auth_provider().local_issuer().to_string(), + claims, raw_payload: payload, }) } @@ -181,7 +189,7 @@ impl SpacetimeAuth { /// Get the auth credentials as headers to be returned from an endpoint. pub fn into_headers(self) -> (TypedHeader, TypedHeader) { ( - TypedHeader(SpacetimeIdentity(self.identity)), + TypedHeader(SpacetimeIdentity(self.claims.identity)), TypedHeader(SpacetimeIdentityToken(self.creds)), ) } @@ -189,7 +197,11 @@ impl SpacetimeAuth { // Sign a new token with the same claims and a new expiry. // Note that this will not change the issuer, so the private_key might not match. // We do this to create short-lived tokens that we will be able to verify. - pub fn re_sign_with_expiry(&self, signer: &impl TokenSigner, expiry: Duration) -> Result { + pub fn re_sign_with_expiry( + &self, + signer: &impl TokenSigner, + expiry: Duration, + ) -> Result<(SpacetimeIdentityClaims, String), JwtError> { TokenClaims::from(self.clone()).encode_and_sign_with_expiry(signer, Some(expiry)) } } @@ -275,7 +287,7 @@ mod tests { audience: vec!["spacetimedb".to_string()], }; let id = claims.id(); - let token = claims.encode_and_sign(&kp.private)?; + let (_, token) = claims.encode_and_sign(&kp.private)?; let decoded = kp.public.validate_token(&token).await?; assert_eq!(decoded.identity, id); @@ -293,7 +305,7 @@ mod tests { subject: "test-subject".to_string(), audience: vec![dummy_audience.clone()], }; - let token = claims.encode_and_sign(&kp.private)?; + let (_, token) = claims.encode_and_sign(&kp.private)?; let st_creds = SpacetimeCreds::from_signed_token(token); let payload = st_creds .extract_jwt_payload_string() @@ -343,9 +355,7 @@ impl axum::extract::FromRequestParts for Space })?; let auth = SpacetimeAuth { creds, - identity: claims.identity, - subject: claims.subject, - issuer: claims.issuer, + claims, raw_payload: payload, }; Ok(Self { auth: Some(auth) }) diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index c25e476b076..5802b3abfdf 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -54,7 +54,7 @@ pub async fn call( if content_type != headers::ContentType::json() { return Err(axum::extract::rejection::MissingJsonContentType::default().into()); } - let caller_identity = auth.identity; + let caller_identity = auth.claims.identity; let args = ReducerArgs::Json(body); @@ -78,7 +78,7 @@ pub async fn call( // so generate one. let connection_id = generate_random_connection_id(); - match module.call_identity_connected(caller_identity, connection_id).await { + match module.call_identity_connected(auth.into(), connection_id).await { // If `call_identity_connected` returns `Err(Rejected)`, then the `client_connected` reducer errored, // meaning the connection was refused. Return 403 forbidden. Err(ClientConnectedError::Rejected(msg)) => return Err((StatusCode::FORBIDDEN, msg).into()), @@ -225,7 +225,7 @@ where }; Ok(( - TypedHeader(SpacetimeIdentity(auth.identity)), + TypedHeader(SpacetimeIdentity(auth.claims.identity)), TypedHeader(SpacetimeIdentityToken(auth.creds)), response_json, )) @@ -300,13 +300,13 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - if database.owner_identity != auth.identity { + if database.owner_identity != auth.claims.identity { return Err(( StatusCode::BAD_REQUEST, format!( "Identity does not own database, expected: {} got: {}", database.owner_identity.to_hex(), - auth.identity.to_hex() + auth.claims.identity.to_hex() ), ) .into()); @@ -402,7 +402,7 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - let auth = AuthCtx::new(database.owner_identity, auth.identity); + let auth = AuthCtx::new(database.owner_identity, auth.claims.identity); log::debug!("auth: {auth:?}"); let host = worker_ctx @@ -481,10 +481,13 @@ fn allow_creation(auth: &SpacetimeAuth) -> Result<(), ErrorResponse> { if !require_spacetime_auth_for_creation() { return Ok(()); } - if auth.issuer.trim_end_matches('/') == "https://auth.spacetimedb.com" { + if auth.claims.issuer.trim_end_matches('/') == "https://auth.spacetimedb.com" { Ok(()) } else { - log::trace!("Rejecting creation request because auth issuer is {}", auth.issuer); + log::trace!( + "Rejecting creation request because auth issuer is {}", + auth.claims.issuer + ); Err(( StatusCode::UNAUTHORIZED, "To create a database, you must be logged in with a SpacetimeDB account.", @@ -511,9 +514,13 @@ pub async fn publish( // exists yet. Create it now with a fresh identity. allow_creation(&auth)?; let database_auth = SpacetimeAuth::alloc(&ctx).await?; - let database_identity = database_auth.identity; + let database_identity = database_auth.claims.identity; let tld: name::Tld = name.clone().into(); - let tld = match ctx.register_tld(&auth.identity, tld).await.map_err(log_and_500)? { + let tld = match ctx + .register_tld(&auth.claims.identity, tld) + .await + .map_err(log_and_500)? + { name::RegisterTldResult::Success { domain } | name::RegisterTldResult::AlreadyRegistered { domain } => domain, name::RegisterTldResult::Unauthorized { .. } => { @@ -525,7 +532,7 @@ pub async fn publish( } }; let res = ctx - .create_dns_record(&auth.identity, &tld.into(), &database_identity) + .create_dns_record(&auth.claims.identity, &tld.into(), &database_identity) .await .map_err(log_and_500)?; match res { @@ -541,7 +548,7 @@ pub async fn publish( }, None => { let database_auth = SpacetimeAuth::alloc(&ctx).await?; - let database_identity = database_auth.identity; + let database_identity = database_auth.claims.identity; (database_identity, None) } }; @@ -558,7 +565,7 @@ pub async fn publish( } if clear && exists { - ctx.delete_database(&auth.identity, &database_identity) + ctx.delete_database(&auth.claims.identity, &database_identity) .await .map_err(log_and_500)?; } @@ -580,7 +587,7 @@ pub async fn publish( let maybe_updated = ctx .publish_database( - &auth.identity, + &auth.claims.identity, DatabaseDef { database_identity, program_bytes: body.into(), @@ -626,7 +633,7 @@ pub async fn delete_database( ) -> axum::response::Result { let database_identity = name_or_identity.resolve(&ctx).await?; - ctx.delete_database(&auth.identity, &database_identity) + ctx.delete_database(&auth.claims.identity, &database_identity) .await .map_err(log_and_500)?; @@ -648,7 +655,7 @@ pub async fn add_name( let database_identity = name_or_identity.resolve(&ctx).await?; let response = ctx - .create_dns_record(&auth.identity, &name.into(), &database_identity) + .create_dns_record(&auth.claims.identity, &name.into(), &database_identity) .await // TODO: better error code handling .map_err(log_and_500)?; @@ -691,7 +698,7 @@ pub async fn set_names( )); }; - if database.owner_identity != auth.identity { + if database.owner_identity != auth.claims.identity { return Ok(( StatusCode::UNAUTHORIZED, axum::Json(name::SetDomainsResult::NotYourDatabase { diff --git a/crates/client-api/src/routes/energy.rs b/crates/client-api/src/routes/energy.rs index 16e66963cf6..34098987c68 100644 --- a/crates/client-api/src/routes/energy.rs +++ b/crates/client-api/src/routes/energy.rs @@ -49,14 +49,14 @@ pub async fn add_energy( })?; if let Some(satoshi) = amount { - ctx.add_energy(&auth.identity, EnergyQuanta::new(satoshi)) + ctx.add_energy(&auth.claims.identity, EnergyQuanta::new(satoshi)) .await .map_err(log_and_500)?; } // TODO: is this guaranteed to pull the updated balance? let balance = ctx - .get_energy_balance(&auth.identity) + .get_energy_balance(&auth.claims.identity) .map_err(log_and_500)? .map_or(0, |quanta| quanta.get()); @@ -87,7 +87,7 @@ pub async fn set_energy_balance( // This will be a natural rate limiter until we can begin to sell energy. // No one is able to be the dummy identity so this always returns unauthorized. - if auth.identity != Identity::__dummy() { + if auth.claims.identity != Identity::__dummy() { return Err(StatusCode::UNAUTHORIZED.into()); } diff --git a/crates/client-api/src/routes/identity.rs b/crates/client-api/src/routes/identity.rs index 69b27661fbe..be9adde55f9 100644 --- a/crates/client-api/src/routes/identity.rs +++ b/crates/client-api/src/routes/identity.rs @@ -24,7 +24,7 @@ pub async fn create_identity( let auth = SpacetimeAuth::alloc(&ctx).await?; let identity_response = CreateIdentityResponse { - identity: auth.identity, + identity: auth.claims.identity, token: auth.creds.token().to_owned(), }; Ok(axum::Json(identity_response)) @@ -103,7 +103,7 @@ pub async fn create_websocket_token( SpacetimeAuthRequired(auth): SpacetimeAuthRequired, ) -> axum::response::Result { let expiry = Duration::from_secs(60); - let token = auth + let (_, token) = auth .re_sign_with_expiry(ctx.jwt_auth_provider(), expiry) .map_err(log_and_500)?; // let token = encode_token_with_expiry(ctx.private_key(), auth.identity, Some(expiry)).map_err(log_and_500)?; @@ -121,7 +121,7 @@ pub async fn validate_token( ) -> axum::response::Result { let identity = Identity::from(identity); - if auth.identity != identity { + if auth.claims.identity != identity { return Err(StatusCode::BAD_REQUEST.into()); } diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index a76d3adacea..97e884beb71 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -136,8 +136,9 @@ where let module_rx = leader.module_watcher().await.map_err(log_and_500)?; + let client_identity = auth.claims.identity; let client_id = ClientActorId { - identity: auth.identity, + identity: client_identity, connection_id, name: ctx.client_actor_index().next_client_name(), }; @@ -164,7 +165,15 @@ where } let actor = |client, sendrx| ws_client_actor(client, ws, sendrx); - let client = match ClientConnection::spawn(client_id, client_config, leader.replica_id, module_rx, actor).await + let client = match ClientConnection::spawn( + client_id, + auth.into(), + client_config, + leader.replica_id, + module_rx, + actor, + ) + .await { Ok(s) => s, Err(e @ (ClientConnectedError::Rejected(_) | ClientConnectedError::OutOfEnergy)) => { @@ -183,7 +192,7 @@ where // Clients that receive the token from the response headers should ignore this // message. let message = IdentityTokenMessage { - identity: auth.identity, + identity: client_identity, token: identity_token, connection_id, }; diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index 11103824dca..1382b7882db 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -14,7 +14,8 @@ pub use client_connection_index::ClientActorIndex; pub use message_handlers::{MessageExecutionError, MessageHandleError}; use spacetimedb_lib::ConnectionId; -#[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] +// #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] +#[derive(Clone, Debug, Copy)] pub struct ClientActorId { pub identity: Identity, pub connection_id: ConnectionId, diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index acd9466d5b1..02d0da541cc 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -5,7 +5,7 @@ use std::sync::atomic::Ordering; use std::sync::atomic::{AtomicBool, Ordering::Relaxed}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::Instant; +use std::time::{Instant, SystemTime}; use super::messages::{OneOffQueryResponseMessage, SerializableMessage}; use super::{message_handlers, ClientActorId, MessageHandleError}; @@ -21,6 +21,7 @@ use bytestring::ByteString; use derive_more::From; use futures::prelude::*; use prometheus::{Histogram, IntCounter, IntGauge}; +use spacetimedb_auth::identity::{ConnectionAuthCtx, SpacetimeIdentityClaims}; use spacetimedb_client_api_messages::websocket::{ BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe, UnsubscribeMulti, @@ -78,6 +79,7 @@ impl ClientConfig { #[derive(Debug)] pub struct ClientConnectionSender { pub id: ClientActorId, + pub auth: ConnectionAuthCtx, pub config: ClientConfig, sendtx: mpsc::Sender, abort_handle: AbortHandle, @@ -146,8 +148,17 @@ impl ClientConnectionSender { let rx = MeteredReceiver::new(rx); let cancelled = AtomicBool::new(false); + let dummy_claims = SpacetimeIdentityClaims { + identity: id.identity, + subject: "".to_string(), + issuer: "".to_string(), + audience: vec![], + iat: SystemTime::now(), + exp: None, + }; let sender = Self { id, + auth: ConnectionAuthCtx::try_from(dummy_claims).expect("dummy claims should always be valid"), config, sendtx, abort_handle, @@ -396,6 +407,7 @@ impl ClientConnection { /// Returns an error if ModuleHost closed pub async fn spawn( id: ClientActorId, + auth: ConnectionAuthCtx, config: ClientConfig, replica_id: u64, mut module_rx: watch::Receiver, @@ -409,7 +421,7 @@ impl ClientConnection { // logically subscribed to the database, not any particular replica. We should handle failover for // them and stuff. Not right now though. let module = module_rx.borrow_and_update().clone(); - module.call_identity_connected(id.identity, id.connection_id).await?; + module.call_identity_connected(auth.clone(), id.connection_id).await?; let (sendtx, sendrx) = mpsc::channel::(CLIENT_CHANNEL_CAPACITY); @@ -417,6 +429,7 @@ impl ClientConnection { // weird dance so that we can get an abort_handle into ClientConnection let module_info = module.info.clone(); let database_identity = module_info.database_identity; + let client_identity = id.identity; let abort_handle = tokio::spawn(async move { let Ok(fut) = fut_rx.await else { return }; @@ -424,7 +437,6 @@ impl ClientConnection { module_info.metrics.ws_clients_spawned.inc(); scopeguard::defer! { let database_identity = module_info.database_identity; - let client_identity = id.identity; log::warn!("websocket connection aborted for client identity `{client_identity}` and database identity `{database_identity}`"); module_info.metrics.ws_clients_aborted.inc(); }; @@ -438,6 +450,7 @@ impl ClientConnection { let sender = Arc::new(ClientConnectionSender { id, + auth, config, sendtx, abort_handle, diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 4e550b234b3..3dfd4ec465b 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -26,7 +26,8 @@ use spacetimedb_vm::expr::Crud; pub use spacetimedb_datastore::error::{DatastoreError, IndexError, SequenceError, TableError}; -#[derive(Error, Debug, PartialEq, Eq)] +// #[derive(Error, Debug, PartialEq, Eq)] +#[derive(Error, Debug)] pub enum ClientError { #[error("Client not found: {0}")] NotFound(ClientActorId), diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index dc8a75640a6..d18c02356dc 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -25,6 +25,7 @@ use derive_more::From; use indexmap::IndexSet; use itertools::Itertools; use prometheus::{Histogram, IntGauge}; +use spacetimedb_auth::identity::ConnectionAuthCtx; use spacetimedb_client_api_messages::websocket::{ByteListLen, Compression, OneOffTable, QueryUpdate, WebsocketFormat}; use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; @@ -684,7 +685,7 @@ impl ModuleHost { /// In this case, the caller should terminate the connection. pub async fn call_identity_connected( &self, - caller_identity: Identity, + caller_auth: ConnectionAuthCtx, caller_connection_id: ConnectionId, ) -> Result<(), ClientConnectedError> { let me = self.clone(); @@ -697,7 +698,7 @@ impl ModuleHost { // If the call fails (as in, something unexpectedly goes wrong with WASM execution), // abort the connection: we can't really recover. let reducer_outcome = me.call_reducer_inner_with_inst( - caller_identity, + caller_auth.claims.identity, Some(caller_connection_id), None, None, @@ -739,7 +740,7 @@ impl ModuleHost { let workload = Workload::Reducer(ReducerContext { name: reducer_name.to_owned(), - caller_identity, + caller_identity: caller_auth.claims.identity, caller_connection_id, timestamp: Timestamp::now(), arg_bsatn: Bytes::new(), @@ -748,7 +749,7 @@ impl ModuleHost { let stdb = me.module.replica_ctx().relational_db.clone(); stdb.with_auto_commit(workload, |mut_tx| { mut_tx - .insert_st_client(caller_identity, caller_connection_id) + .insert_st_client(caller_auth.claims.identity, caller_connection_id) .map_err(DBError::from) }) .inspect_err(|e| { diff --git a/crates/testing/src/modules.rs b/crates/testing/src/modules.rs index 22c55b9af8a..11665c1ec3e 100644 --- a/crates/testing/src/modules.rs +++ b/crates/testing/src/modules.rs @@ -184,7 +184,7 @@ impl CompiledModule { .unwrap(); // TODO: Fix this when we update identity generation. let identity = Identity::ZERO; - let db_identity = SpacetimeAuth::alloc(&env).await.unwrap().identity; + let db_identity = SpacetimeAuth::alloc(&env).await.unwrap().claims.identity; let connection_id = generate_random_connection_id(); let program_bytes = self.program_bytes().to_owned(); From c65bb786ec6f23c2dba15f80095741776ac391e0 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Wed, 23 Jul 2025 10:51:23 -0700 Subject: [PATCH 04/17] Store client credentials. --- crates/core/src/host/module_host.rs | 3 ++ .../src/host/wasm_common/module_host_actor.rs | 7 +++- .../locking_tx_datastore/committed_state.rs | 5 +-- .../src/locking_tx_datastore/mut_tx.rs | 35 +++++++++++++++++-- 4 files changed, 43 insertions(+), 7 deletions(-) diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index d18c02356dc..b5561724404 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -750,6 +750,9 @@ impl ModuleHost { stdb.with_auto_commit(workload, |mut_tx| { mut_tx .insert_st_client(caller_auth.claims.identity, caller_connection_id) + .map_err(DBError::from)?; + mut_tx + .insert_st_client_credentials(caller_connection_id, &caller_auth.jwt_payload) .map_err(DBError::from) }) .inspect_err(|e| { diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index c337140ac94..3a245d927d5 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -363,9 +363,14 @@ impl WasmModuleInstance { .with_label_values(&database_identity, reducer_name); let workload = Workload::Reducer(ReducerContext::from(op.clone())); - let tx = tx.unwrap_or_else(|| stdb.begin_mut_tx(IsolationLevel::Serializable, workload)); + let mut tx = tx.unwrap_or_else(|| stdb.begin_mut_tx(IsolationLevel::Serializable, workload)); let _guard = metric_reducer_plus_query_duration.with_timer(tx.timer); + if let Some(Lifecycle::OnConnect) = reducer_def.lifecycle { + tx.insert_st_client_credentials(caller_connection_id, &client.clone().unwrap().auth.jwt_payload) + .unwrap(); + }; + let mut tx_slot = self.instance.instance_env().tx.clone(); let reducer_span = tracing::trace_span!( diff --git a/crates/datastore/src/locking_tx_datastore/committed_state.rs b/crates/datastore/src/locking_tx_datastore/committed_state.rs index ed8b15ec1f5..c05cd558751 100644 --- a/crates/datastore/src/locking_tx_datastore/committed_state.rs +++ b/crates/datastore/src/locking_tx_datastore/committed_state.rs @@ -25,10 +25,7 @@ use anyhow::anyhow; use core::{convert::Infallible, ops::RangeBounds}; use itertools::Itertools; use spacetimedb_data_structures::map::{HashSet, IntMap}; -use spacetimedb_lib::{ - db::auth::StTableType, - Identity, -}; +use spacetimedb_lib::{db::auth::StTableType, Identity}; use spacetimedb_primitives::{ColList, ColSet, IndexId, TableId}; use spacetimedb_sats::memory_usage::MemoryUsage; use spacetimedb_sats::{AlgebraicValue, ProductValue}; diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index e1c92e6d204..884b07c7538 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -10,6 +10,7 @@ use super::{ }; use crate::execution_context::ExecutionContext; use crate::execution_context::Workload; +use crate::system_tables::{StConnectionCredentialsFields, StConnectionCredentialsRow, ST_CONNECTION_CREDENTIALS_ID}; use crate::traits::{InsertFlags, RowTypeForTable, TxData, UpdateFlags}; use crate::{ error::{IndexError, SequenceError, TableError}, @@ -1357,6 +1358,36 @@ impl MutTxId { self.insert_via_serialize_bsatn(ST_CLIENT_ID, row).map(|_| ()) } + pub fn insert_st_client_credentials(&mut self, connection_id: ConnectionId, jwt_payload: &str) -> Result<()> { + let row = &StConnectionCredentialsRow { + connection_id: connection_id.into(), + jwt_payload: jwt_payload.to_owned(), + }; + self.insert_via_serialize_bsatn(ST_CONNECTION_CREDENTIALS_ID, row) + .map(|_| ()) + } + + pub fn delete_st_client_credentials( + &mut self, + database_identity: Identity, + connection_id: ConnectionId, + ) -> Result<()> { + if let Some(ptr) = self + .iter_by_col_eq( + ST_CONNECTION_CREDENTIALS_ID, + StConnectionCredentialsFields::ConnectionId, + &connection_id.into(), + )? + .next() + .map(|row| row.pointer()) + { + self.delete(ST_CONNECTION_CREDENTIALS_ID, ptr).map(drop) + } else { + log::warn!("[{database_identity}]: delete_st_client_credentials: attempting to credentials for missing connection id ({connection_id})"); + Ok(()) + } + } + pub fn delete_st_client( &mut self, identity: Identity, @@ -1378,11 +1409,11 @@ impl MutTxId { .next() .map(|row| row.pointer()) { - self.delete(ST_CLIENT_ID, ptr).map(drop) + self.delete(ST_CLIENT_ID, ptr).map(drop)? } else { log::error!("[{database_identity}]: delete_st_client: attempting to delete client ({identity}, {connection_id}), but no st_client row for that client is resident"); - Ok(()) } + self.delete_st_client_credentials(database_identity, connection_id) } pub fn insert_via_serialize_bsatn<'a, T: Serialize>( From 168fd84d50b5a40458da70bcd12ef152620e62ad Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Thu, 24 Jul 2025 08:08:35 -0700 Subject: [PATCH 05/17] Tweak error handling --- crates/client-api/src/auth.rs | 5 ---- .../src/host/wasm_common/module_host_actor.rs | 26 +++++++++++++++++-- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index ccf2218abf7..e8175ee7b5b 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -85,11 +85,6 @@ impl SpacetimeCreds { pub struct SpacetimeAuth { pub creds: SpacetimeCreds, pub claims: SpacetimeIdentityClaims, - /* - pub identity: Identity, - pub subject: String, - pub issuer: String, - */ // The decoded JWT payload. pub raw_payload: String, } diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index 3a245d927d5..db899188d65 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -366,9 +366,31 @@ impl WasmModuleInstance { let mut tx = tx.unwrap_or_else(|| stdb.begin_mut_tx(IsolationLevel::Serializable, workload)); let _guard = metric_reducer_plus_query_duration.with_timer(tx.timer); + // For OnConnect, we insert the credentials before the reducer, so we can look them up + // inside that reducer. + // If the connection is rejected, this should get rolled back. if let Some(Lifecycle::OnConnect) = reducer_def.lifecycle { - tx.insert_st_client_credentials(caller_connection_id, &client.clone().unwrap().auth.jwt_payload) - .unwrap(); + let client_clone = match client.clone() { + Some(client) => client, + None => { + log::error!("OnConnect reducer called without a client"); + return ReducerCallResult { + outcome: ReducerOutcome::Failed("OnConnect reducer called without a client".into()), + energy_used: EnergyQuanta::ZERO, + execution_duration: Duration::ZERO, + }; + } + }; + if let Some(err) = tx + .insert_st_client_credentials(caller_connection_id, &client_clone.auth.jwt_payload) + .err() + { + return ReducerCallResult { + outcome: ReducerOutcome::Failed(format!("Error inserting client credentials: {err}")), + energy_used: EnergyQuanta::ZERO, + execution_duration: Duration::ZERO, + }; + } }; let mut tx_slot = self.instance.instance_env().tx.clone(); From a336ba855ad0e143b3d1d6666359de34799a21f0 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Wed, 30 Jul 2025 08:21:53 -0700 Subject: [PATCH 06/17] Fix some system table issues. --- crates/core/src/host/module_host.rs | 72 +++++++++++-------- .../src/host/wasm_common/module_host_actor.rs | 31 +------- .../src/locking_tx_datastore/mut_tx.rs | 24 +++---- crates/datastore/src/system_tables.rs | 6 ++ 4 files changed, 60 insertions(+), 73 deletions(-) diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index b5561724404..1c1e9b40d5f 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -691,6 +691,37 @@ impl ModuleHost { let me = self.clone(); self.call("call_identity_connected", move |inst| { let reducer_lookup = me.info.module_def.lifecycle_reducer(Lifecycle::OnConnect); + let stdb = me.module.replica_ctx().relational_db.clone(); + let workload = Workload::Reducer(ReducerContext { + name: "call_identity_connected".to_owned(), + caller_identity: caller_auth.claims.identity, + caller_connection_id, + timestamp: Timestamp::now(), + arg_bsatn: Bytes::new(), + }); + let mut mut_tx = stdb.begin_mut_tx( + IsolationLevel::Serializable, + workload + ); + mut_tx + .insert_st_client(caller_auth.claims.identity, caller_connection_id) + .inspect_err(|e| { + log::error!( + "`call_identity_connected`: fallback transaction to insert into `st_client` failed: {e:#?}" + ) + }) + .map_err(DBError::from)?; + mut_tx + .insert_st_client_credentials(caller_connection_id, &caller_auth.jwt_payload) + .inspect_err(|e| { + log::error!( + "`call_identity_connected`: fallback transaction to insert into `st_client_credetials` failed: {e:#?}" + ) + }) + .map_err(DBError::from)?; + + + // let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::Internal); if let Some((reducer_id, reducer_def)) = reducer_lookup { // The module defined a lifecycle reducer to handle new connections. @@ -698,6 +729,7 @@ impl ModuleHost { // If the call fails (as in, something unexpectedly goes wrong with WASM execution), // abort the connection: we can't really recover. let reducer_outcome = me.call_reducer_inner_with_inst( + Some(mut_tx), caller_auth.claims.identity, Some(caller_connection_id), None, @@ -729,38 +761,16 @@ impl ModuleHost { } } else { // The module doesn't define a client_connected reducer. - // Commit a transaction to update `st_clients` - // and to ensure we always have those events paired in the commitlog. + // We need to commit the transaction to update st_clients and st_connection_credentials. // // This is necessary to be able to disconnect clients after a server crash. - let reducer_name = reducer_lookup - .as_ref() - .map(|(_, def)| &*def.name) - .unwrap_or("__identity_connected__"); - - let workload = Workload::Reducer(ReducerContext { - name: reducer_name.to_owned(), - caller_identity: caller_auth.claims.identity, - caller_connection_id, - timestamp: Timestamp::now(), - arg_bsatn: Bytes::new(), - }); - let stdb = me.module.replica_ctx().relational_db.clone(); - stdb.with_auto_commit(workload, |mut_tx| { - mut_tx - .insert_st_client(caller_auth.claims.identity, caller_connection_id) - .map_err(DBError::from)?; - mut_tx - .insert_st_client_credentials(caller_connection_id, &caller_auth.jwt_payload) - .map_err(DBError::from) - }) - .inspect_err(|e| { - log::error!( - "`call_identity_connected`: fallback transaction to insert into `st_client` failed: {e:#?}" - ) - }) - .map_err(Into::into) + // TODO: report the metrics. + // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions + // not allowed? + // I don't think it was being broadcast previously. + mut_tx.commit(); + Ok(()) } }) .await @@ -814,6 +824,7 @@ impl ModuleHost { // If it succeeds, `WasmModuleInstance::call_reducer_with_tx` has already ensured // that `st_client` is updated appropriately. let result = me.call_reducer_inner_with_inst( + None, caller_identity, Some(caller_connection_id), None, @@ -918,6 +929,7 @@ impl ModuleHost { } fn call_reducer_inner_with_inst( &self, + tx: Option, caller_identity: Identity, caller_connection_id: Option, client: Option>, @@ -933,7 +945,7 @@ impl ModuleHost { let caller_connection_id = caller_connection_id.unwrap_or(ConnectionId::ZERO); Ok(module_instance.call_reducer( - None, + tx, CallReducerParams { timestamp: Timestamp::now(), caller_identity, diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index db899188d65..0b66ca197f1 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -363,36 +363,9 @@ impl WasmModuleInstance { .with_label_values(&database_identity, reducer_name); let workload = Workload::Reducer(ReducerContext::from(op.clone())); - let mut tx = tx.unwrap_or_else(|| stdb.begin_mut_tx(IsolationLevel::Serializable, workload)); + let tx = tx.unwrap_or_else(|| stdb.begin_mut_tx(IsolationLevel::Serializable, workload)); let _guard = metric_reducer_plus_query_duration.with_timer(tx.timer); - // For OnConnect, we insert the credentials before the reducer, so we can look them up - // inside that reducer. - // If the connection is rejected, this should get rolled back. - if let Some(Lifecycle::OnConnect) = reducer_def.lifecycle { - let client_clone = match client.clone() { - Some(client) => client, - None => { - log::error!("OnConnect reducer called without a client"); - return ReducerCallResult { - outcome: ReducerOutcome::Failed("OnConnect reducer called without a client".into()), - energy_used: EnergyQuanta::ZERO, - execution_duration: Duration::ZERO, - }; - } - }; - if let Some(err) = tx - .insert_st_client_credentials(caller_connection_id, &client_clone.auth.jwt_payload) - .err() - { - return ReducerCallResult { - outcome: ReducerOutcome::Failed(format!("Error inserting client credentials: {err}")), - energy_used: EnergyQuanta::ZERO, - execution_duration: Duration::ZERO, - }; - } - }; - let mut tx_slot = self.instance.instance_env().tx.clone(); let reducer_span = tracing::trace_span!( @@ -484,7 +457,7 @@ impl WasmModuleInstance { // and conversely removing from `st_clients` on disconnect. Ok(Ok(())) => { let res = match reducer_def.lifecycle { - Some(Lifecycle::OnConnect) => tx.insert_st_client(caller_identity, caller_connection_id), + Some(Lifecycle::OnConnect) => Ok(()), Some(Lifecycle::OnDisconnect) => { tx.delete_st_client(caller_identity, caller_connection_id, database_identity) } diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index 884b07c7538..d326ba9ad02 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -10,7 +10,9 @@ use super::{ }; use crate::execution_context::ExecutionContext; use crate::execution_context::Workload; -use crate::system_tables::{StConnectionCredentialsFields, StConnectionCredentialsRow, ST_CONNECTION_CREDENTIALS_ID}; +use crate::system_tables::{ + ConnectionIdViaU128, StConnectionCredentialsFields, StConnectionCredentialsRow, ST_CONNECTION_CREDENTIALS_ID, +}; use crate::traits::{InsertFlags, RowTypeForTable, TxData, UpdateFlags}; use crate::{ error::{IndexError, SequenceError, TableError}, @@ -1372,20 +1374,14 @@ impl MutTxId { database_identity: Identity, connection_id: ConnectionId, ) -> Result<()> { - if let Some(ptr) = self - .iter_by_col_eq( - ST_CONNECTION_CREDENTIALS_ID, - StConnectionCredentialsFields::ConnectionId, - &connection_id.into(), - )? - .next() - .map(|row| row.pointer()) - { - self.delete(ST_CONNECTION_CREDENTIALS_ID, ptr).map(drop) - } else { - log::warn!("[{database_identity}]: delete_st_client_credentials: attempting to credentials for missing connection id ({connection_id})"); - Ok(()) + if let Err(e) = self.delete_col_eq( + ST_CONNECTION_CREDENTIALS_ID, + StConnectionCredentialsFields::ConnectionId.col_id(), + &ConnectionIdViaU128::from(connection_id).into(), + ) { + log::error!("[{database_identity}]: delete_st_client_credentials: attempting to delete credentials for missing connection id ({connection_id}), error: {e}"); } + Ok(()) } pub fn delete_st_client( diff --git a/crates/datastore/src/system_tables.rs b/crates/datastore/src/system_tables.rs index 1c6fdbbb036..781c8835b0a 100644 --- a/crates/datastore/src/system_tables.rs +++ b/crates/datastore/src/system_tables.rs @@ -869,6 +869,12 @@ impl From for ConnectionIdViaU128 { } } +impl From for AlgebraicValue { + fn from(val: ConnectionIdViaU128) -> Self { + AlgebraicValue::U128(val.0.to_u128().into()) + } +} + /// A wrapper for [`Identity`] that acts like [`AlgebraicType::U256`] for serialization purposes. #[derive(Clone, Copy, Debug, Eq, PartialEq)] pub struct IdentityViaU256(pub Identity); From 14a61584ed5882022aa4e9297861be1f374dddcf Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Fri, 1 Aug 2025 12:41:19 -0700 Subject: [PATCH 07/17] fmt --- crates/client-api/src/routes/subscribe.rs | 20 +++++++++++++++++--- crates/core/src/client/client_connection.rs | 2 +- crates/core/src/host/module_host.rs | 2 +- 3 files changed, 19 insertions(+), 5 deletions(-) diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 52aee06952d..07d444cd5fd 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -179,7 +179,13 @@ where log::debug!("websocket: New client connected from {client_log_string}"); - let connected = match ClientConnection::call_client_connected_maybe_reject(&mut module_rx, client_id, auth.clone().into()).await { + let connected = match ClientConnection::call_client_connected_maybe_reject( + &mut module_rx, + client_id, + auth.clone().into(), + ) + .await + { Ok(connected) => { log::debug!("websocket: client_connected returned Ok for {client_log_string}"); connected @@ -201,8 +207,16 @@ where ); let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx); - let client = - ClientConnection::spawn(client_id, auth.into(), client_config, leader.replica_id, module_rx, actor, connected).await; + let client = ClientConnection::spawn( + client_id, + auth.into(), + client_config, + leader.replica_id, + module_rx, + actor, + connected, + ) + .await; // Send the client their identity token message as the first message // NOTE: We're adding this to the protocol because some client libraries are diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index da9315a7af2..f6bcbc43784 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -426,7 +426,7 @@ impl ClientConnection { pub async fn call_client_connected_maybe_reject( module_rx: &mut watch::Receiver, id: ClientActorId, - auth: ConnectionAuthCtx + auth: ConnectionAuthCtx, ) -> Result { let module = module_rx.borrow_and_update().clone(); module.call_identity_connected(auth, id.connection_id).await?; diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index e19861903e9..ebe0b052573 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -768,7 +768,7 @@ impl ModuleHost { // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions // not allowed? // I don't think it was being broadcast previously. - mut_tx.commit(); + let _ = mut_tx.commit(); Ok(()) } }) From 5df0d939fd4c67691eb60149a0424118772c7513 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Mon, 4 Aug 2025 14:10:54 -0700 Subject: [PATCH 08/17] Commit using the db, so it gets persisted. --- crates/core/src/host/module_host.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index ebe0b052573..e235727f45a 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -764,11 +764,15 @@ impl ModuleHost { // // This is necessary to be able to disconnect clients after a server crash. - // TODO: report the metrics. // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions // not allowed? // I don't think it was being broadcast previously. - let _ = mut_tx.commit(); + let stdb = me.module.replica_ctx().relational_db.clone(); + stdb.finish_tx(mut_tx, Ok(())) + .map_err(|e: DBError| { + log::error!("`call_identity_connected`: finish transaction failed: {e:#?}"); + ClientConnectedError::DBError(e) + })?; Ok(()) } }) From 37467bfdcf67c47bd9960cfe43b33de2b66805fc Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Tue, 5 Aug 2025 09:25:07 -0700 Subject: [PATCH 09/17] Cleanup --- crates/client-api/src/auth.rs | 16 +++++++++++----- crates/core/src/host/module_host.rs | 6 +----- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index e8175ee7b5b..2f0ca67ef75 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -85,15 +85,15 @@ impl SpacetimeCreds { pub struct SpacetimeAuth { pub creds: SpacetimeCreds, pub claims: SpacetimeIdentityClaims, - // The decoded JWT payload. - pub raw_payload: String, + /// The JWT payload as a json string (after base64 decoding). + pub jwt_payload: String, } impl From for ConnectionAuthCtx { fn from(auth: SpacetimeAuth) -> Self { ConnectionAuthCtx { claims: auth.claims, - jwt_payload: auth.raw_payload.clone(), + jwt_payload: auth.jwt_payload.clone(), } } } @@ -131,6 +131,9 @@ impl TokenClaims { Identity::from_claims(&self.issuer, &self.subject) } + /// Encode the claims into a JWT token and sign it with the provided signer. + /// This also adds claims for expiry and issued at time. + /// Returns an object representing the claims and the signed token. pub fn encode_and_sign_with_expiry( &self, signer: &impl TokenSigner, @@ -150,6 +153,9 @@ impl TokenClaims { Ok((claims, token)) } + /// Encode the claims into a JWT token and sign it with the provided signer. + /// This also adds a claim for issued at time. + /// Returns an object representing the claims and the signed token. pub fn encode_and_sign(&self, signer: &impl TokenSigner) -> Result<(SpacetimeIdentityClaims, String), JwtError> { self.encode_and_sign_with_expiry(signer, None) } @@ -177,7 +183,7 @@ impl SpacetimeAuth { Ok(Self { creds, claims, - raw_payload: payload, + jwt_payload: payload, }) } @@ -351,7 +357,7 @@ impl axum::extract::FromRequestParts for Space let auth = SpacetimeAuth { creds, claims, - raw_payload: payload, + jwt_payload: payload, }; Ok(Self { auth: Some(auth) }) } diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index e235727f45a..f39d341f16d 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -690,7 +690,7 @@ impl ModuleHost { let me = self.clone(); self.call("call_identity_connected", move |inst| { let reducer_lookup = me.info.module_def.lifecycle_reducer(Lifecycle::OnConnect); - let stdb = me.module.replica_ctx().relational_db.clone(); + let stdb = &me.module.replica_ctx().relational_db; let workload = Workload::Reducer(ReducerContext { name: "call_identity_connected".to_owned(), caller_identity: caller_auth.claims.identity, @@ -719,9 +719,6 @@ impl ModuleHost { }) .map_err(DBError::from)?; - - // let mut tx = db.begin_mut_tx(IsolationLevel::Serializable, Workload::Internal); - if let Some((reducer_id, reducer_def)) = reducer_lookup { // The module defined a lifecycle reducer to handle new connections. // Call this reducer. @@ -767,7 +764,6 @@ impl ModuleHost { // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions // not allowed? // I don't think it was being broadcast previously. - let stdb = me.module.replica_ctx().relational_db.clone(); stdb.finish_tx(mut_tx, Ok(())) .map_err(|e: DBError| { log::error!("`call_identity_connected`: finish transaction failed: {e:#?}"); From b063fe3bb9e0631bf2df59c7890adcc1b45878f7 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Tue, 5 Aug 2025 15:30:37 -0700 Subject: [PATCH 10/17] Rollback on errors during identity connected. --- crates/core/src/host/module_host.rs | 49 ++++++++--------- .../subscription/module_subscription_actor.rs | 55 +++++++++---------- .../src/locking_tx_datastore/mut_tx.rs | 31 ++++++++--- smoketests/config.toml | 1 - 4 files changed, 70 insertions(+), 66 deletions(-) diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index f39d341f16d..6c52e92c496 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -26,6 +26,7 @@ use derive_more::From; use indexmap::IndexSet; use itertools::Itertools; use prometheus::{Histogram, IntGauge}; +use scopeguard::ScopeGuard; use spacetimedb_auth::identity::ConnectionAuthCtx; use spacetimedb_client_api_messages::websocket::{ByteListLen, Compression, OneOffTable, QueryUpdate}; use spacetimedb_data_structures::error_stream::ErrorStream; @@ -692,31 +693,27 @@ impl ModuleHost { let reducer_lookup = me.info.module_def.lifecycle_reducer(Lifecycle::OnConnect); let stdb = &me.module.replica_ctx().relational_db; let workload = Workload::Reducer(ReducerContext { - name: "call_identity_connected".to_owned(), - caller_identity: caller_auth.claims.identity, - caller_connection_id, - timestamp: Timestamp::now(), - arg_bsatn: Bytes::new(), - }); - let mut mut_tx = stdb.begin_mut_tx( - IsolationLevel::Serializable, - workload - ); - mut_tx - .insert_st_client(caller_auth.claims.identity, caller_connection_id) - .inspect_err(|e| { - log::error!( - "`call_identity_connected`: fallback transaction to insert into `st_client` failed: {e:#?}" - ) - }) - .map_err(DBError::from)?; + name: "call_identity_connected".to_owned(), + caller_identity: caller_auth.claims.identity, + caller_connection_id, + timestamp: Timestamp::now(), + arg_bsatn: Bytes::new(), + }); + let mut_tx = stdb.begin_mut_tx(IsolationLevel::Serializable, workload); + let mut mut_tx = scopeguard::guard(mut_tx, |mut_tx| { + // If we crash before committing, we need to ensure that the transaction is rolled back. + // This is necessary to avoid leaving the database in an inconsistent state. + log::debug!("call_identity_connected: rolling back transaction"); + let (metrics, reducer_name) = mut_tx.rollback(); + stdb.report_mut_tx_metrics(reducer_name, metrics, None); + }); + mut_tx - .insert_st_client_credentials(caller_connection_id, &caller_auth.jwt_payload) - .inspect_err(|e| { - log::error!( - "`call_identity_connected`: fallback transaction to insert into `st_client_credetials` failed: {e:#?}" - ) - }) + .insert_st_client( + caller_auth.claims.identity, + caller_connection_id, + &caller_auth.jwt_payload, + ) .map_err(DBError::from)?; if let Some((reducer_id, reducer_def)) = reducer_lookup { @@ -725,7 +722,7 @@ impl ModuleHost { // If the call fails (as in, something unexpectedly goes wrong with WASM execution), // abort the connection: we can't really recover. let reducer_outcome = me.call_reducer_inner_with_inst( - Some(mut_tx), + Some(ScopeGuard::into_inner(mut_tx)), caller_auth.claims.identity, Some(caller_connection_id), None, @@ -764,7 +761,7 @@ impl ModuleHost { // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions // not allowed? // I don't think it was being broadcast previously. - stdb.finish_tx(mut_tx, Ok(())) + stdb.finish_tx(ScopeGuard::into_inner(mut_tx), Ok(())) .map_err(|e: DBError| { log::error!("`call_identity_connected`: finish transaction failed: {e:#?}"); ClientConnectedError::DBError(e) diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 48c2af5fb95..65211305898 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -876,37 +876,16 @@ impl ModuleSubscriptions { return Ok(Err(WriteConflict)); }; *db_update = DatabaseUpdate::from_writes(&tx_data); - (read_tx, Some(tx_data), tx_metrics) + (read_tx, Arc::new(tx_data), tx_metrics) } EventStatus::Failed(_) | EventStatus::OutOfEnergy => { - let (tx_metrics, tx) = stdb.rollback_mut_tx_downgrade(tx, Workload::Update); - (tx, None, tx_metrics) - } - }; - - let tx_data = tx_data.map(Arc::new); - - // When we're done with this method, release the tx and report metrics. - let mut read_tx = scopeguard::guard(read_tx, |tx| { - let (tx_metrics_read, reducer) = self.relational_db.release_tx(tx); - self.relational_db - .report_tx_metrics(reducer, tx_data.clone(), Some(tx_metrics_mut), Some(tx_metrics_read)); - }); - // Create the delta transaction we'll use to eval updates against. - let delta_read_tx = tx_data - .as_ref() - .as_ref() - .map(|tx_data| DeltaTx::new(&read_tx, tx_data, subscriptions.index_ids_for_subscriptions())) - .unwrap_or_else(|| DeltaTx::from(&*read_tx)); + // If the transaction failed, we need to rollback the mutable tx. + // We don't need to do any subscription updates in this case, so we will exit early. - let event = Arc::new(event); - let mut update_metrics: ExecutionMetrics = ExecutionMetrics::default(); - - match &event.status { - EventStatus::Committed(_) => { - update_metrics = subscriptions.eval_updates_sequential(&delta_read_tx, event.clone(), caller); - } - EventStatus::Failed(_) => { + let event = Arc::new(event); + let (tx_metrics, reducer) = stdb.rollback_mut_tx(tx); + self.relational_db + .report_tx_metrics(reducer, None, Some(tx_metrics), None); if let Some(client) = caller { let message = TransactionUpdateMessage { event: Some(event.clone()), @@ -917,9 +896,25 @@ impl ModuleSubscriptions { } else { log::trace!("Reducer failed but there is no client to send the failure to!") } + return Ok(Ok((event, ExecutionMetrics::default()))); } - EventStatus::OutOfEnergy => {} // ? - } + }; + let event = Arc::new(event); + + // When we're done with this method, release the tx and report metrics. + let mut read_tx = scopeguard::guard(read_tx, |tx| { + let (tx_metrics_read, reducer) = self.relational_db.release_tx(tx); + self.relational_db.report_tx_metrics( + reducer, + Some(tx_data.clone()), + Some(tx_metrics_mut), + Some(tx_metrics_read), + ); + }); + // Create the delta transaction we'll use to eval updates against. + let delta_read_tx = DeltaTx::new(&read_tx, tx_data.as_ref(), subscriptions.index_ids_for_subscriptions()); + + let update_metrics = subscriptions.eval_updates_sequential(&delta_read_tx, event.clone(), caller); // Merge in the subscription evaluation metrics. read_tx.metrics.merge(update_metrics); diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index d326ba9ad02..9a3009222f4 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -1173,7 +1173,7 @@ impl MutTxId { /// - [`TxData`], the set of inserts and deletes performed by this transaction. /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran during this transaction. - pub fn commit(mut self) -> (TxData, TxMetrics, String) { + pub(super) fn commit(mut self) -> (TxData, TxMetrics, String) { let tx_data = self.committed_state_write_lock.merge(self.tx_state, &self.ctx); // Compute and keep enough info that we can @@ -1352,33 +1352,46 @@ impl<'a, I: Iterator>> Iterator for FilterDeleted<'a, I> { } impl MutTxId { - pub fn insert_st_client(&mut self, identity: Identity, connection_id: ConnectionId) -> Result<()> { + pub fn insert_st_client( + &mut self, + identity: Identity, + connection_id: ConnectionId, + jwt_payload: &str, + ) -> Result<()> { let row = &StClientRow { identity: identity.into(), connection_id: connection_id.into(), }; - self.insert_via_serialize_bsatn(ST_CLIENT_ID, row).map(|_| ()) + self.insert_via_serialize_bsatn(ST_CLIENT_ID, row) + .map(|_| ()) + .inspect_err(|e| { + log::error!( + "[{identity}]: insert_st_client: failed to insert client ({identity}, {connection_id}), error: {e}" + ); + })?; + self.insert_st_client_credentials(connection_id, jwt_payload) } - pub fn insert_st_client_credentials(&mut self, connection_id: ConnectionId, jwt_payload: &str) -> Result<()> { + fn insert_st_client_credentials(&mut self, connection_id: ConnectionId, jwt_payload: &str) -> Result<()> { let row = &StConnectionCredentialsRow { connection_id: connection_id.into(), jwt_payload: jwt_payload.to_owned(), }; self.insert_via_serialize_bsatn(ST_CONNECTION_CREDENTIALS_ID, row) .map(|_| ()) + .inspect_err(|e| { + log::error!("[{connection_id}]: insert_st_client_credentials: failed to insert client credentials for connection id ({connection_id}), error: {e}"); + }) } - pub fn delete_st_client_credentials( - &mut self, - database_identity: Identity, - connection_id: ConnectionId, - ) -> Result<()> { + fn delete_st_client_credentials(&mut self, database_identity: Identity, connection_id: ConnectionId) -> Result<()> { if let Err(e) = self.delete_col_eq( ST_CONNECTION_CREDENTIALS_ID, StConnectionCredentialsFields::ConnectionId.col_id(), &ConnectionIdViaU128::from(connection_id).into(), ) { + // This is possible on restart if the database was previously running a version + // before this system table was added. log::error!("[{database_identity}]: delete_st_client_credentials: attempting to delete credentials for missing connection id ({connection_id}), error: {e}"); } Ok(()) diff --git a/smoketests/config.toml b/smoketests/config.toml index b7c4ad31a45..5a37a8381b6 100644 --- a/smoketests/config.toml +++ b/smoketests/config.toml @@ -1,5 +1,4 @@ default_server = "127.0.0.1:3000" -spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYzc3NDY1NTE5MDM2MTE4M2JiNjFmMWMxYzY3NDUzMzYzY2MxMTY4MmM1NTUwNWZiNjdlYzI0ZWMyMWViIiwic3ViIjoiOTJlMmNkOGQtNTk5Ny00NjZlLWIwNmYtZDNjOGQ1NzU3ODI4IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc1MjA0NjgwMCwiZXhwIjpudWxsfQ.dgefoxC7eCOONVUufu2JTVFo9876zQ4Mqwm0ivZ0PQK7Hacm3Ip_xqyav4bilZ0vIEf8IM8AB0_xawk8WcbvMg" [[server_configs]] nickname = "localhost" From 2462a95213db62af98e89c4520caff2e70ded944 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Wed, 6 Aug 2025 08:22:40 -0700 Subject: [PATCH 11/17] Update comments. --- crates/core/src/error.rs | 1 - crates/datastore/src/locking_tx_datastore/mut_tx.rs | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index 10bbeac171c..eff3b835f02 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -26,7 +26,6 @@ use spacetimedb_vm::expr::Crud; pub use spacetimedb_datastore::error::{DatastoreError, IndexError, SequenceError, TableError}; -// #[derive(Error, Debug, PartialEq, Eq)] #[derive(Error, Debug)] pub enum ClientError { #[error("Client not found: {0}")] diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index 9a3009222f4..076c9a1c17d 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -1167,7 +1167,8 @@ impl MutTxId { }) } - /// Commits this transaction, applying its changes to the committed state. + /// Commits this transaction in memory, applying its changes to the committed state. + /// This doesn't handle the persistence layer at all. /// /// Returns: /// - [`TxData`], the set of inserts and deletes performed by this transaction. From 7c384f3dac4aa894536646be1ebba297a871ff05 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Thu, 11 Sep 2025 13:42:55 -0400 Subject: [PATCH 12/17] Cleanup --- crates/core/src/db/relational_db.rs | 24 +++++++++---------- crates/core/src/host/host_controller.rs | 6 +++++ crates/core/src/host/module_host.rs | 16 +++++++++++++ .../src/host/wasm_common/module_host_actor.rs | 16 ++++++------- .../locking_tx_datastore/committed_state.rs | 5 +--- .../src/locking_tx_datastore/mut_tx.rs | 8 +++---- .../src/locking_tx_datastore/state_view.rs | 11 ++++++--- 7 files changed, 54 insertions(+), 32 deletions(-) diff --git a/crates/core/src/db/relational_db.rs b/crates/core/src/db/relational_db.rs index 95a16e25d6c..5c878fbd25d 100644 --- a/crates/core/src/db/relational_db.rs +++ b/crates/core/src/db/relational_db.rs @@ -393,7 +393,6 @@ impl RelationalDB { ); db.migrate_system_tables()?; - if let Some(meta) = db.metadata()? { if meta.database_identity != database_identity { return Err(anyhow!( @@ -418,18 +417,17 @@ impl RelationalDB { } fn migrate_system_tables(&self) -> Result<(), DBError> { - println!("Begin migrating system tables"); let mut tx = self.begin_mut_tx(IsolationLevel::Serializable, Workload::Internal); - println!("Got a lock"); for schema in system_tables() { if !self.table_id_exists_mut(&tx, &schema.table_id) { - println!("Missing system table {}: {}", schema.table_id, schema.table_name); + log::info!( + "[{}] DATABASE: adding missing system table {}", + self.database_identity, + schema.table_name + ); let _ = self.create_table(&mut tx, schema.clone())?; } - //let tid = tx.create_table(schema.clone())?; - } - println!("Going to commit"); let _ = self.commit_tx(tx)?; self.inner.assert_system_tables_match()?; Ok(()) @@ -3326,10 +3324,9 @@ mod tests { assert!(table_1_id > table_0_id); } + use crate::error::DBError; use fs_extra::dir::{copy, CopyOptions}; use itertools::Itertools; - use crate::error::DBError; - /// Make a temp dir with the contents of the given directory. fn copy_fixture_dir(src: &PathBuf) -> TempDir { @@ -3434,7 +3431,7 @@ mod tests { owner_identity, use_snapshot, )?; - let schemas = db.with_read_only(Workload::ForTests, |tx| db.get_all_tables(&tx))?; + let schemas = db.with_read_only(Workload::ForTests, |tx| db.get_all_tables(tx))?; let user_table_names: Vec = schemas .iter() .filter(|s| s.table_id.0 > 4000) @@ -3446,8 +3443,9 @@ mod tests { // directory. assert_eq!(user_table_names, expected_table_names); - db.with_auto_commit(Workload::ForTests, |mut tx| { - tx.insert_st_client(Identity::ZERO, ConnectionId::ZERO, "invalid_jwt".into()).unwrap(); + db.with_auto_commit(Workload::ForTests, |tx| { + tx.insert_st_client(Identity::ZERO, ConnectionId::ZERO, "invalid_jwt") + .unwrap(); Ok::<(), DBError>(()) })?; // Now we are going to shut it down and reopen it, to ensure that the new table can be @@ -3456,7 +3454,7 @@ mod tests { let handle = Arc::into_inner(durability_handle).expect("Durability handle should be dropped by db"); rt.block_on(handle.close()).expect("Failed to close durability handle"); - let (db, durability_handle) = TestDB::open_existing_durable( + let (db, _) = TestDB::open_existing_durable( &dir, rt.handle().clone(), 22000001, diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index 5c9f08fff52..85510b76053 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -931,6 +931,12 @@ impl Host { ) })?; } + // We should have no clients left, but we do this just in case. + // This should only matter if we crashed with something in st_client_credentials, + // then restarted with an older version of the code that doesn't use st_client_credentials. + // That case would cause some permanently dangling st_client_credentials. + // Since we have no clients on startup, this should be safe to do regardless. + module_host.clear_all_clients().await?; scheduler_starter.start(&module_host)?; let disk_metrics_recorder_task = tokio::spawn(metric_reporter(replica_ctx.clone())).abort_handle(); diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 6c52e92c496..5ca0ad4099b 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -33,6 +33,7 @@ use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_datastore::execution_context::{ExecutionContext, ReducerContext, Workload, WorkloadType}; use spacetimedb_datastore::locking_tx_datastore::MutTxId; +use spacetimedb_datastore::system_tables::{ST_CLIENT_ID, ST_CONNECTION_CREDENTIALS_ID}; use spacetimedb_datastore::traits::{IsolationLevel, Program, TxData}; use spacetimedb_execution::pipelined::PipelinedProject; use spacetimedb_lib::db::raw_def::v9::Lifecycle; @@ -890,6 +891,21 @@ impl ModuleHost { .map_err(Into::::into)? } + /// Empty the system tables tracking clients without running any lifecycle reducers. + pub async fn clear_all_clients(&self) -> anyhow::Result<()> { + let me = self.clone(); + self.call("clear_all_clients", move |_| { + let stdb = &me.module.replica_ctx().relational_db; + let workload = Workload::Internal; + stdb.with_auto_commit(workload, |mut_tx| { + stdb.clear_table(mut_tx, ST_CONNECTION_CREDENTIALS_ID)?; + stdb.clear_table(mut_tx, ST_CLIENT_ID) + }) + }) + .await? + .map_err(anyhow::Error::from) + } + async fn call_reducer_inner( &self, caller_identity: Identity, diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index dc890d08e98..e8c3ca064b0 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -442,14 +442,14 @@ impl InstanceCommon { // We haven't actually committed yet - `commit_and_broadcast_event` will commit // for us and replace this with the actual database update. Ok(res) => match res.and_then(|()| { - // If this is an OnDisconnect lifecycle event, remove the client from st_clients. - // We handle OnConnect events before running the reducer. - match reducer_def.lifecycle { - Some(Lifecycle::OnDisconnect) => tx.delete_st_client(caller_identity, caller_connection_id, database_identity).map_err(|e| e.to_string().into()) - , - _ => Ok(()), - } - + // If this is an OnDisconnect lifecycle event, remove the client from st_clients. + // We handle OnConnect events before running the reducer. + match reducer_def.lifecycle { + Some(Lifecycle::OnDisconnect) => tx + .delete_st_client(caller_identity, caller_connection_id, database_identity) + .map_err(|e| e.to_string().into()), + _ => Ok(()), + } }) { Ok(()) => EventStatus::Committed(DatabaseUpdate::default()), Err(err) => { diff --git a/crates/datastore/src/locking_tx_datastore/committed_state.rs b/crates/datastore/src/locking_tx_datastore/committed_state.rs index c761227dba6..5fc4818988d 100644 --- a/crates/datastore/src/locking_tx_datastore/committed_state.rs +++ b/crates/datastore/src/locking_tx_datastore/committed_state.rs @@ -25,10 +25,7 @@ use crate::{ use anyhow::anyhow; use core::{convert::Infallible, ops::RangeBounds}; use spacetimedb_data_structures::map::{HashSet, IntMap}; -use spacetimedb_lib::{ - db::auth::{StAccess, StTableType}, - Identity, -}; +use spacetimedb_lib::{db::auth::StTableType, Identity}; use spacetimedb_primitives::{ColId, ColList, ColSet, IndexId, TableId}; use spacetimedb_sats::{algebraic_value::de::ValueDeserializer, memory_usage::MemoryUsage, Deserialize}; use spacetimedb_sats::{AlgebraicValue, ProductValue}; diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index c4e8261da95..94780a685c2 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -10,7 +10,10 @@ use super::{ }; use crate::execution_context::ExecutionContext; use crate::execution_context::Workload; -use crate::system_tables::{system_tables, ConnectionIdViaU128, StConnectionCredentialsFields, StConnectionCredentialsRow, ST_CONNECTION_CREDENTIALS_ID}; +use crate::system_tables::{ + system_tables, ConnectionIdViaU128, StConnectionCredentialsFields, StConnectionCredentialsRow, + ST_CONNECTION_CREDENTIALS_ID, +}; use crate::traits::{InsertFlags, RowTypeForTable, TxData, UpdateFlags}; use crate::{ error::{IndexError, SequenceError, TableError}, @@ -58,8 +61,6 @@ use std::{ sync::Arc, time::{Duration, Instant}, }; -use anyhow::anyhow; -use crate::error::DatastoreError; type DecodeResult = core::result::Result; @@ -1373,7 +1374,6 @@ impl<'a, I: Iterator>> Iterator for FilterDeleted<'a, I> { } impl MutTxId { - pub fn insert_st_client( &mut self, identity: Identity, diff --git a/crates/datastore/src/locking_tx_datastore/state_view.rs b/crates/datastore/src/locking_tx_datastore/state_view.rs index 3c6f34a5526..c36fd9e5ff2 100644 --- a/crates/datastore/src/locking_tx_datastore/state_view.rs +++ b/crates/datastore/src/locking_tx_datastore/state_view.rs @@ -1,8 +1,15 @@ use super::mut_tx::{FilterDeleted, IndexScanRanged}; use super::{committed_state::CommittedState, datastore::Result, tx_state::TxState}; use crate::error::{DatastoreError, TableError}; -use crate::system_tables::{ConnectionIdViaU128, StColumnFields, StColumnRow, StConnectionCredentialsFields, StConnectionCredentialsRow, StConstraintFields, StConstraintRow, StIndexFields, StIndexRow, StScheduledFields, StScheduledRow, StSequenceFields, StSequenceRow, StTableFields, StTableRow, SystemTable, ST_COLUMN_ID, ST_CONNECTION_CREDENTIALS_ID, ST_CONSTRAINT_ID, ST_INDEX_ID, ST_SCHEDULED_ID, ST_SEQUENCE_ID, ST_TABLE_ID}; +use crate::system_tables::{ + ConnectionIdViaU128, StColumnFields, StColumnRow, StConnectionCredentialsFields, StConnectionCredentialsRow, + StConstraintFields, StConstraintRow, StIndexFields, StIndexRow, StScheduledFields, StScheduledRow, + StSequenceFields, StSequenceRow, StTableFields, StTableRow, SystemTable, ST_COLUMN_ID, + ST_CONNECTION_CREDENTIALS_ID, ST_CONSTRAINT_ID, ST_INDEX_ID, ST_SCHEDULED_ID, ST_SEQUENCE_ID, ST_TABLE_ID, +}; +use anyhow::anyhow; use core::ops::RangeBounds; +use spacetimedb_lib::ConnectionId; use spacetimedb_primitives::{ColList, TableId}; use spacetimedb_sats::AlgebraicValue; use spacetimedb_schema::schema::{ColumnSchema, TableSchema}; @@ -11,8 +18,6 @@ use spacetimedb_table::{ table::{IndexScanRangeIter, RowRef, Table, TableScanIter}, }; use std::sync::Arc; -use anyhow::anyhow; -use spacetimedb_lib::ConnectionId; // StateView trait, is designed to define the behavior of viewing internal datastore states. // Currently, it applies to: CommittedState, MutTxId, and TxId. From 6819fe6d89996c951a82ac97bb83e30ffbd5dcc7 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Fri, 12 Sep 2025 17:32:24 -0400 Subject: [PATCH 13/17] Attempt to merge, but this fails some tests. --- .github/workflows/ci.yml | 10 +- Cargo.lock | 236 ++++++- Cargo.toml | 5 +- crates/cli/src/common_args.rs | 8 + crates/cli/src/main.rs | 4 +- crates/cli/src/subcommands/sql.rs | 73 ++- crates/cli/src/subcommands/subscribe.rs | 26 +- crates/client-api-messages/src/name.rs | 23 +- crates/client-api/src/auth.rs | 20 +- crates/client-api/src/lib.rs | 42 +- crates/client-api/src/routes/database.rs | 190 +++++- crates/client-api/src/routes/subscribe.rs | 129 ++-- crates/client-api/src/util.rs | 9 +- crates/core/src/auth/mod.rs | 9 +- crates/core/src/client.rs | 4 +- crates/core/src/client/client_connection.rs | 469 +++++++++++++- crates/core/src/db/relational_db.rs | 36 +- crates/core/src/error.rs | 3 + crates/core/src/host/host_controller.rs | 73 ++- crates/core/src/host/mod.rs | 4 +- crates/core/src/host/module_host.rs | 155 ++--- crates/core/src/host/v8/mod.rs | 8 +- .../src/host/wasm_common/module_host_actor.rs | 29 +- crates/core/src/messages/control_db.rs | 4 + crates/core/src/sql/execute.rs | 34 +- .../subscription/module_subscription_actor.rs | 538 +++++++++++---- .../module_subscription_manager.rs | 165 +++-- crates/datastore/Cargo.toml | 2 +- .../locking_tx_datastore/committed_state.rs | 4 +- .../src/locking_tx_datastore/datastore.rs | 12 +- .../src/locking_tx_datastore/mut_tx.rs | 27 +- .../datastore/src/locking_tx_datastore/tx.rs | 20 +- crates/datastore/src/traits.rs | 29 +- crates/durability/Cargo.toml | 1 + crates/durability/src/imp/local.rs | 45 +- crates/durability/src/lib.rs | 70 +- crates/execution/src/lib.rs | 2 +- crates/pg/Cargo.toml | 22 + crates/pg/LICENSE | 1 + crates/pg/README.md | 3 + crates/pg/src/encoder.rs | 301 +++++++++ crates/pg/src/lib.rs | 2 + crates/pg/src/pg_server.rs | 381 +++++++++++ crates/sats/src/satn.rs | 613 +++++++++++------- crates/sats/src/ser.rs | 27 +- crates/sats/src/ser/impls.rs | 15 +- crates/sats/src/time_duration.rs | 16 + crates/sats/src/timestamp.rs | 10 +- crates/schema/src/auto_migrate.rs | 110 +++- crates/standalone/Cargo.toml | 2 + crates/standalone/src/lib.rs | 34 +- crates/standalone/src/subcommands/start.rs | 22 +- crates/testing/src/modules.rs | 2 + docker-compose.yml | 2 + docs/docs/cli-reference.md | 2 + sdks/rust/src/db_connection.rs | 18 + sdks/rust/src/websocket.rs | 7 + smoketests/__init__.py | 26 +- smoketests/config.toml | 2 +- smoketests/tests/confirmed_reads.py | 52 ++ smoketests/tests/pg_wire.py | 293 +++++++++ smoketests/tests/replication.py | 13 +- 62 files changed, 3761 insertions(+), 733 deletions(-) create mode 100644 crates/pg/Cargo.toml create mode 120000 crates/pg/LICENSE create mode 100644 crates/pg/README.md create mode 100644 crates/pg/src/encoder.rs create mode 100644 crates/pg/src/lib.rs create mode 100644 crates/pg/src/pg_server.rs create mode 100644 smoketests/tests/confirmed_reads.py create mode 100644 smoketests/tests/pg_wire.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 62edf2d64a9..84576981d69 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -21,7 +21,7 @@ jobs: include: - { runner: spacetimedb-runner, smoketest_args: --docker } - { runner: windows-latest, smoketest_args: --no-build-cli } - runner: [spacetimedb-runner, windows-latest] + runner: [ spacetimedb-runner, windows-latest ] runs-on: ${{ matrix.runner }} steps: - name: Find Git ref @@ -44,6 +44,10 @@ jobs: - uses: actions/setup-dotnet@v4 with: global-json-file: modules/global.json + - name: Install psql (Windows) + if: runner.os == 'Windows' + run: choco install psql -y --no-progress + shell: powershell - name: Build and start database (Linux) if: runner.os == 'Linux' run: docker compose up -d @@ -54,11 +58,13 @@ jobs: Start-Process target/debug/spacetimedb-cli.exe start cd modules # the sdk-manifests on windows-latest are messed up, so we need to update them - dotnet workload config --update-mode workload-set + dotnet workload config --update-mode manifests dotnet workload update - uses: actions/setup-python@v5 with: { python-version: '3.12' } if: runner.os == 'Windows' + - name: Install psycopg2 + run: python -m pip install psycopg2-binary - name: Run smoketests # Note: clear_database and replication only work in private run: python -m smoketests ${{ matrix.smoketest_args }} -x clear_database replication diff --git a/Cargo.lock b/Cargo.lock index 014abee6ae8..5bed876a5bc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -195,6 +195,12 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "array-init" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d62b7694a562cdf5a74227903507c56ab2cc8bdd1f781ed5cb4cf9c9f810bfc" + [[package]] name = "arrayref" version = "0.3.9" @@ -293,6 +299,30 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-lc-rs" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878" +dependencies = [ + "aws-lc-sys", + "untrusted 0.7.1", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1" +dependencies = [ + "bindgen 0.69.5", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "axum" version = "0.7.9" @@ -429,6 +459,29 @@ dependencies = [ "serde", ] +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools 0.12.1", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.101", + "which 4.4.2", +] + [[package]] name = "bindgen" version = "0.71.1" @@ -444,7 +497,7 @@ dependencies = [ "proc-macro2", "quote", "regex", - "rustc-hash", + "rustc-hash 2.1.1", "shlex", "syn 2.0.101", ] @@ -925,6 +978,15 @@ dependencies = [ "winapi", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "cobs" version = "0.2.3" @@ -1092,7 +1154,7 @@ dependencies = [ "hashbrown 0.14.5", "log", "regalloc2", - "rustc-hash", + "rustc-hash 2.1.1", "smallvec", "target-lexicon", ] @@ -1485,6 +1547,17 @@ dependencies = [ "serde", ] +[[package]] +name = "derive-new" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "derive_arbitrary" version = "1.4.1" @@ -1602,6 +1675,12 @@ dependencies = [ "shared_child", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + [[package]] name = "educe" version = "0.4.23" @@ -2424,7 +2503,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2", + "socket2 0.5.9", "tokio", "tower-service", "tracing", @@ -2512,7 +2591,7 @@ dependencies = [ "hyper 1.6.0", "libc", "pin-project-lite", - "socket2", + "socket2 0.5.9", "tokio", "tower-service", "tracing", @@ -2858,6 +2937,17 @@ dependencies = [ "cfg-if", ] +[[package]] +name = "io-uring" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d93587f37623a1a17d94ef2bc9ada592f5465fe7732084ab7beefabe5c77c0c4" +dependencies = [ + "bitflags 2.9.0", + "cfg-if", + "libc", +] + [[package]] name = "ipnet" version = "2.11.0" @@ -3000,12 +3090,41 @@ dependencies = [ "spacetimedb 1.4.0", ] +[[package]] +name = "lazy-regex" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60c7310b93682b36b98fa7ea4de998d3463ccbebd94d935d6b48ba5b6ffa7126" +dependencies = [ + "lazy-regex-proc_macros", + "once_cell", + "regex-lite", +] + +[[package]] +name = "lazy-regex-proc_macros" +version = "3.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ba01db5ef81e17eb10a5e0f2109d1b3a3e29bac3070fdbd7d156bf7dbd206a1" +dependencies = [ + "proc-macro2", + "quote", + "regex", + "syn 2.0.101", +] + [[package]] name = "lazy_static" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "leb128" version = "0.2.5" @@ -3204,6 +3323,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + [[package]] name = "memchr" version = "2.7.4" @@ -3794,6 +3919,31 @@ dependencies = [ "postgres-types", ] +[[package]] +name = "pgwire" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddf403a6ee31cf7f2217b2bd8447cb13dbb6c268d7e81501bc78a4d3daafd294" +dependencies = [ + "async-trait", + "aws-lc-rs", + "bytes", + "chrono", + "derive-new", + "futures", + "hex", + "lazy-regex", + "md5", + "postgres-types", + "rand 0.9.1", + "rust_decimal", + "rustls-pki-types", + "thiserror 2.0.12", + "tokio", + "tokio-rustls", + "tokio-util", +] + [[package]] name = "phf" version = "0.11.3" @@ -3945,6 +4095,7 @@ version = "0.2.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48" dependencies = [ + "array-init", "bytes", "chrono", "fallible-iterator 0.2.0", @@ -4434,7 +4585,7 @@ checksum = "12908dbeb234370af84d0579b9f68258a0f67e201412dd9a2814e6f45b2fc0f0" dependencies = [ "hashbrown 0.14.5", "log", - "rustc-hash", + "rustc-hash 2.1.1", "slice-group-by", "smallvec", ] @@ -4471,6 +4622,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -4612,7 +4769,7 @@ dependencies = [ "cfg-if", "getrandom 0.2.16", "libc", - "untrusted", + "untrusted 0.9.0", "windows-sys 0.52.0", ] @@ -4723,6 +4880,12 @@ version = "0.1.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "719b953e2095829ee67db738b3bfa9fa368c94900df327b3f07fe6e794d2fe1f" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rustc-hash" version = "2.1.1" @@ -4770,6 +4933,8 @@ version = "0.23.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "730944ca083c1c233a75c09f199e973ca499344a2b7ba9e755c457e86fb4a321" dependencies = [ + "aws-lc-rs", + "log", "once_cell", "rustls-pki-types", "rustls-webpki", @@ -4810,9 +4975,10 @@ version = "0.103.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7149975849f1abb3832b246010ef62ccc80d3a76169517ada7188252b9cfb437" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", - "untrusted", + "untrusted 0.9.0", ] [[package]] @@ -5265,6 +5431,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "socket2" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "233504af464074f9d066d7b5416c5f9b894a5862a6506e306f7b816cdd6f1807" +dependencies = [ + "libc", + "windows-sys 0.59.0", +] + [[package]] name = "spacetime-module" version = "0.1.0" @@ -5654,7 +5830,7 @@ dependencies = [ "regex", "reqwest 0.12.15", "rustc-demangle", - "rustc-hash", + "rustc-hash 2.1.1", "scopeguard", "semver", "serde", @@ -5776,6 +5952,7 @@ dependencies = [ "spacetimedb-commitlog", "spacetimedb-paths", "spacetimedb-sats 1.4.0", + "thiserror 1.0.69", "tokio", "tracing", ] @@ -5939,6 +6116,24 @@ dependencies = [ "xdg", ] +[[package]] +name = "spacetimedb-pg" +version = "1.4.0" +dependencies = [ + "anyhow", + "async-trait", + "axum", + "futures", + "http 1.3.1", + "log", + "pgwire", + "spacetimedb-client-api", + "spacetimedb-client-api-messages", + "spacetimedb-lib 1.4.0", + "thiserror 1.0.69", + "tokio", +] + [[package]] name = "spacetimedb-physical-plan" version = "1.4.0" @@ -6179,13 +6374,15 @@ dependencies = [ "serde", "serde_json", "sled", - "socket2", + "socket2 0.5.9", "spacetimedb-client-api", "spacetimedb-client-api-messages", "spacetimedb-core", "spacetimedb-datastore", "spacetimedb-lib 1.4.0", "spacetimedb-paths", + "spacetimedb-pg", + "spacetimedb-schema", "spacetimedb-table", "tempfile", "thiserror 1.0.69", @@ -6909,20 +7106,22 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.45.0" +version = "1.47.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2513ca694ef9ede0fb23fe71a4ee4107cb102b9dc1930f6d0fd77aae068ae165" +checksum = "89e49afdadebb872d3145a5638b59eb0691ea23e46ca484037cfab3b76b95038" dependencies = [ "backtrace", "bytes", + "io-uring", "libc", "mio 1.0.3", "parking_lot 0.12.3", "pin-project-lite", "signal-hook-registry", - "socket2", + "slab", + "socket2 0.6.0", "tokio-macros", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -6978,7 +7177,7 @@ dependencies = [ "postgres-protocol", "postgres-types", "rand 0.9.1", - "socket2", + "socket2 0.5.9", "tokio", "tokio-util", "whoami", @@ -7291,6 +7490,7 @@ dependencies = [ "rand 0.9.1", "sha1", "thiserror 2.0.12", + "url", "utf-8", ] @@ -7372,6 +7572,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a156c684c91ea7d62626509bce3cb4e1d9ed5c4d978f7b4352658f96a4c26b4a" + [[package]] name = "untrusted" version = "0.9.0" @@ -7452,7 +7658,7 @@ version = "137.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ca393e2032ddba2a57169e15cac5d0a81cdb3d872a8886f4468bc0f486098d2" dependencies = [ - "bindgen", + "bindgen 0.71.1", "bitflags 2.9.0", "fslock", "gzip-header", diff --git a/Cargo.toml b/Cargo.toml index 59776ccc151..cb45e9d89d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/lib", "crates/metrics", "crates/paths", + "crates/pg", "crates/physical-plan", "crates/primitives", "crates/query", @@ -115,6 +116,7 @@ spacetimedb-lib = { path = "crates/lib", default-features = false, version = "1. spacetimedb-memory-usage = { path = "crates/memory-usage", version = "1.4.0", default-features = false } spacetimedb-metrics = { path = "crates/metrics", version = "1.4.0" } spacetimedb-paths = { path = "crates/paths", version = "1.4.0" } +spacetimedb-pg = { path = "crates/pg", version = "1.4.0" } spacetimedb-physical-plan = { path = "crates/physical-plan", version = "1.4.0" } spacetimedb-primitives = { path = "crates/primitives", version = "1.4.0" } spacetimedb-query = { path = "crates/query", version = "1.4.0" } @@ -214,6 +216,7 @@ paste = "1.0" percent-encoding = "2.3" petgraph = { version = "0.6.5", default-features = false } pin-project-lite = "0.2.9" +pgwire = { version = "0.32", features = ["server-api"] } postgres-types = "0.2.5" pretty_assertions = { version = "1.4", features = ["unstable"] } proc-macro2 = "1.0" @@ -268,7 +271,7 @@ tokio = { version = "1.37", features = ["full"] } tokio_metrics = { version = "0.4.0" } tokio-postgres = { version = "0.7.8", features = ["with-chrono-0_4"] } tokio-stream = "0.1.17" -tokio-tungstenite = { version = "0.27.0", features = ["native-tls"] } +tokio-tungstenite = { version = "0.27.0", features = ["native-tls", "url"] } tokio-util = { version = "0.7.4", features = ["time"] } toml = "0.8" toml_edit = "0.22.22" diff --git a/crates/cli/src/common_args.rs b/crates/cli/src/common_args.rs index a238ecbf0d8..f07247d841c 100644 --- a/crates/cli/src/common_args.rs +++ b/crates/cli/src/common_args.rs @@ -29,3 +29,11 @@ pub fn yes() -> Arg { .action(SetTrue) .help("Run non-interactively wherever possible. This will answer \"yes\" to almost all prompts, but will sometimes answer \"no\" to preserve non-interactivity (e.g. when prompting whether to log in with spacetimedb.com).") } + +pub fn confirmed() -> Arg { + Arg::new("confirmed") + .required(false) + .long("confirmed") + .action(SetTrue) + .help("Instruct the server to deliver only updates of confirmed transactions") +} diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index 4221cb9f009..5d768046a7f 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -3,7 +3,7 @@ use std::process::ExitCode; use clap::{Arg, Command}; use spacetimedb_cli::*; use spacetimedb_paths::cli::CliTomlPath; -use spacetimedb_paths::{RootDir, SpacetimePaths}; +use spacetimedb_paths::RootDir; // Note that the standalone server is invoked through standaline/src/main.rs, so you will // also want to set the allocator there. @@ -24,6 +24,8 @@ static GLOBAL: MiMalloc = MiMalloc; #[cfg(not(feature = "markdown-docs"))] #[tokio::main] async fn main() -> anyhow::Result { + use spacetimedb_paths::SpacetimePaths; + // Compute matches before loading the config, because `Config` has an observable `drop` method // (which deletes a lockfile), // and Clap calls `exit` on parse failure rather than panicking, so destructors never run. diff --git a/crates/cli/src/subcommands/sql.rs b/crates/cli/src/subcommands/sql.rs index 5a79190a45b..00a8d645e42 100644 --- a/crates/cli/src/subcommands/sql.rs +++ b/crates/cli/src/subcommands/sql.rs @@ -10,6 +10,7 @@ use anyhow::Context; use clap::{Arg, ArgAction, ArgMatches}; use reqwest::RequestBuilder; use spacetimedb_lib::de::serde::SeedWrapper; +use spacetimedb_lib::sats::satn::PsqlClient; use spacetimedb_lib::sats::{satn, ProductType, ProductValue, Typespace}; pub fn cli() -> clap::Command { @@ -34,6 +35,7 @@ pub fn cli() -> clap::Command { .conflicts_with("query") .help("Instead of using a query, run an interactive command prompt for `SQL` expressions"), ) + .arg(common_args::confirmed()) .arg(common_args::anonymous()) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) .arg(common_args::yes()) @@ -110,7 +112,7 @@ fn print_stmt_result( for (pos, result) in if_empty .into_iter() .chain(stmt_results.iter().map(|stmt_result| { - let (stats, table) = stmt_result_to_table(stmt_result)?; + let (stats, table) = stmt_result_to_table(PsqlClient::SpacetimeDB, stmt_result)?; anyhow::Ok(StmtResult { stats: with_stats.is_some().then_some(stats), @@ -156,12 +158,13 @@ pub(crate) async fn run_sql(builder: RequestBuilder, sql: &str, with_stats: bool Ok(()) } -fn stmt_result_to_table(stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> { +fn stmt_result_to_table(client: PsqlClient, stmt_result: &SqlStmtResult) -> anyhow::Result<(StmtStats, tabled::Table)> { let stats = StmtStats::from(stmt_result); let SqlStmtResult { schema, rows, .. } = stmt_result; let ty = Typespace::EMPTY.with_type(schema); let table = build_table( + client, schema, rows.iter().map(|row| from_json_seed(row.get(), SeedWrapper(ty))), )?; @@ -178,17 +181,22 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error crate::repl::exec(con).await?; } else { let query = args.get_one::("query").unwrap(); + let confirmed = args.get_flag("confirmed"); let con = parse_req(config, args).await?; - let api = ClientApi::new(con); + let mut api = ClientApi::new(con).sql(); + if confirmed { + api = api.query(&[("confirmed", "true")]); + } - run_sql(api.sql(), query, false).await?; + run_sql(api, query, false).await?; } Ok(()) } /// Generates a [`tabled::Table`] from a schema and rows, using the style of a psql table. fn build_table( + client: PsqlClient, schema: &ProductType, rows: impl Iterator>, ) -> Result { @@ -206,6 +214,7 @@ fn build_table( let row = row?; builder.push_record(ty.with_values(&row).enumerate().map(|(idx, value)| { let ty = satn::PsqlType { + client, tuple: ty.ty(), field: &ty.ty().elements[idx], idx, @@ -441,8 +450,10 @@ Roundtrip time: 1.00ms"#, Ok(()) } - fn expect_psql_table(ty: &ProductType, rows: Vec, expected: &str) { - let table = build_table(ty, rows.into_iter().map(Ok::<_, ()>)).unwrap().to_string(); + fn expect_psql_table(client: PsqlClient, ty: &ProductType, rows: Vec, expected: &str) { + let table = build_table(client, ty, rows.into_iter().map(Ok::<_, ()>)) + .unwrap() + .to_string(); let mut table = table.split('\n').map(|x| x.trim_end()).join("\n"); table.insert(0, '\n'); assert_eq!(expected, table); @@ -471,14 +482,25 @@ Roundtrip time: 1.00ms"#, ]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, - vec![value], + vec![value.clone()], r#" column 0 | column 1 | column 2 | column 3 | column 4 | column 5 ----------+----------+--------------------------------------------------------------------+------------------------------------+---------------------------+----------- "a" | 0 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value], + r#" + column 0 | column 1 | column 2 | column 3 | column 4 | column 5 +----------+----------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+---------- + "a" | 0 | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#, + ); + // Check struct let kind: ProductType = [ ("bool", AlgebraicType::Bool), @@ -502,6 +524,7 @@ Roundtrip time: 1.00ms"#, ]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, vec![value.clone()], r#" @@ -510,12 +533,23 @@ Roundtrip time: 1.00ms"#, true | "This is spacetimedb" | 0x01020304050607 | 0x0000000000000000000000000000000000000000000000000000000000000000 | 0x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | +0.000000"#, ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + bool | str | bytes | identity | connection_id | timestamp | duration +------+-----------------------+--------------------+----------------------------------------------------------------------+--------------------------------------+-----------------------------+---------- + true | "This is spacetimedb" | "0x01020304050607" | "0x0000000000000000000000000000000000000000000000000000000000000000" | "0x00000000000000000000000000000000" | "1970-01-01T00:00:00+00:00" | "P0D""#, + ); + // Check nested struct, tuple... let kind: ProductType = [(None, AlgebraicType::product(kind))].into(); let value = product![value.clone()]; expect_psql_table( + PsqlClient::SpacetimeDB, &kind, vec![value.clone()], r#" @@ -524,17 +558,38 @@ Roundtrip time: 1.00ms"#, (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000)"#, ); + expect_psql_table( + PsqlClient::Postgres, + &kind, + vec![value.clone()], + r#" + column 0 +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}"#, + ); + let kind: ProductType = [("tuple", AlgebraicType::product(kind))].into(); let value = product![value]; expect_psql_table( + PsqlClient::SpacetimeDB, + &kind, + vec![value.clone()], + r#" + tuple +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + (col_0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#, + ); + + expect_psql_table( + PsqlClient::Postgres, &kind, vec![value], r#" tuple ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- - (0 = (bool = true, str = "This is spacetimedb", bytes = 0x01020304050607, identity = 0x0000000000000000000000000000000000000000000000000000000000000000, connection_id = 0x00000000000000000000000000000000, timestamp = 1970-01-01T00:00:00+00:00, duration = +0.000000))"#, +-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"col_0": {"bool": true, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000000", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "P0D"}}"#, ); Ok(()) diff --git a/crates/cli/src/subcommands/subscribe.rs b/crates/cli/src/subcommands/subscribe.rs index c7b5058ba42..1220f987f5f 100644 --- a/crates/cli/src/subcommands/subscribe.rs +++ b/crates/cli/src/subcommands/subscribe.rs @@ -2,7 +2,7 @@ use anyhow::Context; use clap::{value_parser, Arg, ArgAction, ArgMatches}; use futures::{Sink, SinkExt, TryStream, TryStreamExt}; use http::header; -use http::uri::Scheme; +use reqwest::Url; use serde_json::Value; use spacetimedb_client_api_messages::websocket::{self as ws, JsonFormat}; use spacetimedb_data_structures::map::HashMap; @@ -65,6 +65,7 @@ pub fn cli() -> clap::Command { .action(ArgAction::SetTrue) .help("Print the initial update for the queries."), ) + .arg(common_args::confirmed()) .arg(common_args::anonymous()) .arg(common_args::yes()) .arg(common_args::server().help("The nickname, host name or URL of the server hosting the database")) @@ -130,25 +131,26 @@ pub async fn exec(config: Config, args: &ArgMatches) -> Result<(), anyhow::Error let num = args.get_one::("num-updates").copied(); let timeout = args.get_one::("timeout").copied(); let print_initial_update = args.get_flag("print_initial_update"); + let confirmed = args.get_flag("confirmed"); let conn = parse_req(config, args).await?; let api = ClientApi::new(conn); let module_def = api.module_def().await?; + let mut url = Url::parse(&api.con.db_uri("subscribe"))?; // Change the URI scheme from `http(s)` to `ws(s)`. - let mut uri = http::Uri::try_from(api.con.db_uri("subscribe"))?.into_parts(); - uri.scheme = uri.scheme.map(|s| { - if s == Scheme::HTTP { - "ws".parse().unwrap() - } else if s == Scheme::HTTPS { - "wss".parse().unwrap() - } else { - s - } - }); + url.set_scheme(match url.scheme() { + "http" => "ws", + "https" => "wss", + unknown => unreachable!("Invalid URL scheme in `Connection::db_uri`: {unknown}"), + }) + .unwrap(); + if confirmed { + url.query_pairs_mut().append_pair("confirmed", "true"); + } // Create the websocket request. - let mut req = http::Uri::from_parts(uri)?.into_client_request()?; + let mut req = url.into_client_request()?; req.headers_mut().insert( header::SEC_WEBSOCKET_PROTOCOL, http::HeaderValue::from_static(ws::TEXT_PROTOCOL), diff --git a/crates/client-api-messages/src/name.rs b/crates/client-api-messages/src/name.rs index 1966297cf26..49cee23a156 100644 --- a/crates/client-api-messages/src/name.rs +++ b/crates/client-api-messages/src/name.rs @@ -106,6 +106,27 @@ pub enum PublishResult { PermissionDenied { name: DatabaseName }, } +#[derive(serde::Serialize, serde::Deserialize, Debug, Default)] +pub enum MigrationPolicy { + #[default] + Compatible, + BreakClients, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug, Default)] +pub enum PrettyPrintStyle { + #[default] + AnsiColor, + NoColor, +} + +#[derive(serde::Serialize, serde::Deserialize, Debug)] +pub struct PrintPlanResult { + pub migrate_plan: Box, + pub break_clients: bool, + pub token: spacetimedb_lib::Hash, +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum DnsLookupResponse { /// The lookup was successful and the domain and identity are returned. @@ -150,7 +171,7 @@ pub enum SetDefaultDomainResult { /// /// Must match the regex `^[a-z0-9]+(-[a-z0-9]+)*$` #[derive(Clone, Debug, serde_with::DeserializeFromStr, serde_with::SerializeDisplay)] -pub struct DatabaseName(String); +pub struct DatabaseName(pub String); impl AsRef for DatabaseName { fn as_ref(&self) -> &str { diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index 2f0ca67ef75..f017ae4d55d 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -256,6 +256,10 @@ impl TokenSigner for JwtKeyAuthProvider { impl JwtAuthProvider for JwtKeyAuthProvider { type TV = TV; + fn validator(&self) -> &Self::TV { + &self.validator + } + fn local_issuer(&self) -> &str { &self.local_issuer } @@ -263,10 +267,6 @@ impl JwtAuthProvider for JwtKeyAuthProvider &[u8] { &self.keys.public_pem } - - fn validator(&self) -> &Self::TV { - &self.validator - } } #[cfg(test)] @@ -332,6 +332,13 @@ mod tests { } } +pub async fn validate_token( + state: &S, + token: &str, +) -> Result { + state.jwt_auth_provider().validator().validate_token(token).await +} + pub struct SpacetimeAuthHeader { auth: Option, } @@ -344,10 +351,7 @@ impl axum::extract::FromRequestParts for Space return Ok(Self { auth: None }); }; - let claims = state - .jwt_auth_provider() - .validator() - .validate_token(&creds.token) + let claims = validate_token(state, &creds.token) .await .map_err(AuthorizationRejection::Custom)?; diff --git a/crates/client-api/src/lib.rs b/crates/client-api/src/lib.rs index ac82a3e142d..fd426bb154c 100644 --- a/crates/client-api/src/lib.rs +++ b/crates/client-api/src/lib.rs @@ -7,7 +7,7 @@ use http::StatusCode; use spacetimedb::client::ClientActorIndex; use spacetimedb::energy::{EnergyBalance, EnergyQuanta}; -use spacetimedb::host::{HostController, ModuleHost, NoSuchModule, UpdateDatabaseResult}; +use spacetimedb::host::{HostController, MigratePlanResult, ModuleHost, NoSuchModule, UpdateDatabaseResult}; use spacetimedb::identity::{AuthCtx, Identity}; use spacetimedb::messages::control_db::{Database, HostType, Node, Replica}; use spacetimedb::sql; @@ -15,6 +15,7 @@ use spacetimedb_client_api_messages::http::{SqlStmtResult, SqlStmtStats}; use spacetimedb_client_api_messages::name::{DomainName, InsertDomainResult, RegisterTldResult, SetDomainsResult, Tld}; use spacetimedb_lib::{ProductTypeElement, ProductValue}; use spacetimedb_paths::server::ModuleLogsDir; +use spacetimedb_schema::auto_migrate::{MigrationPolicy, PrettyPrintStyle}; use tokio::sync::watch; pub mod auth; @@ -67,6 +68,7 @@ impl Host { &self, auth: AuthCtx, database: Database, + confirmed_read: bool, body: String, ) -> axum::response::Result>> { let module_host = self @@ -74,7 +76,7 @@ impl Host { .await .map_err(|_| (StatusCode::NOT_FOUND, "module not found".to_string()))?; - let json = self + let (tx_offset, durable_offset, json) = self .host_controller .using_database( database, @@ -115,17 +117,28 @@ impl Host { .map(|(col_name, col_type)| ProductTypeElement::new(col_type, Some(col_name))) .collect(); - Ok(vec![SqlStmtResult { - schema, - rows: result.rows, - total_duration_micros: total_duration.as_micros() as u64, - stats: SqlStmtStats::from_metrics(&result.metrics), - }]) + Ok(( + result.tx_offset, + db.durable_tx_offset(), + vec![SqlStmtResult { + schema, + rows: result.rows, + total_duration_micros: total_duration.as_micros() as u64, + stats: SqlStmtStats::from_metrics(&result.metrics), + }], + )) }, ) .await .map_err(log_and_500)??; + if confirmed_read { + if let Some(mut durable_offset) = durable_offset { + let tx_offset = tx_offset.await.map_err(|_| log_and_500("transaction aborted"))?; + durable_offset.wait_for(tx_offset).await.map_err(log_and_500)?; + } + } + Ok(json) } @@ -134,9 +147,10 @@ impl Host { database: Database, host_type: HostType, program_bytes: Box<[u8]>, + policy: MigrationPolicy, ) -> anyhow::Result { self.host_controller - .update_module_host(database, host_type, self.replica_id, program_bytes) + .update_module_host(database, host_type, self.replica_id, program_bytes, policy) .await } } @@ -219,8 +233,11 @@ pub trait ControlStateWriteAccess: Send + Sync { &self, publisher: &Identity, spec: DatabaseDef, + policy: MigrationPolicy, ) -> anyhow::Result>; + async fn migrate_plan(&self, spec: DatabaseDef, style: PrettyPrintStyle) -> anyhow::Result; + async fn delete_database(&self, caller_identity: &Identity, database_identity: &Identity) -> anyhow::Result<()>; // Energy @@ -309,8 +326,13 @@ impl ControlStateWriteAccess for Arc { &self, identity: &Identity, spec: DatabaseDef, + policy: MigrationPolicy, ) -> anyhow::Result> { - (**self).publish_database(identity, spec).await + (**self).publish_database(identity, spec, policy).await + } + + async fn migrate_plan(&self, spec: DatabaseDef, style: PrettyPrintStyle) -> anyhow::Result { + (**self).migrate_plan(spec, style).await } async fn delete_database(&self, caller_identity: &Identity, database_identity: &Identity) -> anyhow::Result<()> { diff --git a/crates/client-api/src/routes/database.rs b/crates/client-api/src/routes/database.rs index ff90b1e165c..5ff5102d355 100644 --- a/crates/client-api/src/routes/database.rs +++ b/crates/client-api/src/routes/database.rs @@ -7,7 +7,7 @@ use crate::auth::{ SpacetimeIdentityToken, }; use crate::routes::subscribe::generate_random_connection_id; -use crate::util::{ByteStringBody, NameOrIdentity}; +pub use crate::util::{ByteStringBody, NameOrIdentity}; use crate::{log_and_500, ControlStateDelegate, DatabaseDef, NodeDelegate}; use axum::body::{Body, Bytes}; use axum::extract::{Path, Query, State}; @@ -20,16 +20,21 @@ use http::StatusCode; use serde::Deserialize; use spacetimedb::database_logger::DatabaseLogger; use spacetimedb::host::module_host::ClientConnectedError; -use spacetimedb::host::ReducerArgs; use spacetimedb::host::ReducerCallError; use spacetimedb::host::ReducerOutcome; use spacetimedb::host::UpdateDatabaseResult; +use spacetimedb::host::{MigratePlanResult, ReducerArgs}; use spacetimedb::identity::Identity; use spacetimedb::messages::control_db::{Database, HostType}; -use spacetimedb_client_api_messages::name::{self, DatabaseName, DomainName, PublishOp, PublishResult}; +use spacetimedb_client_api_messages::name::{ + self, DatabaseName, DomainName, MigrationPolicy, PrettyPrintStyle, PrintPlanResult, PublishOp, PublishResult, +}; use spacetimedb_lib::db::raw_def::v9::RawModuleDefV9; use spacetimedb_lib::identity::AuthCtx; -use spacetimedb_lib::{sats, Timestamp}; +use spacetimedb_lib::{sats, ProductValue, Timestamp}; +use spacetimedb_schema::auto_migrate::{ + MigrationPolicy as SchemaMigrationPolicy, MigrationToken, PrettyPrintStyle as AutoMigratePrettyPrintStyle, +}; use super::subscribe::{handle_websocket, HasWebSocketOptions}; @@ -378,19 +383,24 @@ pub(crate) async fn worker_ctx_find_database( #[derive(Deserialize)] pub struct SqlParams { - name_or_identity: NameOrIdentity, + pub name_or_identity: NameOrIdentity, } #[derive(Deserialize)] -pub struct SqlQueryParams {} +pub struct SqlQueryParams { + /// If `true`, return the query result only after its transaction offset + /// is confirmed to be durable. + #[serde(default)] + pub confirmed: bool, +} -pub async fn sql( - State(worker_ctx): State, - Path(SqlParams { name_or_identity }): Path, - Query(SqlQueryParams {}): Query, - Extension(auth): Extension, - body: String, -) -> axum::response::Result +pub async fn sql_direct( + worker_ctx: S, + SqlParams { name_or_identity }: SqlParams, + SqlQueryParams { confirmed }: SqlQueryParams, + caller_identity: Identity, + sql: String, +) -> axum::response::Result>> where S: NodeDelegate + ControlStateDelegate, { @@ -402,7 +412,7 @@ where .await? .ok_or(NO_SUCH_DATABASE)?; - let auth = AuthCtx::new(database.owner_identity, auth.claims.identity); + let auth = AuthCtx::new(database.owner_identity, caller_identity); log::debug!("auth: {auth:?}"); let host = worker_ctx @@ -410,7 +420,21 @@ where .await .map_err(log_and_500)? .ok_or(StatusCode::NOT_FOUND)?; - let json = host.exec_sql(auth, database, body).await?; + + host.exec_sql(auth, database, confirmed, sql).await +} + +pub async fn sql( + State(worker_ctx): State, + Path(name_or_identity): Path, + Query(params): Query, + Extension(auth): Extension, + body: String, +) -> axum::response::Result +where + S: NodeDelegate + ControlStateDelegate, +{ + let json = sql_direct(worker_ctx, name_or_identity, params, auth.claims.identity, body).await?; let total_duration = json.iter().fold(0, |acc, x| acc + x.total_duration_micros); @@ -469,9 +493,18 @@ pub struct PublishDatabaseQueryParams { #[serde(default)] clear: bool, num_replicas: Option, + /// [`Hash`] of [`MigrationToken`]` to be checked if `MigrationPolicy::BreakClients` is set. + /// + /// Users obtain such a hash via the `/database/:name_or_identity/pre-publish POST` route. + /// This is a safeguard to require explicit approval for updates which will break clients. + token: Option, + #[serde(default)] + policy: MigrationPolicy, } +use spacetimedb_client_api_messages::http::SqlStmtResult; use std::env; + fn require_spacetime_auth_for_creation() -> bool { env::var("TEMP_REQUIRE_SPACETIME_AUTH").is_ok_and(|v| !v.is_empty()) } @@ -499,7 +532,12 @@ fn allow_creation(auth: &SpacetimeAuth) -> Result<(), ErrorResponse> { pub async fn publish( State(ctx): State, Path(PublishDatabaseParams { name_or_identity }): Path, - Query(PublishDatabaseQueryParams { clear, num_replicas }): Query, + Query(PublishDatabaseQueryParams { + clear, + num_replicas, + token, + policy, + }): Query, Extension(auth): Extension, body: Bytes, ) -> axum::response::Result> { @@ -507,7 +545,7 @@ pub async fn publish( // so, unless you are the owner, this will fail. let (database_identity, db_name) = match &name_or_identity { - Some(noa) => match noa.try_resolve(&ctx).await? { + Some(noa) => match noa.try_resolve(&ctx).await.map_err(log_and_500)? { Ok(resolved) => (resolved, noa.name()), Err(name) => { // `name_or_identity` was a `NameOrIdentity::Name`, but no record @@ -553,6 +591,21 @@ pub async fn publish( } }; + let policy: SchemaMigrationPolicy = match policy { + MigrationPolicy::BreakClients => { + if let Some(token) = token { + Ok(SchemaMigrationPolicy::BreakClients(token)) + } else { + Err(( + StatusCode::BAD_REQUEST, + "Migration policy is set to `BreakClients`, but no migration token was provided.", + )) + } + } + + MigrationPolicy::Compatible => Ok(SchemaMigrationPolicy::Compatible), + }?; + log::trace!("Publishing to the identity: {}", database_identity.to_hex()); let op = { @@ -594,6 +647,7 @@ pub async fn publish( num_replicas, host_type: HostType::Wasm, }, + policy, ) .await .map_err(log_and_500)?; @@ -621,6 +675,101 @@ pub async fn publish( })) } +#[derive(serde::Deserialize)] +pub struct PrePublishParams { + name_or_identity: NameOrIdentity, +} + +#[derive(serde::Deserialize)] +pub struct PrePublishQueryParams { + #[serde(default)] + style: PrettyPrintStyle, +} + +pub async fn pre_publish( + State(ctx): State, + Path(PrePublishParams { name_or_identity }): Path, + Query(PrePublishQueryParams { style }): Query, + Extension(auth): Extension, + body: Bytes, +) -> axum::response::Result> { + // User should not be able to print migration plans for a database that they do not own + let database_identity = resolve_and_authenticate(&ctx, &name_or_identity, &auth).await?; + let style = match style { + PrettyPrintStyle::NoColor => AutoMigratePrettyPrintStyle::NoColor, + PrettyPrintStyle::AnsiColor => AutoMigratePrettyPrintStyle::AnsiColor, + }; + + let migrate_plan = ctx + .migrate_plan( + DatabaseDef { + database_identity, + program_bytes: body.into(), + num_replicas: None, + host_type: HostType::Wasm, + }, + style, + ) + .await + .map_err(log_and_500)?; + + match migrate_plan { + MigratePlanResult::Success { + old_module_hash, + new_module_hash, + breaks_client, + plan, + } => { + let token = MigrationToken { + database_identity, + old_module_hash, + new_module_hash, + } + .hash(); + + Ok(PrintPlanResult { + token, + migrate_plan: plan, + break_clients: breaks_client, + }) + } + MigratePlanResult::AutoMigrationError(e) => Err(( + StatusCode::BAD_REQUEST, + format!("Automatic migration is not possible: {e}"), + ) + .into()), + } + .map(axum::Json) +} + +/// Resolves the [`NameOrIdentity`] to a database identity and checks if the +/// `auth` identity owns the database. +async fn resolve_and_authenticate( + ctx: &S, + name_or_identity: &NameOrIdentity, + auth: &SpacetimeAuth, +) -> axum::response::Result { + let database_identity = name_or_identity.resolve(ctx).await?; + + let database = worker_ctx_find_database(ctx, &database_identity) + .await? + .ok_or(NO_SUCH_DATABASE)?; + + if database.owner_identity != auth.claims.identity { + return Err(( + StatusCode::UNAUTHORIZED, + format!( + "Identity does not own database, expected: {} got: {}", + database.owner_identity.to_hex(), + auth.claims.identity.to_hex() + ), + ) + .into()); + } + + Ok(database_identity) +} + #[derive(Deserialize)] pub struct DeleteDatabaseParams { name_or_identity: NameOrIdentity, @@ -790,7 +939,8 @@ pub struct DatabaseRoutes { pub logs_get: MethodRouter, /// POST: /database/:name_or_identity/sql pub sql_post: MethodRouter, - + /// POST: /database/:name_or_identity/pre-publish + pub pre_publish: MethodRouter, /// GET: /database/: name_or_identity/unstable/timestamp pub timestamp_get: MethodRouter, } @@ -815,6 +965,7 @@ where schema_get: get(schema::), logs_get: get(logs::), sql_post: post(sql::), + pre_publish: post(pre_publish::), timestamp_get: get(get_timestamp::), } } @@ -838,7 +989,8 @@ where .route("/schema", self.schema_get) .route("/logs", self.logs_get) .route("/sql", self.sql_post) - .route("/unstable/timestamp", self.timestamp_get); + .route("/unstable/timestamp", self.timestamp_get) + .route("/pre-publish", self.pre_publish); axum::Router::new() .route("/", self.root_post) diff --git a/crates/client-api/src/routes/subscribe.rs b/crates/client-api/src/routes/subscribe.rs index 33fb0c6ffe4..457fb0bf96b 100644 --- a/crates/client-api/src/routes/subscribe.rs +++ b/crates/client-api/src/routes/subscribe.rs @@ -25,8 +25,8 @@ use spacetimedb::client::messages::{ serialize, IdentityTokenMessage, SerializableMessage, SerializeBuffer, SwitchedServerMessage, ToProtocol, }; use spacetimedb::client::{ - ClientActorId, ClientConfig, ClientConnection, DataMessage, MessageExecutionError, MessageHandleError, - MeteredReceiver, MeteredSender, Protocol, + ClientActorId, ClientConfig, ClientConnection, ClientConnectionReceiver, DataMessage, MessageExecutionError, + MessageHandleError, MeteredReceiver, MeteredSender, Protocol, }; use spacetimedb::host::module_host::ClientConnectedError; use spacetimedb::host::NoSuchModule; @@ -80,6 +80,12 @@ pub struct SubscribeQueryParams { /// This knob works by setting other, more specific, knobs to the value. #[serde(default)] pub light: bool, + /// If `true`, send the subscription updates only after the transaction + /// offset they're computed from is confirmed to be durable. + /// + /// If `false`, send them immediately. + #[serde(default)] + pub confirmed: bool, } pub fn generate_random_connection_id() -> ConnectionId { @@ -93,6 +99,7 @@ pub async fn handle_websocket( connection_id, compression, light, + confirmed, }): Query, forwarded_for: Option>, Extension(auth): Extension, @@ -127,6 +134,7 @@ where protocol, compression, tx_update_full: !light, + confirmed_reads: confirmed, }; // TODO: Should also maybe refactor the code and the protocol to allow a single websocket @@ -206,7 +214,7 @@ where "websocket: Database accepted connection from {client_log_string}; spawning ws_client_actor and ClientConnection" ); - let actor = |client, sendrx| ws_client_actor(ws_opts, client, ws, sendrx); + let actor = |client, receiver| ws_client_actor(ws_opts, client, ws, receiver); let client = ClientConnection::spawn( client_id, auth.into(), @@ -228,7 +236,7 @@ where token: identity_token, connection_id, }; - if let Err(e) = client.send_message(message) { + if let Err(e) = client.send_message(None, message) { log::warn!("websocket: Error sending IdentityToken message to {client_log_string}: {e}"); } }); @@ -356,7 +364,7 @@ async fn ws_client_actor( options: WebSocketOptions, client: ClientConnection, ws: WebSocketStream, - sendrx: MeteredReceiver, + sendrx: ClientConnectionReceiver, ) { // ensure that even if this task gets cancelled, we always cleanup the connection let mut client = scopeguard::guard(client, |client| { @@ -372,7 +380,7 @@ async fn ws_client_actor_inner( client: &mut ClientConnection, config: WebSocketOptions, ws: WebSocketStream, - sendrx: MeteredReceiver, + sendrx: ClientConnectionReceiver, ) { let database = client.module().info().database_identity; let client_id = client.id; @@ -980,6 +988,33 @@ enum UnorderedWsMessage { Error(MessageExecutionError), } +/// Abstraction over [`ClientConnectionReceiver`], so tests can use a plain +/// [`mpsc::Receiver`]. +trait Receiver { + fn recv(&mut self) -> impl Future> + Send; + fn close(&mut self); +} + +impl Receiver for ClientConnectionReceiver { + async fn recv(&mut self) -> Option { + ClientConnectionReceiver::recv(self).await + } + + fn close(&mut self) { + ClientConnectionReceiver::close(self); + } +} + +impl Receiver for mpsc::Receiver { + async fn recv(&mut self) -> Option { + mpsc::Receiver::recv(self).await + } + + fn close(&mut self) { + mpsc::Receiver::close(self); + } +} + /// Sink that sends outgoing messages to the `ws` sink. /// /// Consumes `messages`, which yields subscription updates and reducer call @@ -1005,10 +1040,9 @@ async fn ws_send_loop( state: Arc, config: ClientConfig, mut ws: impl Sink + Unpin, - mut messages: MeteredReceiver, + mut messages: impl Receiver, mut unordered: mpsc::UnboundedReceiver, ) { - let mut messages_buf = Vec::with_capacity(32); let mut serialize_buf = SerializeBuffer::new(config); loop { @@ -1072,26 +1106,35 @@ async fn ws_send_loop( } }, - n = messages.recv_many(&mut messages_buf, 32), if !closed => { - if n == 0 { - continue; - } - log::trace!("sending {n} outgoing messages"); - for msg in messages_buf.drain(..n) { - let (msg_alloc, res) = send_message( - &state.database, - config, - serialize_buf, - msg.workload().zip(msg.num_rows()), - &mut ws, - msg - ).await; - serialize_buf = msg_alloc; - - if let Err(e) = res { - log::warn!("websocket send error: {e}"); - return; + maybe_message = messages.recv(), if !closed => { + let Some(message) = maybe_message else { + // The message sender was dropped, even though no close + // handshake is in progress. This should not normally happen, + // but initiating close seems like the correct thing to do. + log::warn!("message sender dropped without close handshake"); + if let Err(e) = ws.send(WsMessage::Close(None)).await { + log::warn!("error sending close frame: {e:#}"); + break; } + state.close(); + // Continue so that `ws_client_actor` keeps waiting for an + // acknowledgement from the client. + continue; + }; + log::trace!("sending outgoing message"); + let (msg_alloc, res) = send_message( + &state.database, + config, + serialize_buf, + message.workload().zip(message.num_rows()), + &mut ws, + message + ).await; + serialize_buf = msg_alloc; + + if let Err(e) = res { + log::warn!("websocket send error: {e}"); + return; } }, } @@ -1226,7 +1269,7 @@ mod tests { sink, stream, }; use pretty_assertions::assert_matches; - use spacetimedb::client::ClientName; + use spacetimedb::client::{messages::SerializableMessage, ClientName}; use tokio::time::sleep; use super::*; @@ -1404,10 +1447,15 @@ mod tests { async fn send_loop_terminates_when_unordered_closed() { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let messages = MeteredReceiver::new(messages_rx); let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); - let send_loop = ws_send_loop(state, ClientConfig::for_test(), sink::drain(), messages, unordered_rx); + let send_loop = ws_send_loop( + state, + ClientConfig::for_test(), + sink::drain(), + messages_rx, + unordered_rx, + ); pin_mut!(send_loop); assert!(is_pending(&mut send_loop).await); @@ -1422,14 +1470,13 @@ mod tests { async fn send_loop_close_message_closes_state_and_messages() { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let messages = MeteredReceiver::new(messages_rx); let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), sink::drain(), - messages, + messages_rx, unordered_rx, ); pin_mut!(send_loop); @@ -1470,24 +1517,23 @@ mod tests { })), ]; - for msg in input { + for message in input { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let messages = MeteredReceiver::new(messages_rx); let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), UnfeedableSink, - messages, + messages_rx, unordered_rx, ); pin_mut!(send_loop); - match msg { + match message { Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), - Either::Right(msg) => messages_tx.send(msg).await.unwrap(), + Either::Right(message) => messages_tx.send(message).await.unwrap(), } send_loop.await; } @@ -1517,24 +1563,23 @@ mod tests { })), ]; - for msg in input { + for message in input { let state = Arc::new(dummy_actor_state()); let (messages_tx, messages_rx) = mpsc::channel(64); - let messages = MeteredReceiver::new(messages_rx); let (unordered_tx, unordered_rx) = mpsc::unbounded_channel(); let send_loop = ws_send_loop( state.clone(), ClientConfig::for_test(), UnflushableSink, - messages, + messages_rx, unordered_rx, ); pin_mut!(send_loop); - match msg { + match message { Either::Left(unordered) => unordered_tx.send(unordered).unwrap(), - Either::Right(msg) => messages_tx.send(msg).await.unwrap(), + Either::Right(message) => messages_tx.send(message).await.unwrap(), } send_loop.await; } diff --git a/crates/client-api/src/util.rs b/crates/client-api/src/util.rs index 91386986fe4..c38bf33c0ae 100644 --- a/crates/client-api/src/util.rs +++ b/crates/client-api/src/util.rs @@ -97,10 +97,10 @@ impl NameOrIdentity { pub async fn try_resolve( &self, ctx: &(impl ControlStateReadAccess + ?Sized), - ) -> axum::response::Result> { + ) -> anyhow::Result> { Ok(match self { Self::Identity(identity) => Ok(Identity::from(*identity)), - Self::Name(name) => ctx.lookup_identity(name.as_ref()).map_err(log_and_500)?.ok_or(name), + Self::Name(name) => ctx.lookup_identity(name.as_ref())?.ok_or(name), }) } @@ -108,7 +108,10 @@ impl NameOrIdentity { /// response if `self` is a [`NameOrIdentity::Name`] for which no /// corresponding [`Identity`] is found in the SpacetimeDB DNS. pub async fn resolve(&self, ctx: &(impl ControlStateReadAccess + ?Sized)) -> axum::response::Result { - self.try_resolve(ctx).await?.map_err(|_| StatusCode::NOT_FOUND.into()) + self.try_resolve(ctx) + .await + .map_err(log_and_500)? + .map_err(|_| StatusCode::NOT_FOUND.into()) } } diff --git a/crates/core/src/auth/mod.rs b/crates/core/src/auth/mod.rs index f9c381902e1..e1e38a667c4 100644 --- a/crates/core/src/auth/mod.rs +++ b/crates/core/src/auth/mod.rs @@ -15,6 +15,7 @@ pub struct JwtKeys { pub public: DecodingKey, pub public_pem: Box<[u8]>, pub private: EncodingKey, + pub private_pem: Box<[u8]>, pub kid: Option, } @@ -23,15 +24,17 @@ impl JwtKeys { /// respectively. /// /// The key files must be PEM encoded ECDSA P256 keys. - pub fn new(public_pem: impl Into>, private_pem: &[u8]) -> anyhow::Result { + pub fn new(public_pem: impl Into>, private_pem: impl Into>) -> anyhow::Result { let public_pem = public_pem.into(); + let private_pem = private_pem.into(); let public = DecodingKey::from_ec_pem(&public_pem)?; - let private = EncodingKey::from_ec_pem(private_pem)?; + let private = EncodingKey::from_ec_pem(&private_pem)?; Ok(Self { public, private, public_pem, + private_pem, kid: None, }) } @@ -75,7 +78,7 @@ pub struct EcKeyPair { impl TryFrom for JwtKeys { type Error = anyhow::Error; fn try_from(pair: EcKeyPair) -> anyhow::Result { - JwtKeys::new(pair.public_key_bytes, &pair.private_key_bytes) + JwtKeys::new(pair.public_key_bytes, pair.private_key_bytes) } } diff --git a/crates/core/src/client.rs b/crates/core/src/client.rs index 65c3a6847c5..c37d0f58b45 100644 --- a/crates/core/src/client.rs +++ b/crates/core/src/client.rs @@ -7,8 +7,8 @@ mod message_handlers; pub mod messages; pub use client_connection::{ - ClientConfig, ClientConnection, ClientConnectionSender, ClientSendError, DataMessage, MeteredDeque, - MeteredReceiver, MeteredSender, Protocol, + ClientConfig, ClientConnection, ClientConnectionReceiver, ClientConnectionSender, ClientSendError, DataMessage, + MeteredDeque, MeteredReceiver, MeteredSender, Protocol, }; pub use client_connection_index::ClientActorIndex; pub use message_handlers::{MessageExecutionError, MessageHandleError}; diff --git a/crates/core/src/client/client_connection.rs b/crates/core/src/client/client_connection.rs index b8a22b9d3e7..85d43482d16 100644 --- a/crates/core/src/client/client_connection.rs +++ b/crates/core/src/client/client_connection.rs @@ -9,6 +9,7 @@ use std::time::{Instant, SystemTime}; use super::messages::{OneOffQueryResponseMessage, SerializableMessage}; use super::{message_handlers, ClientActorId, MessageHandleError}; +use crate::db::relational_db::RelationalDB; use crate::error::DBError; use crate::host::module_host::ClientConnectedError; use crate::host::{ModuleHost, NoSuchModule, ReducerArgs, ReducerCallError, ReducerCallResult}; @@ -26,12 +27,14 @@ use spacetimedb_client_api_messages::websocket::{ BsatnFormat, CallReducerFlags, Compression, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, Unsubscribe, UnsubscribeMulti, }; +use spacetimedb_durability::{DurableOffset, TxOffset}; use spacetimedb_lib::identity::RequestId; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Identity; use tokio::sync::mpsc::error::{SendError, TrySendError}; use tokio::sync::{mpsc, oneshot, watch}; use tokio::task::AbortHandle; +use tracing::{trace, warn}; #[derive(PartialEq, Eq, Clone, Copy, Hash, Debug)] pub enum Protocol { @@ -65,6 +68,10 @@ pub struct ClientConfig { /// rather than [`TransactionUpdateLight`]s on a successful update. // TODO(centril): As more knobs are added, make this into a bitfield (when there's time). pub tx_update_full: bool, + /// If `true`, the client requests to receive updates for transactions + /// confirmed to be durable. If `false`, updates will be delivered + /// immediately. + pub confirmed_reads: bool, } impl ClientConfig { @@ -73,16 +80,183 @@ impl ClientConfig { protocol: Protocol::Binary, compression: <_>::default(), tx_update_full: true, + confirmed_reads: false, } } } +/// A message to be sent to the client, along with the transaction offset it +/// was computed at, if available. +/// +// TODO: Consider a different name, "ClientUpdate" is used elsewhere already. +#[derive(Debug)] +struct ClientUpdate { + /// Transaction offset at which `message` was computed. + /// + /// This is only `Some` if `message` is a query result. + /// + /// If `Some` and [`ClientConfig::confirmed_reads`] is `true`, + /// [`ClientConnectionReceiver`] will delay delivery until the durable + /// offset of the database is equal to or greater than `tx_offset`. + pub tx_offset: Option, + /// Type-erased outgoing message. + pub message: SerializableMessage, +} + +/// Types with access to the [`DurableOffset`] of a database. +/// +/// Provided implementors are [`watch::Receiver`] and [`RelationalDB`]. +/// +/// The latter is mostly useful for tests, where no managed [`ModuleHost`] is +/// available, while the former supports module hotswapping. +pub trait DurableOffsetSupply: Send { + /// Obtain the current [`DurableOffset`] handle. + /// + /// Returns: + /// + /// - `Err(NoSuchModule)` if the database was shut down + /// - `Ok(None)` if the database is configured without durability + /// - `Ok(Some(DurableOffset))` otherwise + /// + fn durable_offset(&mut self) -> Result, NoSuchModule>; +} + +impl DurableOffsetSupply for watch::Receiver { + fn durable_offset(&mut self) -> Result, NoSuchModule> { + let module = if self.has_changed().map_err(|_| NoSuchModule)? { + self.borrow_and_update() + } else { + self.borrow() + }; + + Ok(module.replica_ctx().relational_db.durable_tx_offset()) + } +} + +impl DurableOffsetSupply for RelationalDB { + fn durable_offset(&mut self) -> Result, NoSuchModule> { + Ok(self.durable_tx_offset()) + } +} + +/// Receiving end of [`ClientConnectionSender`]. +/// +/// The [`ClientConnection`] actor reads messages from this channel and sends +/// them to the client over its websocket connection. +/// +/// The [`ClientConnectionReceiver`] takes care of confirmed reads semantics, +/// if requested by the client. +pub struct ClientConnectionReceiver { + confirmed_reads: bool, + channel: MeteredReceiver, + current: Option, + offset_supply: Box, +} + +impl ClientConnectionReceiver { + fn new( + confirmed_reads: bool, + channel: MeteredReceiver, + offset_supply: impl DurableOffsetSupply + 'static, + ) -> Self { + Self { + confirmed_reads, + channel, + current: None, + offset_supply: Box::new(offset_supply), + } + } + + /// Receive the next message from this channel. + /// + /// If this method returns `None`, the channel is closed and no more messages + /// are in the internal buffers. No more messages can ever be received from + /// the channel. + /// + /// Messages are returned immediately if: + /// + /// - The (internal) [`ClientUpdate`] does not have a `tx_offset` + /// (such as for error messages). + /// - The client hasn't requested confirmed reads (i.e. + /// [`ClientConfig::confirmed_reads`] is `false`). + /// - The database is configured to not persist transactions. + /// + /// Otherwise, the update's `tx_offset` is compared against the module's + /// durable offset. If the durable offset is behind the `tx_offset`, the + /// method waits until it catches up before returning the message. + /// + /// If the database is shut down while waiting for the durable offset, + /// `None` is returned. In this case, no more messages can ever be received + /// from the channel. + /// + /// # Cancel safety + /// + /// This method is cancel safe, as long as `self` is not dropped. + /// + /// If `recv` is used in a [`tokio::select!`] statement, it may get + /// cancelled while waiting for the durable offset to catch up. At this + /// point, it has already received a value from the underlying channel. + /// This value is stored internally, so calling `recv` again will not lose + /// data. + // + // TODO: Can we make a cancel-safe `recv_many` with confirmed reads semantics? + pub async fn recv(&mut self) -> Option { + let ClientUpdate { tx_offset, message } = match self.current.take() { + None => self.channel.recv().await?, + Some(update) => update, + }; + if !self.confirmed_reads { + return Some(message); + } + + if let Some(tx_offset) = tx_offset { + match self.offset_supply.durable_offset() { + Ok(Some(mut durable)) => { + // Store the current update in case we get cancelled while + // waiting for the durable offset. + self.current = Some(ClientUpdate { + tx_offset: Some(tx_offset), + message, + }); + trace!("waiting for offset {tx_offset} to become durable"); + durable + .wait_for(tx_offset) + .await + .inspect_err(|_| { + warn!("database went away while waiting for durable offset"); + }) + .ok()?; + self.current.take().map(|update| update.message) + } + // Database shut down or crashed. + Err(NoSuchModule) => None, + // In-memory database. + Ok(None) => Some(message), + } + } else { + Some(message) + } + } + + /// Close the receiver without dropping it. + /// + /// This is used to notify the [`ClientConnectionSender`] that the receiver + /// will not consume any more messages from the channel, usually because the + /// connection has been closed or is about to be closed. + /// + /// After calling this method, the sender will not be able to send more + /// messages, preventing the internal buffer from filling up. + pub fn close(&mut self) { + self.channel.close(); + } +} + #[derive(Debug)] pub struct ClientConnectionSender { pub id: ClientActorId, pub auth: ConnectionAuthCtx, pub config: ClientConfig, - sendtx: mpsc::Sender, + sendtx: mpsc::Sender, abort_handle: AbortHandle, cancelled: AtomicBool, @@ -139,15 +313,19 @@ pub enum ClientSendError { } impl ClientConnectionSender { - pub fn dummy_with_channel(id: ClientActorId, config: ClientConfig) -> (Self, MeteredReceiver) { - let (sendtx, rx) = mpsc::channel(1); + pub fn dummy_with_channel( + id: ClientActorId, + config: ClientConfig, + offset_supply: impl DurableOffsetSupply + 'static, + ) -> (Self, ClientConnectionReceiver) { + let (sendtx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY_TEST); // just make something up, it doesn't need to be attached to a real task let abort_handle = match tokio::runtime::Handle::try_current() { Ok(h) => h.spawn(async {}).abort_handle(), Err(_) => tokio::runtime::Runtime::new().unwrap().spawn(async {}).abort_handle(), }; - let rx = MeteredReceiver::new(rx); + let receiver = ClientConnectionReceiver::new(config.confirmed_reads, MeteredReceiver::new(rx), offset_supply); let cancelled = AtomicBool::new(false); let dummy_claims = SpacetimeIdentityClaims { identity: id.identity, @@ -166,11 +344,11 @@ impl ClientConnectionSender { cancelled, metrics: None, }; - (sender, rx) + (sender, receiver) } - pub fn dummy(id: ClientActorId, config: ClientConfig) -> Self { - Self::dummy_with_channel(id, config).0 + pub fn dummy(id: ClientActorId, config: ClientConfig, offset_supply: impl DurableOffsetSupply + 'static) -> Self { + Self::dummy_with_channel(id, config, offset_supply).0 } pub fn is_cancelled(&self) -> bool { @@ -179,11 +357,25 @@ impl ClientConnectionSender { /// Send a message to the client. For data-related messages, you should probably use /// `BroadcastQueue::send` to ensure that the client sees data messages in a consistent order. - pub fn send_message(&self, message: impl Into) -> Result<(), ClientSendError> { - self.send(message.into()) + /// + /// If `message` is the result of evaluating a query, then `tx_offset` should be + /// the TX offset of the database state against which the query was evaluated. + /// If `message` is not the result of evaluating a query (e.g. it reports an error), + /// `tx_offset` should be `None`. + /// For clients which have requested only confirmed durable reads, + /// the sender will delay sending `message` until the `tx_offset` is confirmed. + pub fn send_message( + &self, + tx_offset: Option, + message: impl Into, + ) -> Result<(), ClientSendError> { + self.send(ClientUpdate { + tx_offset, + message: message.into(), + }) } - fn send(&self, message: SerializableMessage) -> Result<(), ClientSendError> { + fn send(&self, message: ClientUpdate) -> Result<(), ClientSendError> { if self.cancelled.load(Relaxed) { return Err(ClientSendError::Cancelled); } @@ -192,7 +384,12 @@ impl ClientConnectionSender { Err(mpsc::error::TrySendError::Full(_)) => { // we've hit CLIENT_CHANNEL_CAPACITY messages backed up in // the channel, so forcibly kick the client - tracing::warn!(identity = %self.id.identity, connection_id = %self.id.connection_id, "client channel capacity exceeded"); + tracing::warn!( + identity = %self.id.identity, + connection_id = %self.id.connection_id, + confirmed_reads = self.config.confirmed_reads, + "client channel capacity exceeded" + ); self.abort_handle.abort(); self.cancelled.store(true, Ordering::Relaxed); return Err(ClientSendError::Cancelled); @@ -436,6 +633,9 @@ impl MeteredSender { // if a client racks up this many messages in the queue without ACK'ing // anything, we boot 'em. const CLIENT_CHANNEL_CAPACITY: usize = 16 * KB; +// use a smaller value for tests +const CLIENT_CHANNEL_CAPACITY_TEST: usize = 8; + const KB: usize = 1024; /// Value returned by [`ClientConnection::call_client_connected_maybe_reject`] @@ -480,7 +680,7 @@ impl ClientConnection { config: ClientConfig, replica_id: u64, mut module_rx: watch::Receiver, - actor: impl FnOnce(ClientConnection, MeteredReceiver) -> Fut, + actor: impl FnOnce(ClientConnection, ClientConnectionReceiver) -> Fut, _proof_of_client_connected_call: Connected, ) -> ClientConnection where @@ -492,7 +692,7 @@ impl ClientConnection { // them and stuff. Not right now though. let module = module_rx.borrow_and_update().clone(); - let (sendtx, sendrx) = mpsc::channel::(CLIENT_CHANNEL_CAPACITY); + let (sendtx, sendrx) = mpsc::channel::(CLIENT_CHANNEL_CAPACITY); let (fut_tx, fut_rx) = oneshot::channel::(); // weird dance so that we can get an abort_handle into ClientConnection @@ -515,7 +715,11 @@ impl ClientConnection { .abort_handle(); let metrics = ClientConnectionMetrics::new(database_identity, config.protocol); - let sendrx = MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone()); + let receiver = ClientConnectionReceiver::new( + config.confirmed_reads, + MeteredReceiver::with_gauge(sendrx, metrics.sendtx_queue_size.clone()), + module_rx.clone(), + ); let sender = Arc::new(ClientConnectionSender { id, @@ -532,7 +736,7 @@ impl ClientConnection { module_rx, }; - let actor_fut = actor(this.clone(), sendrx); + let actor_fut = actor(this.clone(), receiver); // if this fails, the actor() function called .abort(), which like... okay, I guess? let _ = fut_tx.send(actor_fut); @@ -546,7 +750,7 @@ impl ClientConnection { module_rx: watch::Receiver, ) -> Self { Self { - sender: Arc::new(ClientConnectionSender::dummy(id, config)), + sender: Arc::new(ClientConnectionSender::dummy(id, config, module_rx.clone())), replica_id, module_rx, } @@ -735,3 +939,236 @@ impl ClientConnection { self.module().disconnect_client(self.id).await } } + +#[cfg(test)] +mod tests { + use core::fmt; + use std::pin::pin; + + use pretty_assertions::assert_matches; + + use super::*; + use crate::client::messages::{SubscriptionUpdateMessage, TransactionUpdateMessage}; + + #[derive(Clone)] + struct FakeDurableOffset { + channel: watch::Sender>, + closed: Arc, + } + + impl DurableOffsetSupply for FakeDurableOffset { + fn durable_offset(&mut self) -> Result, NoSuchModule> { + if self.closed.load(Ordering::Acquire) { + Err(NoSuchModule) + } else { + Ok(Some(self.channel.subscribe().into())) + } + } + } + + impl FakeDurableOffset { + fn new() -> Self { + let (tx, _) = watch::channel(None); + Self { + channel: tx, + closed: <_>::default(), + } + } + + fn mark_durable_at(&self, offset: TxOffset) { + self.channel.send_modify(|val| { + val.replace(offset); + }) + } + + fn close(&self) { + self.closed.store(true, Ordering::Release); + } + } + + /// [DurableOffsetSupply] that only stores the receiver side of a watch + /// channel initialized to some value. + /// + /// Calling `wait_for` will succeed while the provided value is smaller than + /// or equal to the stored value, but report the channel as closed once it + /// attempts to wait for a new value. + struct DisconnectedDurableOffset { + receiver: watch::Receiver>, + } + + impl DisconnectedDurableOffset { + fn new(offset: TxOffset) -> Self { + let (_, rx) = watch::channel(Some(offset)); + Self { receiver: rx } + } + } + + impl DurableOffsetSupply for DisconnectedDurableOffset { + fn durable_offset(&mut self) -> Result, NoSuchModule> { + Ok(Some(self.receiver.clone().into())) + } + } + + /// [DurableOffsetSupply] that always returns `Ok(None)`. + struct NoneDurableOffset; + + impl DurableOffsetSupply for NoneDurableOffset { + fn durable_offset(&mut self) -> Result, NoSuchModule> { + Ok(None) + } + } + + fn empty_tx_update() -> TransactionUpdateMessage { + TransactionUpdateMessage { + event: None, + database_update: SubscriptionUpdateMessage::default_for_protocol(Protocol::Binary, None), + } + } + + async fn assert_received_update(f: impl Future>) { + assert_matches!(f.await, Some(SerializableMessage::TxUpdate(_))); + } + + async fn assert_receiver_closed(f: impl Future>) { + assert_matches!(f.await, None); + } + + async fn assert_pending(f: &mut (impl Future + Unpin)) { + assert_matches!(futures::poll!(f), Poll::Pending); + } + + fn default_client( + offset_supply: impl DurableOffsetSupply + 'static, + ) -> (ClientConnectionSender, ClientConnectionReceiver) { + ClientConnectionSender::dummy_with_channel( + ClientActorId::for_test(Identity::ZERO), + ClientConfig { + confirmed_reads: false, + ..ClientConfig::for_test() + }, + offset_supply, + ) + } + + fn client_with_confirmed_reads( + offset_supply: impl DurableOffsetSupply + 'static, + ) -> (ClientConnectionSender, ClientConnectionReceiver) { + ClientConnectionSender::dummy_with_channel( + ClientActorId::for_test(Identity::ZERO), + ClientConfig { + confirmed_reads: true, + ..ClientConfig::for_test() + }, + offset_supply, + ) + } + + #[tokio::test] + async fn client_connection_receiver_waits_for_durable_offset() { + let offset = FakeDurableOffset::new(); + let (sender, mut receiver) = client_with_confirmed_reads(offset.clone()); + + for tx_offset in 0..10 { + sender.send_message(Some(tx_offset), empty_tx_update()).unwrap(); + let mut recv = pin!(receiver.recv()); + assert_pending(&mut recv).await; + offset.mark_durable_at(tx_offset); + assert_received_update(recv).await; + } + } + + #[tokio::test] + async fn client_connection_receiver_immediately_yields_message_if_already_durable() { + let offset = FakeDurableOffset::new(); + let (sender, mut receiver) = client_with_confirmed_reads(offset.clone()); + + for tx_offset in 0..10 { + offset.mark_durable_at(tx_offset); + sender.send_message(Some(tx_offset), empty_tx_update()).unwrap(); + assert_received_update(receiver.recv()).await; + } + } + + #[tokio::test] + async fn client_connection_receiver_ends_if_durable_offset_closed() { + let offset = FakeDurableOffset::new(); + let (sender, mut receiver) = client_with_confirmed_reads(offset.clone()); + + offset.close(); + sender.send_message(Some(42), empty_tx_update()).unwrap(); + assert_receiver_closed(receiver.recv()).await; + } + + #[tokio::test] + async fn client_connection_receiver_ends_if_durable_offset_dropped() { + const INITIAL_OFFSET: TxOffset = 1; + let offset = DisconnectedDurableOffset::new(INITIAL_OFFSET); + let (sender, mut receiver) = client_with_confirmed_reads(offset); + + for tx_offset in 0..=(INITIAL_OFFSET + 1) { + sender.send_message(Some(tx_offset), empty_tx_update()).unwrap(); + if tx_offset <= INITIAL_OFFSET { + assert_received_update(receiver.recv()).await; + } else { + assert_receiver_closed(receiver.recv()).await; + } + } + } + + #[tokio::test] + async fn client_connection_receiver_immediately_yields_message_if_sent_without_offset() { + let offset = FakeDurableOffset::new(); + let (sender, mut receiver) = client_with_confirmed_reads(offset.clone()); + + for _ in 0..10 { + sender.send_message(None, empty_tx_update()).unwrap(); + assert_received_update(receiver.recv()).await; + } + + offset.mark_durable_at(5); + + for _ in 0..10 { + sender.send_message(None, empty_tx_update()).unwrap(); + assert_received_update(receiver.recv()).await; + } + } + + #[tokio::test] + async fn client_connection_receiver_immediately_yields_message_for_client_without_confirmed_reads() { + let offset = FakeDurableOffset::new(); + let (sender, mut receiver) = default_client(offset.clone()); + + for tx_offset in 0..10 { + sender.send_message(Some(tx_offset), empty_tx_update()).unwrap(); + assert_received_update(receiver.recv()).await; + } + + offset.mark_durable_at(10); + + for tx_offset in 0..10 { + sender.send_message(Some(tx_offset), empty_tx_update()).unwrap(); + assert_received_update(receiver.recv()).await; + } + } + + #[tokio::test] + async fn client_connection_receiver_immediately_yields_message_without_durability() { + let (sender, mut receiver) = client_with_confirmed_reads(NoneDurableOffset); + + for tx_offset in 0..10 { + sender.send_message(Some(tx_offset), empty_tx_update()).unwrap(); + assert_received_update(receiver.recv()).await; + } + } + + #[tokio::test] + async fn client_connection_receiver_cancel_safety() { + let offset = FakeDurableOffset::new(); + let (sender, mut receiver) = client_with_confirmed_reads(offset.clone()); + + sender.send_message(Some(3), empty_tx_update()).unwrap(); + assert_pending(&mut pin!(receiver.recv())).await; + offset.mark_durable_at(3); + assert_received_update(receiver.recv()).await; + } +} diff --git a/crates/core/src/db/relational_db.rs b/crates/core/src/db/relational_db.rs index 5c878fbd25d..fb6db705147 100644 --- a/crates/core/src/db/relational_db.rs +++ b/crates/core/src/db/relational_db.rs @@ -32,7 +32,7 @@ use spacetimedb_datastore::{ }, traits::TxData, }; -use spacetimedb_durability::{self as durability, TxOffset}; +use spacetimedb_durability as durability; use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::db::raw_def::v9::{btree, RawModuleDefV9Builder, RawSql}; use spacetimedb_lib::st_var::StVarValue; @@ -63,6 +63,8 @@ use std::path::Path; use std::sync::Arc; use tokio::sync::watch; +pub use durability::{DurableOffset, TxOffset}; + // NOTE(cloutiertyler): We should be using the associated types, but there is // a bug in the Rust compiler that prevents us from doing so. pub type MutTx = MutTxId; //::MutTx; @@ -369,7 +371,9 @@ impl RelationalDB { .as_ref() .map(|pair| pair.0.clone()) .as_deref() - .and_then(|durability| durability.durable_tx_offset()); + .map(|durability| durability.durable_tx_offset().get()) + .transpose()? + .flatten(); let (min_commitlog_offset, _) = history.tx_range_hint(); log::info!("[{database_identity}] DATABASE: durable_tx_offset is {durable_tx_offset:?}"); @@ -807,7 +811,7 @@ impl RelationalDB { } #[tracing::instrument(level = "trace", skip_all)] - pub fn rollback_mut_tx(&self, tx: MutTx) -> (TxMetrics, String) { + pub fn rollback_mut_tx(&self, tx: MutTx) -> (TxOffset, TxMetrics, String) { log::trace!("ROLLBACK MUT TX"); self.inner.rollback_mut_tx(tx) } @@ -819,18 +823,18 @@ impl RelationalDB { } #[tracing::instrument(level = "trace", skip_all)] - pub fn release_tx(&self, tx: Tx) -> (TxMetrics, String) { + pub fn release_tx(&self, tx: Tx) -> (TxOffset, TxMetrics, String) { log::trace!("RELEASE TX"); self.inner.release_tx(tx) } #[tracing::instrument(level = "trace", skip_all)] - pub fn commit_tx(&self, tx: MutTx) -> Result, DBError> { + pub fn commit_tx(&self, tx: MutTx) -> Result, DBError> { log::trace!("COMMIT MUT TX"); // TODO: Never returns `None` -- should it? let reducer_context = tx.ctx.reducer_context().cloned(); - let Some((tx_data, tx_metrics, reducer)) = self.inner.commit_mut_tx(tx)? else { + let Some((tx_offset, tx_data, tx_metrics, reducer)) = self.inner.commit_mut_tx(tx)? else { return Ok(None); }; @@ -840,7 +844,7 @@ impl RelationalDB { Self::do_durability(&**durability, reducer_context.as_ref(), &tx_data) } - Ok(Some((tx_data, tx_metrics, reducer))) + Ok(Some((tx_offset, tx_data, tx_metrics, reducer))) } #[tracing::instrument(level = "trace", skip_all)] @@ -916,6 +920,14 @@ impl RelationalDB { } } + /// Get the [`DurableOffset`] of this database, or `None` if this is an + /// in-memory instance. + pub fn durable_tx_offset(&self) -> Option { + self.durability + .as_ref() + .map(|durability| durability.durable_tx_offset()) + } + /// Decide based on the `committed_state.next_tx_offset` /// whether to request that the [`SnapshotWorker`] in `self` capture a snapshot of the database. /// @@ -1018,7 +1030,7 @@ impl RelationalDB { { let mut tx = self.begin_tx(workload); let res = f(&mut tx); - let (tx_metrics, reducer) = self.release_tx(tx); + let (_tx_offset, tx_metrics, reducer) = self.release_tx(tx); self.report_read_tx_metrics(reducer, tx_metrics); res } @@ -1029,11 +1041,11 @@ impl RelationalDB { E: From, { if res.is_err() { - let (tx_metrics, reducer) = self.rollback_mut_tx(tx); + let (_, tx_metrics, reducer) = self.rollback_mut_tx(tx); self.report_mut_tx_metrics(reducer, tx_metrics, None); } else { match self.commit_tx(tx).map_err(E::from)? { - Some((tx_data, tx_metrics, reducer)) => { + Some((_tx_offset, tx_data, tx_metrics, reducer)) => { self.report_mut_tx_metrics(reducer, tx_metrics, Some(tx_data)); } None => panic!("TODO: retry?"), @@ -1048,7 +1060,7 @@ impl RelationalDB { pub fn rollback_on_err(&self, tx: MutTx, res: Result) -> Result<(MutTx, A), E> { match res { Err(e) => { - let (tx_metrics, reducer) = self.rollback_mut_tx(tx); + let (_, tx_metrics, reducer) = self.rollback_mut_tx(tx); self.report_mut_tx_metrics(reducer, tx_metrics, None); Err(e) @@ -2159,7 +2171,7 @@ mod tests { fn table( name: &str, columns: ProductType, - f: impl FnOnce(RawTableDefBuilder) -> RawTableDefBuilder, + f: impl FnOnce(RawTableDefBuilder<'_>) -> RawTableDefBuilder, ) -> TableSchema { let mut builder = RawModuleDefV9Builder::new(); f(builder.build_table_with_new_type(name, columns, true)); diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs index eff3b835f02..84e321a261b 100644 --- a/crates/core/src/error.rs +++ b/crates/core/src/error.rs @@ -6,6 +6,7 @@ use std::sync::{MutexGuard, PoisonError}; use enum_as_inner::EnumAsInner; use hex::FromHexError; use spacetimedb_commitlog::repo::TxOffset; +use spacetimedb_durability::DurabilityExited; use spacetimedb_expr::errors::TypingError; use spacetimedb_lib::Identity; use spacetimedb_schema::error::ValidationErrors; @@ -144,6 +145,8 @@ pub enum DBError { }, #[error(transparent)] RestoreSnapshot(#[from] RestoreSnapshotError), + #[error(transparent)] + DurabilityGone(#[from] DurabilityExited), } impl DBError { diff --git a/crates/core/src/host/host_controller.rs b/crates/core/src/host/host_controller.rs index 85510b76053..4b739073ec8 100644 --- a/crates/core/src/host/host_controller.rs +++ b/crates/core/src/host/host_controller.rs @@ -21,6 +21,7 @@ use async_trait::async_trait; use durability::{Durability, EmptyHistory}; use log::{info, trace, warn}; use parking_lot::Mutex; +use spacetimedb_data_structures::error_stream::ErrorStream; use spacetimedb_data_structures::map::IntMap; use spacetimedb_datastore::db_metrics::data_size::DATA_SIZE_METRICS; use spacetimedb_datastore::db_metrics::DB_METRICS; @@ -30,6 +31,7 @@ use spacetimedb_lib::{hash_bytes, Identity}; use spacetimedb_paths::server::{ReplicaDir, ServerDataDir}; use spacetimedb_paths::FromPathUnchecked; use spacetimedb_sats::hash::Hash; +use spacetimedb_schema::auto_migrate::{ponder_migrate, AutoMigrateError, MigrationPolicy, PrettyPrintStyle}; use spacetimedb_schema::def::ModuleDef; use spacetimedb_table::page_pool::PagePool; use std::future::Future; @@ -355,6 +357,7 @@ impl HostController { host_type: HostType, replica_id: u64, program_bytes: Box<[u8]>, + policy: MigrationPolicy, ) -> anyhow::Result { let program = Program { hash: hash_bytes(&program_bytes), @@ -404,6 +407,7 @@ impl HostController { this.runtimes.clone(), host_type, program, + policy, this.energy_monitor.clone(), this.unregister_fn(replica_id), this.db_cores.take(), @@ -418,6 +422,32 @@ impl HostController { Ok(update_result) } + pub async fn migrate_plan( + &self, + database: Database, + host_type: HostType, + replica_id: u64, + program_bytes: Box<[u8]>, + style: PrettyPrintStyle, + ) -> anyhow::Result { + let program = Program { + hash: hash_bytes(&program_bytes), + bytes: program_bytes, + }; + trace!( + "migrate plan {}/{}: genesis={} update-to={}", + database.database_identity, + replica_id, + database.initial_program, + program.hash + ); + + let guard = self.acquire_read_lock(replica_id).await; + let host = guard.as_ref().ok_or(NoSuchModule)?; + + host.migrate_plan(host_type, program, style).await + } + /// Start the host `replica_id` and conditionally update it. /// /// If the host was not initialized before, it is initialized with the @@ -503,6 +533,7 @@ impl HostController { this.runtimes.clone(), host_type, program, + MigrationPolicy::Compatible, this.energy_monitor.clone(), this.unregister_fn(replica_id), this.db_cores.take(), @@ -773,6 +804,7 @@ async fn update_module( module: &ModuleHost, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> anyhow::Result { let addr = db.database_identity(); match stored_program_hash(db)? { @@ -783,7 +815,7 @@ async fn update_module( UpdateDatabaseResult::NoUpdateNeeded } else { info!("updating `{}` from {} to {}", addr, stored, program.hash); - module.update_database(program, old_module_info).await? + module.update_database(program, old_module_info, policy).await? }; Ok(res) @@ -1022,11 +1054,13 @@ impl Host { /// otherwise it stays the same. /// /// Either way, the [`UpdateDatabaseResult`] is returned. + #[allow(clippy::too_many_arguments)] async fn update_module( &mut self, runtimes: Arc, host_type: HostType, program: Program, + policy: MigrationPolicy, energy_monitor: Arc, on_panic: impl Fn() + Send + Sync + 'static, core: JobCore, @@ -1049,7 +1083,8 @@ impl Host { // Get the old module info to diff against when building a migration plan. let old_module_info = self.module.borrow().info.clone(); - let update_result = update_module(&replica_ctx.relational_db, &module, program, old_module_info).await?; + let update_result = + update_module(&replica_ctx.relational_db, &module, program, old_module_info, policy).await?; trace!("update result: {update_result:?}"); // Only replace the module + scheduler if the update succeeded. // Otherwise, we want the database to continue running with the old state. @@ -1063,6 +1098,30 @@ impl Host { Ok(update_result) } + /// Generate a migration plan for the given `program`. + async fn migrate_plan( + &self, + host_type: HostType, + program: Program, + style: PrettyPrintStyle, + ) -> anyhow::Result { + let old_module = self.module.borrow().info.clone(); + + let module_def = extract_schema(program.bytes, host_type).await?; + + let res = match ponder_migrate(&old_module.module_def, &module_def) { + Ok(plan) => MigratePlanResult::Success { + old_module_hash: old_module.module_hash, + new_module_hash: program.hash, + breaks_client: plan.breaks_client(), + plan: plan.pretty_print(style)?.into(), + }, + Err(e) => MigratePlanResult::AutoMigrationError(e), + }; + + Ok(res) + } + fn db(&self) -> &RelationalDB { &self.replica_ctx.relational_db } @@ -1075,6 +1134,16 @@ impl Drop for Host { } } +pub enum MigratePlanResult { + Success { + old_module_hash: Hash, + new_module_hash: Hash, + plan: Box, + breaks_client: bool, + }, + AutoMigrationError(ErrorStream), +} + const STORAGE_METERING_INTERVAL: Duration = Duration::from_secs(15); /// Periodically collect gauge stats and update prometheus metrics. diff --git a/crates/core/src/host/mod.rs b/crates/core/src/host/mod.rs index 1434f50917c..beb72fc2f99 100644 --- a/crates/core/src/host/mod.rs +++ b/crates/core/src/host/mod.rs @@ -25,8 +25,8 @@ mod wasm_common; pub use disk_storage::DiskStorage; pub use host_controller::{ - extract_schema, DurabilityProvider, ExternalDurability, ExternalStorage, HostController, ProgramStorage, - ReducerCallResult, ReducerOutcome, StartSnapshotWatcher, + extract_schema, DurabilityProvider, ExternalDurability, ExternalStorage, HostController, MigratePlanResult, + ProgramStorage, ReducerCallResult, ReducerOutcome, StartSnapshotWatcher, }; pub use module_host::{ModuleHost, NoSuchModule, ReducerCallError, UpdateDatabaseResult}; pub use scheduler::Scheduler; diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 5ca0ad4099b..31155712085 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -44,7 +44,7 @@ use spacetimedb_lib::Timestamp; use spacetimedb_primitives::TableId; use spacetimedb_query::compile_subscription; use spacetimedb_sats::ProductValue; -use spacetimedb_schema::auto_migrate::AutoMigrateError; +use spacetimedb_schema::auto_migrate::{AutoMigrateError, MigrationPolicy}; use spacetimedb_schema::def::deserialize::ReducerArgsDeserializeSeed; use spacetimedb_schema::def::{ModuleDef, ReducerDef, TableDef}; use spacetimedb_schema::schema::{Schema, TableSchema}; @@ -52,6 +52,7 @@ use spacetimedb_vm::relation::RelValue; use std::fmt; use std::sync::{Arc, Weak}; use std::time::{Duration, Instant}; +use tokio::sync::oneshot; #[derive(Debug, Default, Clone, From)] pub struct DatabaseUpdate { @@ -337,6 +338,7 @@ pub trait ModuleInstance: Send + 'static { &mut self, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> anyhow::Result; fn call_reducer(&mut self, tx: Option, params: CallReducerParams) -> ReducerCallResult; @@ -398,7 +400,7 @@ fn init_database( let rcr = match module_def.lifecycle_reducer(Lifecycle::Init) { None => { - if let Some((tx_data, tx_metrics, reducer)) = stdb.commit_tx(tx)? { + if let Some((_tx_offset, tx_data, tx_metrics, reducer)) = stdb.commit_tx(tx)? { stdb.report_mut_tx_metrics(reducer, tx_metrics, Some(tx_data)); } None @@ -461,8 +463,9 @@ impl ModuleInstance for AutoReplacingModuleInstance { &mut self, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> anyhow::Result { - let ret = self.inst.update_database(program, old_module_info); + let ret = self.inst.update_database(program, old_module_info, policy); self.check_trap(); ret } @@ -705,7 +708,7 @@ impl ModuleHost { // If we crash before committing, we need to ensure that the transaction is rolled back. // This is necessary to avoid leaving the database in an inconsistent state. log::debug!("call_identity_connected: rolling back transaction"); - let (metrics, reducer_name) = mut_tx.rollback(); + let (_, metrics, reducer_name) = mut_tx.rollback(); stdb.report_mut_tx_metrics(reducer_name, metrics, None); }); @@ -1086,9 +1089,10 @@ impl ModuleHost { &self, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> Result { self.call("", move |inst| { - inst.update_database(program, old_module_info) + inst.update_database(program, old_module_info, policy) }) .await? } @@ -1138,74 +1142,79 @@ impl ModuleHost { log::debug!("One-off query: {query}"); let metrics = self .on_module_thread("one_off_query", move || { - db.with_read_only(Workload::Sql, |tx| { - // We wrap the actual query in a closure so we can use ? to handle errors without making - // the entire transaction abort with an error. - let result: Result<(OneOffTable, ExecutionMetrics), anyhow::Error> = (|| { - let tx = SchemaViewer::new(tx, &auth); - - let ( - // A query may compile down to several plans. - // This happens when there are multiple RLS rules per table. - // The original query is the union of these plans. - plans, - _, - table_name, - _, - ) = compile_subscription(&query, &tx, &auth)?; - - // Optimize each fragment - let optimized = plans - .into_iter() - .map(|plan| plan.optimize()) - .collect::, _>>()?; - - check_row_limit( - &optimized, - &db, - &tx, - // Estimate the number of rows this query will scan - |plan, tx| estimate_rows_scanned(tx, plan), - &auth, - )?; - - let optimized = optimized - .into_iter() - // Convert into something we can execute - .map(PipelinedProject::from) - .collect::>(); - - // Execute the union and return the results - execute_plan::<_, F>(&optimized, &DeltaTx::from(&*tx)) - .map(|(rows, _, metrics)| (OneOffTable { table_name, rows }, metrics)) - .context("One-off queries are not allowed to modify the database") - })(); - - let total_host_execution_duration = timer.elapsed().into(); - let (message, metrics): (SerializableMessage, Option) = match result { - Ok((rows, metrics)) => ( - into_message(OneOffQueryResponseMessage { - message_id, - error: None, - results: vec![rows], - total_host_execution_duration, - }), - Some(metrics), - ), - Err(err) => ( - into_message(OneOffQueryResponseMessage { - message_id, - error: Some(format!("{err}")), - results: vec![], - total_host_execution_duration, - }), - None, - ), - }; - - subscriptions.send_client_message(client, message, tx)?; - Ok::, anyhow::Error>(metrics) - }) + let (tx_offset_sender, tx_offset_receiver) = oneshot::channel(); + let tx = scopeguard::guard(db.begin_tx(Workload::Sql), |tx| { + let (tx_offset, tx_metrics, reducer) = db.release_tx(tx); + let _ = tx_offset_sender.send(tx_offset); + db.report_read_tx_metrics(reducer, tx_metrics); + }); + + // We wrap the actual query in a closure so we can use ? to handle errors without making + // the entire transaction abort with an error. + let result: Result<(OneOffTable, ExecutionMetrics), anyhow::Error> = (|| { + let tx = SchemaViewer::new(&*tx, &auth); + + let ( + // A query may compile down to several plans. + // This happens when there are multiple RLS rules per table. + // The original query is the union of these plans. + plans, + _, + table_name, + _, + ) = compile_subscription(&query, &tx, &auth)?; + + // Optimize each fragment + let optimized = plans + .into_iter() + .map(|plan| plan.optimize()) + .collect::, _>>()?; + + check_row_limit( + &optimized, + &db, + &tx, + // Estimate the number of rows this query will scan + |plan, tx| estimate_rows_scanned(tx, plan), + &auth, + )?; + + let optimized = optimized + .into_iter() + // Convert into something we can execute + .map(PipelinedProject::from) + .collect::>(); + + // Execute the union and return the results + execute_plan::<_, F>(&optimized, &DeltaTx::from(&*tx)) + .map(|(rows, _, metrics)| (OneOffTable { table_name, rows }, metrics)) + .context("One-off queries are not allowed to modify the database") + })(); + + let total_host_execution_duration = timer.elapsed().into(); + let (message, metrics): (SerializableMessage, Option) = match result { + Ok((rows, metrics)) => ( + into_message(OneOffQueryResponseMessage { + message_id, + error: None, + results: vec![rows], + total_host_execution_duration, + }), + Some(metrics), + ), + Err(err) => ( + into_message(OneOffQueryResponseMessage { + message_id, + error: Some(format!("{err}")), + results: vec![], + total_host_execution_duration, + }), + None, + ), + }; + + subscriptions.send_client_message(client, message, (&*tx, tx_offset_receiver))?; + Ok::, anyhow::Error>(metrics) }) .await??; diff --git a/crates/core/src/host/v8/mod.rs b/crates/core/src/host/v8/mod.rs index 9cff2587320..92e1fb0c896 100644 --- a/crates/core/src/host/v8/mod.rs +++ b/crates/core/src/host/v8/mod.rs @@ -19,7 +19,9 @@ use ser::serialize_to_js; use spacetimedb_client_api_messages::energy::{EnergyQuanta, ReducerBudget}; use spacetimedb_datastore::locking_tx_datastore::MutTxId; use spacetimedb_datastore::traits::Program; -use spacetimedb_lib::{ConnectionId, Identity, RawModuleDef}; +use spacetimedb_lib::RawModuleDef; +use spacetimedb_lib::{ConnectionId, Identity}; +use spacetimedb_schema::auto_migrate::MigrationPolicy; use std::sync::{Arc, LazyLock}; use v8::{Context, ContextOptions, ContextScope, Function, HandleScope, Isolate, Local, Value}; @@ -137,9 +139,11 @@ impl ModuleInstance for JsInstance { &mut self, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> anyhow::Result { let replica_ctx = &self.replica_ctx; - self.common.update_database(replica_ctx, program, old_module_info) + self.common + .update_database(replica_ctx, program, old_module_info, policy) } fn call_reducer(&mut self, tx: Option, params: CallReducerParams) -> super::ReducerCallResult { diff --git a/crates/core/src/host/wasm_common/module_host_actor.rs b/crates/core/src/host/wasm_common/module_host_actor.rs index e8c3ca064b0..037a5f9bd11 100644 --- a/crates/core/src/host/wasm_common/module_host_actor.rs +++ b/crates/core/src/host/wasm_common/module_host_actor.rs @@ -1,6 +1,6 @@ use prometheus::{Histogram, IntCounter, IntGauge}; use spacetimedb_lib::db::raw_def::v9::Lifecycle; -use spacetimedb_schema::auto_migrate::ponder_migrate; +use spacetimedb_schema::auto_migrate::{MigratePlan, MigrationPolicy, MigrationPolicyError}; use std::sync::Arc; use std::time::Duration; use tracing::span::EnteredSpan; @@ -233,9 +233,11 @@ impl ModuleInstance for WasmModuleInstance { &mut self, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> anyhow::Result { let replica_ctx = &self.instance.instance_env().replica_ctx; - self.common.update_database(replica_ctx, program, old_module_info) + self.common + .update_database(replica_ctx, program, old_module_info, policy) } fn call_reducer(&mut self, tx: Option, params: CallReducerParams) -> ReducerCallResult { @@ -277,15 +279,24 @@ impl InstanceCommon { replica_ctx: &ReplicaContext, program: Program, old_module_info: Arc, + policy: MigrationPolicy, ) -> Result { let system_logger = replica_ctx.logger.system_logger(); let stdb = &replica_ctx.relational_db; - let plan = ponder_migrate(&old_module_info.module_def, &self.info.module_def); - let plan = match plan { + let plan: MigratePlan = match policy.try_migrate( + self.info.database_identity, + old_module_info.module_hash, + &old_module_info.module_def, + self.info.module_hash, + &self.info.module_def, + ) { Ok(plan) => plan, - Err(errs) => { - return Ok(UpdateDatabaseResult::AutoMigrateError(errs)); + Err(e) => { + return match e { + MigrationPolicyError::AutoMigrateFailure(e) => Ok(UpdateDatabaseResult::AutoMigrateError(e)), + _ => Ok(UpdateDatabaseResult::ErrorExecutingMigration(e.into())), + } } }; @@ -301,12 +312,12 @@ impl InstanceCommon { Err(e) => { log::warn!("Database update failed: {} @ {}", e, stdb.database_identity()); system_logger.warn(&format!("Database update failed: {e}")); - let (tx_metrics, reducer) = stdb.rollback_mut_tx(tx); + let (_, tx_metrics, reducer) = stdb.rollback_mut_tx(tx); stdb.report_mut_tx_metrics(reducer, tx_metrics, None); Ok(UpdateDatabaseResult::ErrorExecutingMigration(e)) } Ok(()) => { - if let Some((tx_data, tx_metrics, reducer)) = stdb.commit_tx(tx)? { + if let Some((_tx_offset, tx_data, tx_metrics, reducer)) = stdb.commit_tx(tx)? { stdb.report_mut_tx_metrics(reducer, tx_metrics, Some(tx_data)); } system_logger.info("Database updated"); @@ -615,7 +626,7 @@ fn commit_and_broadcast_event( .commit_and_broadcast_event(client, event, tx) .unwrap() { - Ok((event, _)) => event, + Ok(res) => res.event, Err(WriteConflict) => todo!("Write skew, you need to implement retries my man, T-dawg."), } } diff --git a/crates/core/src/messages/control_db.rs b/crates/core/src/messages/control_db.rs index 21a9155dc64..8299875e339 100644 --- a/crates/core/src/messages/control_db.rs +++ b/crates/core/src/messages/control_db.rs @@ -63,6 +63,10 @@ pub struct Node { /// /// If `None`, the node is not currently live. pub advertise_addr: Option, + /// The address this node is running its postgres API at. + /// + /// If `None`, the node is not currently live. + pub pg_addr: Option, } #[derive(Clone, PartialEq, Serialize, Deserialize)] pub struct NodeStatus { diff --git a/crates/core/src/sql/execute.rs b/crates/core/src/sql/execute.rs index d12e811ca7a..089148cdd2e 100644 --- a/crates/core/src/sql/execute.rs +++ b/crates/core/src/sql/execute.rs @@ -9,6 +9,7 @@ use crate::estimation::estimate_rows_scanned; use crate::host::module_host::{DatabaseTableUpdate, DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall}; use crate::host::ArgsTuple; use crate::subscription::module_subscription_actor::{ModuleSubscriptions, WriteConflict}; +use crate::subscription::module_subscription_manager::TransactionOffset; use crate::subscription::tx::DeltaTx; use crate::util::slow::SlowQueryLogger; use crate::vm::{check_row_limit, DbProgram, TxMode}; @@ -26,6 +27,7 @@ use spacetimedb_schema::relation::FieldName; use spacetimedb_vm::eval::run_ast; use spacetimedb_vm::expr::{CodeResult, CrudExpr, Expr}; use spacetimedb_vm::relation::MemTable; +use tokio::sync::oneshot; pub struct StmtResult { pub schema: ProductType, @@ -172,6 +174,11 @@ pub fn execute_sql_tx<'a>( } pub struct SqlResult { + /// The offset of the SQL operation's transaction. + /// + /// Used to determine visibility of the transaction wrt the durability + /// requirements requested by the caller. + pub tx_offset: TransactionOffset, pub rows: Vec, /// These metrics will be reported via `report_tx_metrics`. /// They should not be reported separately to avoid double counting. @@ -200,9 +207,12 @@ pub fn run( // and hence there are no deltas to process. let (tx_data, tx_metrics_mut, tx) = tx.commit_downgrade(Workload::Sql); - // Release the tx on drop, so that we record metrics. + let (tx_offset_send, tx_offset) = oneshot::channel(); + // Release the tx on drop, so that we record metrics + // and set the transaction offset. let mut tx = scopeguard::guard(tx, |tx| { - let (tx_metrics_downgrade, reducer) = db.release_tx(tx); + let (offset, tx_metrics_downgrade, reducer) = db.release_tx(tx); + let _ = tx_offset_send.send(offset); db.report_tx_metrics( reducer, Some(Arc::new(tx_data)), @@ -232,6 +242,7 @@ pub fn run( tx.metrics.merge(metrics); Ok(SqlResult { + tx_offset, rows, metrics: tx.metrics, }) @@ -252,10 +263,17 @@ pub fn run( if subs.is_none() { let metrics = tx.metrics; return db.commit_tx(tx).map(|tx_opt| { - if let Some((tx_data, tx_metrics, reducer)) = tx_opt { - db.report_mut_tx_metrics(reducer, tx_metrics, Some(tx_data)); + let (tx_offset, tx_data, tx_metrics, reducer) = tx_opt.unwrap(); + + let (tx_offset_sender, tx_offset_receiver) = oneshot::channel(); + let _ = tx_offset_sender.send(tx_offset); + + db.report_mut_tx_metrics(reducer, tx_metrics, Some(tx_data)); + SqlResult { + tx_offset: tx_offset_receiver, + rows: vec![], + metrics, } - SqlResult { rows: vec![], metrics } }); } @@ -289,7 +307,11 @@ pub fn run( Err(WriteConflict) => { todo!("See module_host_actor::call_reducer_with_tx") } - Ok(_) => Ok(SqlResult { rows: vec![], metrics }), + Ok(res) => Ok(SqlResult { + tx_offset: res.tx_offset, + rows: vec![], + metrics, + }), } } } diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 65211305898..63b56d54c07 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,7 +1,5 @@ use super::execution_unit::QueryHash; -use super::module_subscription_manager::{ - spawn_send_worker, BroadcastError, BroadcastQueue, Plan, SubscriptionGaugeStats, SubscriptionManager, -}; +use super::module_subscription_manager::{from_tx_offset, spawn_send_worker, BroadcastError, BroadcastQueue, Plan, SubscriptionGaugeStats, SubscriptionManager, TransactionOffset}; use super::query::compile_query_with_hashes; use super::tx::DeltaTx; use super::{collect_table_update, TableUpdateType}; @@ -22,18 +20,23 @@ use crate::vm::check_row_limit; use crate::worker_metrics::WORKER_METRICS; use parking_lot::RwLock; use prometheus::{Histogram, HistogramTimer, IntCounter, IntGauge}; +use scopeguard::ScopeGuard; use spacetimedb_client_api_messages::websocket::{ self as ws, BsatnFormat, FormatSwitch, JsonFormat, SubscribeMulti, SubscribeSingle, TableUpdate, Unsubscribe, UnsubscribeMulti, }; use spacetimedb_datastore::db_metrics::DB_METRICS; use spacetimedb_datastore::execution_context::{Workload, WorkloadType}; +use spacetimedb_datastore::locking_tx_datastore::datastore::TxMetrics; use spacetimedb_datastore::locking_tx_datastore::TxId; +use spacetimedb_datastore::traits::TxData; +use spacetimedb_durability::TxOffset; use spacetimedb_execution::pipelined::PipelinedProject; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::Identity; use std::{sync::Arc, time::Instant}; +use tokio::sync::oneshot; type Subscriptions = Arc>; @@ -123,6 +126,16 @@ impl SubscriptionMetrics { } } +/// Inner result type of [`ModuleSubscriptions::commit_and_broadcast_event`]. +pub type CommitAndBroadcastEventResult = Result; + +/// `Ok` side of a [`CommitAndBroadcastEventResult`]. +pub struct CommitAndBroadcastEventSuccess { + pub tx_offset: TransactionOffset, + pub event: Arc, + pub metrics: ExecutionMetrics, +} + type AssertTxFn = Arc; type SubscriptionUpdate = FormatSwitch, TableUpdate>; type FullSubscriptionUpdate = FormatSwitch, ws::DatabaseUpdate>; @@ -299,6 +312,7 @@ impl ModuleSubscriptions { let send_err_msg = |message| { self.broadcast_queue.send_client_message( sender.clone(), + None, SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -316,10 +330,7 @@ impl ModuleSubscriptions { let hash = QueryHash::from_string(&sql, auth.caller, false); let hash_with_param = QueryHash::from_string(&sql, auth.caller, true); - let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Subscribe), |tx| { - let (tx_metrics, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_read_tx_metrics(reducer, tx_metrics); - }); + let (tx, tx_offset) = self.begin_tx(Workload::Subscribe); let existing_query = { let guard = self.subscriptions.read(); @@ -365,6 +376,7 @@ impl ModuleSubscriptions { // Holding a write lock on `self.subscriptions` would also be sufficient. let _ = self.broadcast_queue.send_client_message( sender.clone(), + Some(tx_offset), SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -390,6 +402,7 @@ impl ModuleSubscriptions { let send_err_msg = |message| { self.broadcast_queue.send_client_message( sender.clone(), + None, SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -418,10 +431,7 @@ impl ModuleSubscriptions { return Ok(None); }; - let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Unsubscribe), |tx| { - let (tx_metrics, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_read_tx_metrics(reducer, tx_metrics); - }); + let (tx, tx_offset) = self.begin_tx(Workload::Unsubscribe); let auth = AuthCtx::new(self.owner_identity, sender.id.identity); let (table_rows, metrics) = return_on_err_with_sql!( self.evaluate_initial_subscription(sender.clone(), query.clone(), &tx, &auth, TableUpdateType::Unsubscribe), @@ -438,6 +448,7 @@ impl ModuleSubscriptions { // Holding a write lock on `self.subscriptions` would also be sufficient. let _ = self.broadcast_queue.send_client_message( sender.clone(), + Some(tx_offset), SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -464,6 +475,7 @@ impl ModuleSubscriptions { let send_err_msg = |message| { self.broadcast_queue.send_client_message( sender.clone(), + None, SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -480,10 +492,7 @@ impl ModuleSubscriptions { let subscription_metrics = SubscriptionMetrics::new(&database_identity, &WorkloadType::Unsubscribe); // Always lock the db before the subscription lock to avoid deadlocks. - let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Unsubscribe), |tx| { - let (tx_metrics, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_read_tx_metrics(reducer, tx_metrics); - }); + let (tx, tx_offset) = self.begin_tx(Workload::Unsubscribe); let removed_queries = { let _compile_timer = subscription_metrics.compilation_time.start_timer(); @@ -527,6 +536,7 @@ impl ModuleSubscriptions { // Holding a write lock on `self.subscriptions` would also be sufficient. let _ = self.broadcast_queue.send_client_message( sender, + Some(tx_offset), SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -576,10 +586,7 @@ impl ModuleSubscriptions { let auth = AuthCtx::new(self.owner_identity, sender); // We always get the db lock before the subscription lock to avoid deadlocks. - let tx = scopeguard::guard(self.relational_db.begin_tx(Workload::Subscribe), |tx| { - let (tx_metrics, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_read_tx_metrics(reducer, tx_metrics); - }); + let (tx, _tx_offset) = self.begin_tx(Workload::Subscribe); let compile_timer = metrics.compilation_time.start_timer(); @@ -631,9 +638,10 @@ impl ModuleSubscriptions { &self, recipient: Arc, message: impl Into, - _tx_id: &TxId, + (_tx, tx_offset): (&TxId, TransactionOffset), ) -> Result<(), BroadcastError> { - self.broadcast_queue.send_client_message(recipient, message) + self.broadcast_queue + .send_client_message(recipient, Some(tx_offset), message) } #[tracing::instrument(level = "trace", skip_all)] @@ -648,6 +656,7 @@ impl ModuleSubscriptions { let send_err_msg = |message| { let _ = self.broadcast_queue.send_client_message( sender.clone(), + None, SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -678,10 +687,7 @@ impl ModuleSubscriptions { send_err_msg, None ); - let tx = scopeguard::guard(tx, |tx| { - let (tx_metrics, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_read_tx_metrics(reducer, tx_metrics); - }); + let (tx, tx_offset) = self.guard_tx(tx, <_>::default()); // We minimize locking so that other clients can add subscriptions concurrently. // We are protected from race conditions with broadcasts, because we have the db lock, @@ -738,6 +744,7 @@ impl ModuleSubscriptions { let _ = self.broadcast_queue.send_client_message( sender.clone(), + Some(tx_offset), SubscriptionMessage { request_id: Some(request.request_id), query_id: Some(request.query_id), @@ -772,10 +779,7 @@ impl ModuleSubscriptions { num_queries, &subscription_metrics, )?; - let tx = scopeguard::guard(tx, |tx| { - let (tx_metrics, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_read_tx_metrics(reducer, tx_metrics); - }); + let (tx, tx_offset) = self.guard_tx(tx, <_>::default()); check_row_limit( &queries, @@ -830,6 +834,7 @@ impl ModuleSubscriptions { // Holding a write lock on `self.subscriptions` would also be sufficient. let _ = self.broadcast_queue.send_client_message( sender, + Some(tx_offset), SubscriptionUpdateMessage { database_update, request_id: Some(subscription.request_id), @@ -854,7 +859,7 @@ impl ModuleSubscriptions { caller: Option>, mut event: ModuleEvent, tx: MutTx, - ) -> Result, ExecutionMetrics), WriteConflict>, DBError> { + ) -> Result { let database_identity = self.relational_db.database_identity(); let subscription_metrics = SubscriptionMetrics::new(&database_identity, &WorkloadType::Update); @@ -883,7 +888,7 @@ impl ModuleSubscriptions { // We don't need to do any subscription updates in this case, so we will exit early. let event = Arc::new(event); - let (tx_metrics, reducer) = stdb.rollback_mut_tx(tx); + let (tx_offset, tx_metrics, reducer) = stdb.rollback_mut_tx(tx); self.relational_db .report_tx_metrics(reducer, None, Some(tx_metrics), None); if let Some(client) = caller { @@ -892,34 +897,104 @@ impl ModuleSubscriptions { database_update: SubscriptionUpdateMessage::default_for_protocol(client.config.protocol, None), }; - let _ = self.broadcast_queue.send_client_message(client, message); + let _ = self.broadcast_queue.send_client_message(client, Some(from_tx_offset(tx_offset)), message); } else { log::trace!("Reducer failed but there is no client to send the failure to!") } - return Ok(Ok((event, ExecutionMetrics::default()))); + return Ok(Ok(CommitAndBroadcastEventSuccess { + tx_offset: from_tx_offset(tx_offset), + event, + metrics: ExecutionMetrics::default(), + })); } }; let event = Arc::new(event); // When we're done with this method, release the tx and report metrics. - let mut read_tx = scopeguard::guard(read_tx, |tx| { - let (tx_metrics_read, reducer) = self.relational_db.release_tx(tx); - self.relational_db.report_tx_metrics( - reducer, - Some(tx_data.clone()), - Some(tx_metrics_mut), - Some(tx_metrics_read), - ); - }); + let (extra_tx_offset_sender, extra_tx_offset) = oneshot::channel(); + let (mut read_tx, tx_offset) = self.guard_tx( + read_tx, + GuardTxOptions::full(extra_tx_offset_sender, Some(tx_data.clone()), tx_metrics_mut), + ); // Create the delta transaction we'll use to eval updates against. let delta_read_tx = DeltaTx::new(&read_tx, tx_data.as_ref(), subscriptions.index_ids_for_subscriptions()); + let update_metrics = subscriptions.eval_updates_sequential((&delta_read_tx, tx_offset), event.clone(), caller); + read_tx.metrics.merge(update_metrics); + Ok(Ok(CommitAndBroadcastEventSuccess { + tx_offset: extra_tx_offset, + event, + metrics: update_metrics, + })) + } - let update_metrics = subscriptions.eval_updates_sequential(&delta_read_tx, event.clone(), caller); + /// Helper that starts a new read transaction, and guards it using + /// [`Self::guard_tx`] with the default configuration. + fn begin_tx(&self, workload: Workload) -> (ScopeGuard, TransactionOffset) { + self.guard_tx(self.relational_db.begin_tx(workload), <_>::default()) + } - // Merge in the subscription evaluation metrics. - read_tx.metrics.merge(update_metrics); + /// Helper wrapping `tx` in a scopegard, with a configurable drop fn. + /// + /// By default, `tx` is released when the returned [`ScopeGuard`] is dropped, + /// and reports the transaction metrics via [`RelationalDB::report_tx_metrics`]. + /// The `tx_data` and `tx_metrics_mut` parameters are passed to the metrics + /// reporting method as-is; they can be used to report additional metrics + /// about a previous mutable transaction that was downgraded to `tx` after + /// committing. + /// + /// The method returns a [`ScopeGuard`] along with a [`TransactionOffset`]. + /// When the transaction commits, its transaction offset is sent to the + /// latter (a [`oneshot::Receiver`]). + /// If another receiver of the transaction offset is needed, its sending + /// side can be passed in as `extra_tx_offset_sender`. It will be sent the + /// offset as well. + fn guard_tx( + &self, + tx: TxId, + GuardTxOptions { + extra_tx_offset_sender, + tx_data, + tx_metrics_mut, + }: GuardTxOptions, + ) -> (ScopeGuard, TransactionOffset) { + let (offset_tx, offset_rx) = oneshot::channel(); + let guard = scopeguard::guard(tx, |tx| { + let (tx_offset, tx_metrics, reducer) = self.relational_db.release_tx(tx); + log::trace!("read tx released with offset {tx_offset}"); + let _ = offset_tx.send(tx_offset); + if let Some(extra) = extra_tx_offset_sender { + let _ = extra.send(tx_offset); + } + self.relational_db + .report_tx_metrics(reducer, tx_data, tx_metrics_mut, Some(tx_metrics)); + }); - Ok(Ok((event, update_metrics))) + (guard, offset_rx) + } +} + +/// Extra parameters for [`ModuleSubscriptions::guard_tx`]. +#[derive(Default)] +struct GuardTxOptions { + /// Sender for an extra [`oneshot::Receiver`] for the transaction offset. + extra_tx_offset_sender: Option>, + /// [`TxData`] of a preceding mutable transaction. + tx_data: Option>, + /// [`TxMetrics`] of a preceding mutable transaction. + tx_metrics_mut: Option, +} + +impl GuardTxOptions { + fn full( + extra_tx_offset_sender: oneshot::Sender, + tx_data: Option>, + tx_metrics_mut: TxMetrics, + ) -> Self { + Self { + extra_tx_offset_sender: extra_tx_offset_sender.into(), + tx_data, + tx_metrics_mut: tx_metrics_mut.into(), + } } } @@ -932,11 +1007,13 @@ mod tests { SerializableMessage, SubscriptionData, SubscriptionError, SubscriptionMessage, SubscriptionResult, SubscriptionUpdateMessage, TransactionUpdateMessage, }; - use crate::client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName, MeteredReceiver, Protocol}; + use crate::client::{ + ClientActorId, ClientConfig, ClientConnectionReceiver, ClientConnectionSender, ClientName, Protocol, + }; use crate::db::relational_db::tests_utils::{ - begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TestDB, + begin_mut_tx, begin_tx, insert, with_auto_commit, with_read_only, TempReplicaDir, TestDB, }; - use crate::db::relational_db::RelationalDB; + use crate::db::relational_db::{RelationalDB, Txdata}; use crate::error::DBError; use crate::host::module_host::{DatabaseUpdate, EventStatus, ModuleEvent, ModuleFunctionCall}; use crate::messages::websocket as ws; @@ -944,6 +1021,7 @@ mod tests { use crate::subscription::module_subscription_manager::{spawn_send_worker, SubscriptionManager}; use crate::subscription::query::compile_read_only_query; use crate::subscription::TableUpdateType; + use core::fmt; use hashbrown::HashMap; use itertools::Itertools; use pretty_assertions::assert_matches; @@ -952,7 +1030,9 @@ mod tests { CompressableQueryUpdate, Compression, FormatSwitch, QueryId, Subscribe, SubscribeMulti, SubscribeSingle, TableUpdate, Unsubscribe, UnsubscribeMulti, }; + use spacetimedb_commitlog::{commitlog, repo}; use spacetimedb_datastore::system_tables::{StRowLevelSecurityRow, ST_ROW_LEVEL_SECURITY_ID}; + use spacetimedb_durability::{Durability, EmptyHistory, TxOffset}; use spacetimedb_execution::dml::MutDatastore; use spacetimedb_lib::bsatn::ToBsatn; use spacetimedb_lib::db::auth::StAccess; @@ -962,9 +1042,14 @@ mod tests { use spacetimedb_lib::{error::ResultTest, AlgebraicType, Identity}; use spacetimedb_primitives::TableId; use spacetimedb_sats::product; + use std::future::Future; + use std::pin::pin; + use std::sync::RwLock; + use std::task::Poll; use std::time::Instant; use std::{sync::Arc, time::Duration}; use tokio::sync::mpsc::{self}; + use tokio::sync::watch; fn add_subscriber(db: Arc, sql: &str, assert: Option) -> Result<(), DBError> { // Create and enter a Tokio runtime to run the `ModuleSubscriptions`' background workers in parallel. @@ -973,7 +1058,7 @@ mod tests { let owner = Identity::from_byte_array([1; 32]); let client = ClientActorId::for_test(Identity::ZERO); let config = ClientConfig::for_test(); - let sender = Arc::new(ClientConnectionSender::dummy(client, config)); + let sender = Arc::new(ClientConnectionSender::dummy(client, config, (*db).clone())); let send_worker_queue = spawn_send_worker(None); let module_subscriptions = ModuleSubscriptions::new( db.clone(), @@ -990,12 +1075,88 @@ mod tests { Ok(()) } + /// A [`Durability`] for which the durable offset is marked manually. + struct ManualDurability { + commitlog: Arc>>, + durable_offset: watch::Sender>, + } + + impl ManualDurability { + #[allow(unused)] + fn mark_durable_at(&self, offset: TxOffset) { + assert!( + self.committed_offset().is_some_and(|committed| committed >= offset), + "given offset is not in the commitlog" + ); + self.durable_offset.send_modify(|val| { + val.replace(offset); + }); + } + + fn mark_durable(&self) { + if let Some(offset) = self.committed_offset() { + self.durable_offset.send_modify(|val| { + val.replace(offset); + }); + } + } + + fn committed_offset(&self) -> Option { + self.commitlog.read().unwrap().max_committed_offset() + } + } + + impl Durability for ManualDurability { + type TxData = Txdata; + + fn append_tx(&self, tx: Self::TxData) { + let mut commitlog = self.commitlog.write().unwrap(); + if let Err(tx) = commitlog.append(tx) { + commitlog.commit().expect("error flushing commitlog"); + commitlog.append(tx).expect("should be able to append after flush"); + } + commitlog.commit().expect("error flushing commitlog"); + } + + fn durable_tx_offset(&self) -> spacetimedb_durability::DurableOffset { + self.durable_offset.subscribe().into() + } + } + + impl Default for ManualDurability { + fn default() -> Self { + let (durable_offset, ..) = watch::channel(None); + Self { + commitlog: Arc::new(RwLock::new( + commitlog::Generic::open(repo::Memory::new(), <_>::default()).unwrap(), + )), + durable_offset, + } + } + } + /// An in-memory `RelationalDB` for testing fn relational_db() -> anyhow::Result> { let TestDB { db, .. } = TestDB::in_memory()?; Ok(Arc::new(db)) } + /// An in-memory `RelationalDB` with `ManualDurability`. + fn relational_db_with_manual_durability() -> anyhow::Result<(Arc, Arc)> { + let dir = TempReplicaDir::new()?; + let durability = Arc::new(ManualDurability::default()); + let db = TestDB::open_db( + &dir, + EmptyHistory::new(), + Some((durability.clone(), Arc::new(|| Ok(0)))), + None, + None, + 0, + )?; + + Ok((Arc::new(db), durability)) + } + /// A [SubscribeSingle] message for testing fn single_subscribe(sql: &str, query_id: u32) -> SubscribeSingle { SubscribeSingle { @@ -1068,27 +1229,57 @@ mod tests { } } + fn client_connection_with_config( + client_id: ClientActorId, + db: &RelationalDB, + config: ClientConfig, + ) -> (Arc, ClientConnectionReceiver) { + let (sender, receiver) = ClientConnectionSender::dummy_with_channel(client_id, config, db.clone()); + (Arc::new(sender), receiver) + } + /// Instantiate a client connection with compression fn client_connection_with_compression( client_id: ClientActorId, + db: &RelationalDB, compression: Compression, - ) -> (Arc, MeteredReceiver) { - let (sender, rx) = ClientConnectionSender::dummy_with_channel( + ) -> (Arc, ClientConnectionReceiver) { + client_connection_with_config( client_id, + db, ClientConfig { protocol: Protocol::Binary, compression, tx_update_full: true, + confirmed_reads: false, }, - ); - (Arc::new(sender), rx) + ) } /// Instantiate a client connection fn client_connection( client_id: ClientActorId, - ) -> (Arc, MeteredReceiver) { - client_connection_with_compression(client_id, Compression::None) + db: &RelationalDB, + ) -> (Arc, ClientConnectionReceiver) { + client_connection_with_compression(client_id, db, Compression::None) + } + + /// Instantiate a client connection with confirmed reads turned on or off. + fn client_connection_with_confirmed_reads( + client_id: ClientActorId, + db: &RelationalDB, + confirmed_reads: bool, + ) -> (Arc, ClientConnectionReceiver) { + client_connection_with_config( + client_id, + db, + ClientConfig { + protocol: Protocol::Binary, + compression: Compression::None, + tx_update_full: true, + confirmed_reads, + }, + ) } /// Insert rules into the RLS system table @@ -1161,13 +1352,13 @@ mod tests { /// Pull a message from receiver and assert that it is a `TxUpdate` with the expected rows async fn assert_tx_update_for_table( - rx: &mut MeteredReceiver, + rx: impl Future>, table_id: TableId, schema: &ProductType, inserts: impl IntoIterator, deletes: impl IntoIterator, ) { - match rx.recv().await { + match rx.await { Some(SerializableMessage::TxUpdate(TransactionUpdateMessage { database_update: SubscriptionUpdateMessage { @@ -1237,6 +1428,22 @@ mod tests { } } + /// Assert that the future `f` completes only after `durability` is marked + /// durable. + /// + /// Namely: + /// + /// - assert that polling `f` once returns [`Poll::Pending`] + /// - call `durability.mark_durable()` + /// - assert that polling `f` returns [`Poll::Ready`]. + /// + async fn assert_after_durable(durability: &ManualDurability, f: impl Future) { + let mut g = pin!(f); + assert_matches!(futures::poll!(&mut g), Poll::Pending); + durability.mark_durable(); + assert_matches!(futures::poll!(g), Poll::Ready(_)); + } + /// Commit a set of row updates and broadcast to subscribers fn commit_tx( db: &RelationalDB, @@ -1252,18 +1459,19 @@ mod tests { db.insert(&mut tx, table_id, &bsatn::to_vec(&row)?)?; } - let Ok(Ok((_, metrics))) = subs.commit_and_broadcast_event(None, module_event(), tx) else { + let Ok(Ok(success)) = subs.commit_and_broadcast_event(None, module_event(), tx) else { panic!("Encountered an error in `commit_and_broadcast_event`"); }; - Ok(metrics) + Ok(success.metrics) } #[test] fn test_subscribe_metrics() -> anyhow::Result<()> { + let db = relational_db()?; + let client_id = client_id_from_u8(1); - let (sender, _) = client_connection(client_id); + let (sender, _) = client_connection(client_id, &db); - let db = relational_db()?; let (subs, _runtime) = ModuleSubscriptions::for_test_new_runtime(db.clone()); // Create a table `t` with index on `id` @@ -1315,10 +1523,11 @@ mod tests { /// Test that clients receive error messages on subscribe #[tokio::test] async fn subscribe_single_error() -> anyhow::Result<()> { + let db = relational_db()?; + let client_id = client_id_from_u8(1); - let (tx, mut rx) = client_connection(client_id); + let (tx, mut rx) = client_connection(client_id, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?; @@ -1335,10 +1544,11 @@ mod tests { /// Test that clients receive error messages on subscribe #[tokio::test] async fn subscribe_multi_error() -> anyhow::Result<()> { + let db = relational_db()?; + let client_id = client_id_from_u8(1); - let (tx, mut rx) = client_connection(client_id); + let (tx, mut rx) = client_connection(client_id, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?; @@ -1355,10 +1565,11 @@ mod tests { /// Test that clients receive error messages on unsubscribe #[tokio::test] async fn unsubscribe_single_error() -> anyhow::Result<()> { + let db = relational_db()?; + let client_id = client_id_from_u8(1); - let (tx, mut rx) = client_connection(client_id); + let (tx, mut rx) = client_connection(client_id, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create a table `t` with an index on `id` @@ -1409,10 +1620,11 @@ mod tests { /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel. #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn unsubscribe_multi_error() -> anyhow::Result<()> { + let db = relational_db()?; + let client_id = client_id_from_u8(1); - let (tx, mut rx) = client_connection(client_id); + let (tx, mut rx) = client_connection(client_id, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create a table `t` with an index on `id` @@ -1463,10 +1675,11 @@ mod tests { /// Test that clients receive error messages on tx updates #[tokio::test] async fn tx_update_error() -> anyhow::Result<()> { + let db = relational_db()?; + let client_id = client_id_from_u8(1); - let (tx, mut rx) = client_connection(client_id); + let (tx, mut rx) = client_connection(client_id, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create two tables `t` and `s` with indexes on their `id` columns @@ -1518,6 +1731,8 @@ mod tests { /// Test that two clients can subscribe to a parameterized query and get the correct rows. #[tokio::test] async fn test_parameterized_subscription() -> anyhow::Result<()> { + let db = relational_db()?; + // Create identities for two different clients let id_for_a = identity_from_u8(1); let id_for_b = identity_from_u8(2); @@ -1526,10 +1741,9 @@ mod tests { let client_id_for_b = client_id_from_u8(2); // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a, &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("identity", AlgebraicType::identity())]; @@ -1576,14 +1790,16 @@ mod tests { let schema = ProductType::from([AlgebraicType::identity()]); // Both clients should only receive their identities and not the other's. - assert_tx_update_for_table(&mut rx_for_a, table_id, &schema, [product![id_for_a]], []).await; - assert_tx_update_for_table(&mut rx_for_b, table_id, &schema, [product![id_for_b]], []).await; + assert_tx_update_for_table(rx_for_a.recv(), table_id, &schema, [product![id_for_a]], []).await; + assert_tx_update_for_table(rx_for_b.recv(), table_id, &schema, [product![id_for_b]], []).await; Ok(()) } /// Test that two clients can subscribe to a table with RLS rules and get the correct rows #[tokio::test] async fn test_rls_subscription() -> anyhow::Result<()> { + let db = relational_db()?; + // Create identities for two different clients let id_for_a = identity_from_u8(1); let id_for_b = identity_from_u8(2); @@ -1592,10 +1808,9 @@ mod tests { let client_id_for_b = client_id_from_u8(2); // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_for_a, &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_for_b, &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("id", AlgebraicType::identity())]; @@ -1650,19 +1865,20 @@ mod tests { let schema = ProductType::from([AlgebraicType::identity()]); // Both clients should only receive their identities and not the other's. - assert_tx_update_for_table(&mut rx_for_a, w_id, &schema, [product![id_for_a]], []).await; - assert_tx_update_for_table(&mut rx_for_b, w_id, &schema, [product![id_for_b]], []).await; + assert_tx_update_for_table(rx_for_a.recv(), w_id, &schema, [product![id_for_a]], []).await; + assert_tx_update_for_table(rx_for_b.recv(), w_id, &schema, [product![id_for_b]], []).await; Ok(()) } /// Test that a client and the database owner can subscribe to the same query #[tokio::test] async fn test_rls_for_owner() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a connection for owner and client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(0)); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(1)); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(0), &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(1), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); // Create table `t` @@ -1710,7 +1926,7 @@ mod tests { )?; assert_tx_update_for_table( - &mut rx_for_a, + rx_for_a.recv(), table_id, &schema, // The owner should receive both identities @@ -1720,7 +1936,7 @@ mod tests { .await; assert_tx_update_for_table( - &mut rx_for_b, + rx_for_b.recv(), table_id, &schema, // Client `b` should only receive its identity @@ -1735,10 +1951,11 @@ mod tests { /// Test that we do not send empty updates to clients #[tokio::test] async fn test_no_empty_updates() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a client connection - let (tx, mut rx) = client_connection(client_id_from_u8(1)); + let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("x", AlgebraicType::U8)]; @@ -1773,7 +1990,7 @@ mod tests { // If the server sends empty updates, this assertion will fail, // because we will receive one for the first transaction. - assert_tx_update_for_table(&mut rx, t_id, &schema, [product![0_u8]], []).await; + assert_tx_update_for_table(rx.recv(), t_id, &schema, [product![0_u8]], []).await; Ok(()) } @@ -1784,10 +2001,11 @@ mod tests { /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel. #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_no_compression_for_subscribe() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a client connection with compression - let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), Compression::Brotli); + let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), &db, Compression::Brotli); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let table_id = db.create_table_for_test("t", &[("x", AlgebraicType::U64)], &[])?; @@ -1828,10 +2046,11 @@ mod tests { /// Test that we receive subscription updates for DML #[tokio::test] async fn test_updates_for_dml() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a client connection - let (tx, mut rx) = client_connection(client_id_from_u8(1)); + let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = [("x", AlgebraicType::U8), ("y", AlgebraicType::U8)]; let t_id = db.create_table_for_test("t", &schema, &[])?; @@ -1856,17 +2075,17 @@ mod tests { )?; // Client should receive insert - assert_tx_update_for_table(&mut rx, t_id, &schema, [product![0_u8, 1_u8]], []).await; + assert_tx_update_for_table(rx.recv(), t_id, &schema, [product![0_u8, 1_u8]], []).await; run(&db, "UPDATE t SET y=2 WHERE x=0", auth, Some(&subs), &mut vec![])?; // Client should receive update - assert_tx_update_for_table(&mut rx, t_id, &schema, [product![0_u8, 2_u8]], [product![0_u8, 1_u8]]).await; + assert_tx_update_for_table(rx.recv(), t_id, &schema, [product![0_u8, 2_u8]], [product![0_u8, 1_u8]]).await; run(&db, "DELETE FROM t WHERE x=0", auth, Some(&subs), &mut vec![])?; // Client should receive delete - assert_tx_update_for_table(&mut rx, t_id, &schema, [], [product![0_u8, 2_u8]]).await; + assert_tx_update_for_table(rx.recv(), t_id, &schema, [], [product![0_u8, 2_u8]]).await; Ok(()) } @@ -1875,10 +2094,11 @@ mod tests { /// but we don't care about that for this test. #[tokio::test] async fn test_no_compression_for_update() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a client connection with compression - let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), Compression::Brotli); + let (tx, mut rx) = client_connection_with_compression(client_id_from_u8(1), &db, Compression::Brotli); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let table_id = db.create_table_for_test("t", &[("x", AlgebraicType::U64)], &[])?; @@ -1925,10 +2145,11 @@ mod tests { #[tokio::test] async fn test_update_for_join() -> anyhow::Result<()> { async fn test_subscription_updates(queries: &[&'static str]) -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a client connection - let (sender, mut rx) = client_connection(client_id_from_u8(1)); + let (sender, mut rx) = client_connection(client_id_from_u8(1), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let p_schema = [("id", AlgebraicType::U64), ("signed_in", AlgebraicType::Bool)]; @@ -1962,7 +2183,7 @@ mod tests { // We should receive both matching player rows assert_tx_update_for_table( - &mut rx, + rx.recv(), p_id, &schema, [product![1_u64, true], product![2_u64, true]], @@ -1980,7 +2201,7 @@ mod tests { // We should receive an update for it because it is still matching assert_tx_update_for_table( - &mut rx, + rx.recv(), p_id, &schema, [product![2_u64, false]], @@ -1998,7 +2219,7 @@ mod tests { // We should receive an update for it because it is still matching assert_tx_update_for_table( - &mut rx, + rx.recv(), p_id, &schema, [product![2_u64, true]], @@ -2038,11 +2259,12 @@ mod tests { /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel. #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_query_pruning() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1)); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2)); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1), &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let u_id = db.create_table_for_test( @@ -2133,7 +2355,7 @@ mod tests { let metrics = commit_tx(&db, &subs, [], [(v_id, product![2u64, 6u64, 6u64])])?; assert_tx_update_for_table( - &mut rx_for_a, + rx_for_a.recv(), u_id, &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]), [product![2u64, 3u64, 3u64]], @@ -2154,7 +2376,7 @@ mod tests { )?; assert_tx_update_for_table( - &mut rx_for_b, + rx_for_b.recv(), u_id, &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]), [product![1u64, 2u64, 3u64]], @@ -2179,9 +2401,10 @@ mod tests { /// Test that we do not evaluate queries that we know will not match row updates #[tokio::test] async fn test_join_pruning() -> anyhow::Result<()> { - let (tx, mut rx) = client_connection(client_id_from_u8(1)); - let db = relational_db()?; + + let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); + let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let u_id = db.create_table_for_test_with_the_works( @@ -2250,7 +2473,7 @@ mod tests { // Insert a new row into `u` that joins with `x = 1` let metrics = commit_tx(&db, &subs, [], [(u_id, product![1u64, 2u64, 3u64])])?; - assert_tx_update_for_table(&mut rx, u_id, &schema, [product![1u64, 2u64, 3u64]], []).await; + assert_tx_update_for_table(rx.recv(), u_id, &schema, [product![1u64, 2u64, 3u64]], []).await; // We should only have evaluated a single query assert_eq!(metrics.delta_queries_evaluated, 1); @@ -2277,7 +2500,7 @@ mod tests { )?; // Results in a no-op - assert_tx_update_for_table(&mut rx, u_id, &schema, [], []).await; + assert_tx_update_for_table(rx.recv(), u_id, &schema, [], []).await; // We should have evaluated queries for `x = 1` and `x = 2` assert_eq!(metrics.delta_queries_evaluated, 2); @@ -2292,7 +2515,7 @@ mod tests { [(v_id, product![3u64, 4u64, 3u64]), (u_id, product![3u64, 4u64, 5u64])], )?; - assert_tx_update_for_table(&mut rx, u_id, &schema, [product![3u64, 4u64, 5u64]], []).await; + assert_tx_update_for_table(rx.recv(), u_id, &schema, [product![3u64, 4u64, 5u64]], []).await; // We should have evaluated queries for `x = 3` and `x = 4` assert_eq!(metrics.delta_queries_evaluated, 2); @@ -2306,7 +2529,7 @@ mod tests { [(v_id, product![3u64, 0u64, 3u64])], )?; - assert_tx_update_for_table(&mut rx, u_id, &schema, [], [product![3u64, 4u64, 5u64]]).await; + assert_tx_update_for_table(rx.recv(), u_id, &schema, [], [product![3u64, 4u64, 5u64]]).await; // We should only have evaluated the query for `x = 4` assert_eq!(metrics.delta_queries_evaluated, 1); @@ -2332,11 +2555,12 @@ mod tests { /// Test that one client subscribing does not affect another #[tokio::test] async fn test_subscribe_distinct_queries_same_plan() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1)); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2)); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1), &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let u_id = db.create_table_for_test_with_the_works( @@ -2402,7 +2626,7 @@ mod tests { commit_tx(&db, &subs, [], [(u_id, product![1u64, 0u64, 0u64])])?; assert_tx_update_for_table( - &mut rx_for_a, + rx_for_a.recv(), u_id, &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]), [product![1u64, 0u64, 0u64]], @@ -2411,7 +2635,7 @@ mod tests { .await; assert_tx_update_for_table( - &mut rx_for_b, + rx_for_b.recv(), u_id, &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]), [product![1u64, 0u64, 0u64]], @@ -2425,11 +2649,12 @@ mod tests { /// Test that one client unsubscribing does not affect another #[tokio::test] async fn test_unsubscribe_distinct_queries_same_plan() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a connection for each client - let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1)); - let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2)); + let (tx_for_a, mut rx_for_a) = client_connection(client_id_from_u8(1), &db); + let (tx_for_b, mut rx_for_b) = client_connection(client_id_from_u8(2), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let u_id = db.create_table_for_test_with_the_works( @@ -2504,7 +2729,7 @@ mod tests { let metrics = commit_tx(&db, &subs, [], [(u_id, product![1u64, 0u64, 0u64])])?; assert_tx_update_for_table( - &mut rx_for_a, + rx_for_a.recv(), u_id, &ProductType::from([AlgebraicType::U64, AlgebraicType::U64, AlgebraicType::U64]), [product![1u64, 0u64, 0u64]], @@ -2536,10 +2761,11 @@ mod tests { /// Needs a multi-threaded tokio runtime so that the module subscription worker can run in parallel. #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn test_query_pruning_for_empty_tables() -> anyhow::Result<()> { + let db = relational_db()?; + // Establish a client connection - let (tx, mut rx) = client_connection(client_id_from_u8(1)); + let (tx, mut rx) = client_connection(client_id_from_u8(1), &db); - let db = relational_db()?; let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); let schema = &[("id", AlgebraicType::U64), ("a", AlgebraicType::U64)]; @@ -2657,4 +2883,64 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_confirmed_reads() -> anyhow::Result<()> { + let (db, durability) = relational_db_with_manual_durability()?; + + let (tx_for_confirmed, mut rx_for_confirmed) = + client_connection_with_confirmed_reads(client_id_from_u8(1), &db, true); + let (tx_for_unconfirmed, mut rx_for_unconfirmed) = + client_connection_with_confirmed_reads(client_id_from_u8(2), &db, false); + + let subs = ModuleSubscriptions::for_test_enclosing_runtime(db.clone()); + let table = db.create_table_for_test("t", &[("x", AlgebraicType::U8)], &[])?; + let schema = ProductType::from([AlgebraicType::U8]); + + // Subscribe both clients. + subscribe_multi(&subs, &["select * from t"], tx_for_confirmed, &mut 0)?; + subscribe_multi(&subs, &["select * from t"], tx_for_unconfirmed, &mut 0)?; + + assert_matches!( + rx_for_unconfirmed.recv().await, + Some(SerializableMessage::Subscription(SubscriptionMessage { + result: SubscriptionResult::SubscribeMulti(_), + .. + })) + ); + assert_after_durable(&durability, async { + assert_matches!( + rx_for_confirmed.recv().await, + Some(SerializableMessage::Subscription(SubscriptionMessage { + result: SubscriptionResult::SubscribeMulti(_), + .. + })) + ); + }) + .await; + + // Insert a row. + let mut tx = begin_mut_tx(&db); + db.insert(&mut tx, table, &bsatn::to_vec(&product![1_u8])?)?; + assert!(matches!( + subs.commit_and_broadcast_event(None, module_event(), tx), + Ok(Ok(_)) + )); + // Insert another row, using SQL. + let auth = AuthCtx::new(identity_from_u8(0), identity_from_u8(0)); + run(&db, "INSERT INTO t (x) VALUES (2)", auth, Some(&subs), &mut vec![])?; + + // Unconfirmed client should have received both rows. + assert_tx_update_for_table(rx_for_unconfirmed.recv(), table, &schema, [product![1_u8]], []).await; + assert_tx_update_for_table(rx_for_unconfirmed.recv(), table, &schema, [product![2_u8]], []).await; + + // Confirmed client should receive the rows after the tx becomes durable. + assert_after_durable(&durability, async { + assert_tx_update_for_table(rx_for_confirmed.recv(), table, &schema, [product![1_u8]], []).await; + assert_tx_update_for_table(rx_for_confirmed.recv(), table, &schema, [product![2_u8]], []).await + }) + .await; + + Ok(()) + } } diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 9820a739dc8..e2fbd3b375a 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -21,6 +21,7 @@ use spacetimedb_client_api_messages::websocket::{ }; use spacetimedb_data_structures::map::{Entry, IntMap}; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; +use spacetimedb_durability::TxOffset; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::{AlgebraicValue, ConnectionId, Identity, ProductValue}; use spacetimedb_primitives::{ColId, IndexId, TableId}; @@ -29,7 +30,7 @@ use std::collections::BTreeMap; use std::fmt::Debug; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot}; /// Clients are uniquely identified by their Identity and ConnectionId. /// Identity is insufficient because different ConnectionIds can use the same Identity. @@ -548,12 +549,38 @@ impl SenderWithGauge { } } +/// The offset used to control visibility of the message if the client has +/// requested confirmed reads. +/// +/// [`SendWorkerMessage`]s are sent while holding the database lock, i.e. +/// without committing the transaction. When the transaction commits, the +/// message sender is expected to send the transaction offset along this channel. +/// +/// NOTE: If the send end is dropped before sending the offset, the +/// [`SendWorker`] will assume that the message sender was cancelled, and exit +/// itself. +pub type TransactionOffset = oneshot::Receiver; + +/// Create a TransactionOffset from a known TxOffset. +pub fn from_tx_offset(offset: TxOffset) -> TransactionOffset { + let (tx, rx) = oneshot::channel(); + let _ = tx.send(offset); + rx +} + + /// Message sent by the [`SubscriptionManager`] to the [`SendWorker`]. #[derive(Debug)] enum SendWorkerMessage { /// A transaction has completed and the [`SubscriptionManager`] has evaluated the incremental queries, /// so the [`SendWorker`] should broadcast them to clients. - Broadcast(ComputedQueries), + /// + /// The `tx_offset` of the transaction is used to control visibility of + /// the results if the client has requested confirmed reads. + Broadcast { + tx_offset: TransactionOffset, + queries: ComputedQueries, + }, /// A new client has been registered in the [`SubscriptionManager`], /// so the [`SendWorker`] should also record its existence. @@ -566,9 +593,14 @@ enum SendWorkerMessage { outbound_ref: Client, }, - // Send a message to a client. + /// Send a message to a client. + /// + /// In some cases, `message` may contain query results. In this case, + /// `tx_offset` is `Some`, and later used to control visibility of the + /// message if the the client has requested confirmed reads. SendMessage { recipient: Arc, + tx_offset: Option, message: SerializableMessage, }, @@ -1100,7 +1132,7 @@ impl SubscriptionManager { #[tracing::instrument(level = "trace", skip_all)] pub fn eval_updates_sequential( &self, - tx: &DeltaTx, + (tx, tx_offset): (&DeltaTx, TransactionOffset), event: Arc, caller: Option>, ) -> ExecutionMetrics { @@ -1266,12 +1298,15 @@ impl SubscriptionManager { // then return ASAP in order to unlock the datastore and start running the next transaction. // See comment on the `send_worker_tx` field in [`SubscriptionManager`] for more motivation. self.send_worker_queue - .send(SendWorkerMessage::Broadcast(ComputedQueries { - updates, - errs, - event, - caller, - })) + .send(SendWorkerMessage::Broadcast { + tx_offset, + queries: ComputedQueries { + updates, + errs, + event, + caller, + }, + }) .expect("send worker has panicked, or otherwise dropped its recv queue!"); drop(span); @@ -1370,10 +1405,12 @@ impl BroadcastQueue { pub fn send_client_message( &self, recipient: Arc, + tx_offset: Option, message: impl Into, ) -> Result<(), BroadcastError> { self.0.send(SendWorkerMessage::SendMessage { recipient, + tx_offset, message: message.into(), })?; Ok(()) @@ -1427,14 +1464,31 @@ impl SendWorker { self.clients .insert(client_id, SendWorkerClient { dropped, outbound_ref }); } - SendWorkerMessage::SendMessage { recipient, message } => { - let _ = recipient.send_message(message); - } + SendWorkerMessage::SendMessage { + recipient, + tx_offset, + message, + } => match tx_offset { + None => { + let _ = recipient.send_message(None, message); + } + Some(tx_offset) => { + let Ok(tx_offset) = tx_offset.await else { + tracing::error!("tx offset sender dropped, exiting send worker"); + return; + }; + let _ = recipient.send_message(Some(tx_offset), message); + } + }, SendWorkerMessage::RemoveClient(client_id) => { self.clients.remove(&client_id); } - SendWorkerMessage::Broadcast(queries) => { - self.send_one_computed_queries(queries); + SendWorkerMessage::Broadcast { tx_offset, queries } => { + let Ok(tx_offset) = tx_offset.await else { + tracing::error!("tx offset sender dropped, exiting send worker"); + return; + }; + self.send_one_computed_queries(tx_offset, queries); } } } @@ -1442,6 +1496,7 @@ impl SendWorker { fn send_one_computed_queries( &mut self, + tx_offset: TxOffset, ComputedQueries { updates, errs, @@ -1527,7 +1582,7 @@ impl SendWorker { event: Some(event.clone()), database_update, }; - send_to_client(&caller, message); + send_to_client(&caller, Some(tx_offset), message); } // Send all the other updates. @@ -1537,7 +1592,7 @@ impl SendWorker { // Conditionally send out a full update or a light one otherwise. let event = client.config.tx_update_full.then(|| event.clone()); let message = TransactionUpdateMessage { event, database_update }; - send_to_client(&client, message); + send_to_client(&client, Some(tx_offset), message); } // Put back the aggregation maps into the worker. @@ -1550,6 +1605,7 @@ impl SendWorker { client.dropped.store(true, Ordering::Release); send_to_client( &client.outbound_ref, + None, SubscriptionMessage { request_id: None, query_id: None, @@ -1565,8 +1621,13 @@ impl SendWorker { } } -fn send_to_client(client: &ClientConnectionSender, message: impl Into) { - if let Err(e) = client.send_message(message) { +fn send_to_client( + client: &ClientConnectionSender, + tx_offset: Option, + message: impl Into, +) { + tracing::trace!(client = %client.id, tx_offset, "send_to_client"); + if let Err(e) = client.send_message(tx_offset, message) { tracing::warn!(%client.id, "failed to send update message to client: {e}") } } @@ -1581,12 +1642,14 @@ mod tests { use spacetimedb_primitives::{ColId, TableId}; use spacetimedb_sats::product; use spacetimedb_subscription::SubscriptionPlan; + use tokio::sync::oneshot; use super::{Plan, SubscriptionManager}; use crate::db::relational_db::tests_utils::with_read_only; use crate::host::module_host::DatabaseTableUpdate; use crate::sql::ast::SchemaViewer; use crate::subscription::module_subscription_manager::ClientQueryId; + use crate::subscription::tx::DeltaTx; use crate::{ client::{ClientActorId, ClientConfig, ClientConnectionSender, ClientName}, db::relational_db::{tests_utils::TestDB, RelationalDB}, @@ -1617,7 +1680,7 @@ mod tests { (Identity::ZERO, ConnectionId::from_u128(connection_id)) } - fn client(connection_id: u128) -> ClientConnectionSender { + fn client(connection_id: u128, db: &RelationalDB) -> ClientConnectionSender { let (identity, connection_id) = id(connection_id); ClientConnectionSender::dummy( ClientActorId { @@ -1626,6 +1689,7 @@ mod tests { name: ClientName(0), }, ClientConfig::for_test(), + db.clone(), ) } @@ -1639,7 +1703,7 @@ mod tests { let hash = plan.hash(); let id = id(0); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -1663,7 +1727,7 @@ mod tests { let plan = compile_plan(&db, sql)?; let hash = plan.hash(); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -1686,7 +1750,7 @@ mod tests { let plan = compile_plan(&db, sql)?; let hash = plan.hash(); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -1712,7 +1776,7 @@ mod tests { let sql = "select * from T"; let plan = compile_plan(&db, sql)?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -1737,7 +1801,7 @@ mod tests { let plan = compile_plan(&db, sql)?; let hash = plan.hash(); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -1766,7 +1830,7 @@ mod tests { let plan = compile_plan(&db, sql)?; let hash = plan.hash(); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -1803,7 +1867,7 @@ mod tests { let plan = compile_plan(&db, sql)?; let hash = plan.hash(); - let clients = (0..3).map(|i| Arc::new(client(i))).collect::>(); + let clients = (0..3).map(|i| Arc::new(client(i, &db))).collect::>(); // All of the clients are using the same query id. let query_id: ClientQueryId = QueryId::new(1); @@ -1844,7 +1908,7 @@ mod tests { let plan = compile_plan(&db, sql)?; let hash = plan.hash(); - let clients = (0..3).map(|i| Arc::new(client(i))).collect::>(); + let clients = (0..3).map(|i| Arc::new(client(i, &db))).collect::>(); // All of the clients are using the same query id. let query_id: ClientQueryId = QueryId::new(1); @@ -1895,7 +1959,7 @@ mod tests { .map(|sql| compile_plan(&db, &sql)) .collect::>>()?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -1940,7 +2004,7 @@ mod tests { .map(|sql| compile_plan(&db, &sql)) .collect::>>()?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -1984,7 +2048,7 @@ mod tests { let table_id = create_table(&db, "t")?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2057,7 +2121,7 @@ mod tests { let table_id = create_table(&db, "t")?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2127,7 +2191,7 @@ mod tests { let t_id = db.create_table_for_test("t", &schema, &[0.into()])?; let s_id = db.create_table_for_test("s", &schema, &[0.into()])?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2199,7 +2263,7 @@ mod tests { let sql = "select * from T"; let plan = compile_plan(&db, sql)?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -2224,7 +2288,7 @@ mod tests { let sql = "select * from T"; let plan = compile_plan(&db, sql)?; - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let query_id: ClientQueryId = QueryId::new(1); @@ -2252,7 +2316,7 @@ mod tests { let hash = plan.hash(); let id = id(0); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2278,7 +2342,7 @@ mod tests { let hash = plan.hash(); let id = id(0); - let client = Arc::new(client(0)); + let client = Arc::new(client(0, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2310,10 +2374,10 @@ mod tests { let hash = plan.hash(); let id0 = id(0); - let client0 = Arc::new(client(0)); + let client0 = Arc::new(client(0, &db)); let id1 = id(1); - let client1 = Arc::new(client(1)); + let client1 = Arc::new(client(1, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2358,10 +2422,10 @@ mod tests { let hash_select1 = plan_select1.hash(); let id0 = id(0); - let client0 = Arc::new(client(0)); + let client0 = Arc::new(client(0, &db)); let id1 = id(1); - let client1 = Arc::new(client(1)); + let client1 = Arc::new(client(1, &db)); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2420,7 +2484,7 @@ mod tests { let id0 = Identity::ZERO; let client0 = ClientActorId::for_test(id0); let config = ClientConfig::for_test(); - let (client0, mut rx) = ClientConnectionSender::dummy_with_channel(client0, config); + let (client0, mut rx) = ClientConnectionSender::dummy_with_channel(client0, config, (*db).clone()); let runtime = tokio::runtime::Runtime::new().unwrap(); let _rt = runtime.enter(); @@ -2442,9 +2506,20 @@ mod tests { timer: None, }); - db.with_read_only(Workload::Update, |tx| { - subscriptions.eval_updates_sequential(&(&*tx).into(), event, Some(Arc::new(client0))) - }); + // This block ensures that the transaction is released before waiting + // for a message to appear on `rx`. + // The message won't be sent until the transaction offset is known, + // and it is known when the transaction commits. + { + let (offset_tx, offset_rx) = oneshot::channel(); + let tx = scopeguard::guard(db.begin_tx(Workload::Update), |tx| { + let (tx_offset, tx_metrics, reducer) = db.release_tx(tx); + let _ = offset_tx.send(tx_offset); + db.report_read_tx_metrics(reducer, tx_metrics); + }); + let delta_tx = DeltaTx::from(&*tx); + subscriptions.eval_updates_sequential((&delta_tx, offset_rx), event, Some(Arc::new(client0))); + } runtime.block_on(async move { tokio::time::timeout(Duration::from_millis(20), async move { diff --git a/crates/datastore/Cargo.toml b/crates/datastore/Cargo.toml index 0a46d472df6..af77913e067 100644 --- a/crates/datastore/Cargo.toml +++ b/crates/datastore/Cargo.toml @@ -50,4 +50,4 @@ spacetimedb-commitlog = { path = "../commitlog", features = ["test"] } # Also as dev-dependencies for use in _this_ crate's tests. proptest.workspace = true -pretty_assertions.workspace = true \ No newline at end of file +pretty_assertions.workspace = true diff --git a/crates/datastore/src/locking_tx_datastore/committed_state.rs b/crates/datastore/src/locking_tx_datastore/committed_state.rs index 5fc4818988d..f30417d85ee 100644 --- a/crates/datastore/src/locking_tx_datastore/committed_state.rs +++ b/crates/datastore/src/locking_tx_datastore/committed_state.rs @@ -25,6 +25,7 @@ use crate::{ use anyhow::anyhow; use core::{convert::Infallible, ops::RangeBounds}; use spacetimedb_data_structures::map::{HashSet, IntMap}; +use spacetimedb_durability::TxOffset; use spacetimedb_lib::{db::auth::StTableType, Identity}; use spacetimedb_primitives::{ColId, ColList, ColSet, IndexId, TableId}; use spacetimedb_sats::{algebraic_value::de::ValueDeserializer, memory_usage::MemoryUsage, Deserialize}; @@ -654,12 +655,13 @@ impl CommittedState { } /// Rolls back the changes immediately made to the committed state during a transaction. - pub(super) fn rollback(&mut self, seq_state: &mut SequencesState, tx_state: TxState) { + pub(super) fn rollback(&mut self, seq_state: &mut SequencesState, tx_state: TxState) -> TxOffset { // Roll back the changes in the reverse order in which they were made // so that e.g., the last change is undone first. for change in tx_state.pending_schema_changes.into_iter().rev() { self.rollback_pending_schema_change(seq_state, change); } + self.next_tx_offset.saturating_sub(1) } fn rollback_pending_schema_change( diff --git a/crates/datastore/src/locking_tx_datastore/datastore.rs b/crates/datastore/src/locking_tx_datastore/datastore.rs index 8dd2c897c65..029b9508360 100644 --- a/crates/datastore/src/locking_tx_datastore/datastore.rs +++ b/crates/datastore/src/locking_tx_datastore/datastore.rs @@ -372,9 +372,10 @@ impl Tx for Locking { /// allowing new mutable transactions to start if this was the last read-only transaction. /// /// Returns: + /// - [`TxOffset`], the smallest transaction offset visible to this transaction. /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran within this transaction. - fn release_tx(&self, tx: Self::Tx) -> (TxMetrics, String) { + fn release_tx(&self, tx: Self::Tx) -> (TxOffset, TxMetrics, String) { tx.release() } } @@ -887,11 +888,11 @@ impl MutTx for Locking { } } - fn rollback_mut_tx(&self, tx: Self::MutTx) -> (TxMetrics, String) { + fn rollback_mut_tx(&self, tx: Self::MutTx) -> (TxOffset, TxMetrics, String) { tx.rollback() } - fn commit_mut_tx(&self, tx: Self::MutTx) -> Result> { + fn commit_mut_tx(&self, tx: Self::MutTx) -> Result> { Ok(Some(tx.commit())) } } @@ -931,7 +932,7 @@ pub struct Replay { } impl Replay { - fn using_visitor(&self, f: impl FnOnce(&mut ReplayVisitor) -> T) -> T { + fn using_visitor(&self, f: impl FnOnce(&mut ReplayVisitor<'_, F>) -> T) -> T { let mut committed_state = self.committed_state.write_arc(); let mut visitor = ReplayVisitor { database_identity: &self.database_identity, @@ -1443,7 +1444,8 @@ mod tests { } fn commit(datastore: &Locking, tx: MutTxId) -> ResultTest { - Ok(datastore.commit_mut_tx(tx)?.expect("commit should produce `TxData`").0) + let (_, tx_data, _, _) = datastore.commit_mut_tx(tx)?.expect("commit should produce `TxData`"); + Ok(tx_data) } #[rustfmt::skip] diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index 94780a685c2..7645fed23f6 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -29,6 +29,7 @@ use core::ops::RangeBounds; use core::{cell::RefCell, mem}; use core::{iter, ops::Bound}; use smallvec::SmallVec; +use spacetimedb_durability::TxOffset; use spacetimedb_execution::{dml::MutDatastore, Datastore, DeltaStore, Row}; use spacetimedb_lib::{db::raw_def::v9::RawSql, metrics::ExecutionMetrics}; use spacetimedb_lib::{ @@ -1195,7 +1196,9 @@ impl MutTxId { /// - [`TxData`], the set of inserts and deletes performed by this transaction. /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran during this transaction. - pub(super) fn commit(mut self) -> (TxData, TxMetrics, String) { + pub(super) fn commit(mut self) -> (TxOffset, TxData, TxMetrics, String) { + self.committed_state_write_lock.next_tx_offset += 1; + let tx_offset = self.committed_state_write_lock.next_tx_offset; let tx_data = self.committed_state_write_lock.merge(self.tx_state, &self.ctx); // Compute and keep enough info that we can @@ -1213,7 +1216,20 @@ impl MutTxId { ); let reducer = self.ctx.into_reducer_name(); - (tx_data, tx_metrics, reducer) + // If the transaction didn't consume an offset (i.e. it was empty), + // report the previous offset. + // + // Note that technically the tx could have run against an empty database, + // in which case we'd wrongly return zero (a non-existent transaction). + // This doesn not happen in practice, however, as [RelationalDB::set_initialized] + // creates a transaction. + let tx_offset = if tx_offset == self.committed_state_write_lock.next_tx_offset { + tx_offset.saturating_sub(1) + } else { + tx_offset + }; + + (tx_offset, tx_data, tx_metrics, reducer) } /// Commits this transaction, applying its changes to the committed state. @@ -1258,8 +1274,9 @@ impl MutTxId { /// Returns: /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran during this transaction. - pub fn rollback(mut self) -> (TxMetrics, String) { - self.committed_state_write_lock + pub fn rollback(mut self) -> (TxOffset, TxMetrics, String) { + let offset = self + .committed_state_write_lock .rollback(&mut self.sequence_state_lock, self.tx_state); // Compute and keep enough info that we can @@ -1276,7 +1293,7 @@ impl MutTxId { &self.committed_state_write_lock, ); let reducer = self.ctx.into_reducer_name(); - (tx_metrics, reducer) + (offset, tx_metrics, reducer) } /// Roll back this transaction, discarding its changes. diff --git a/crates/datastore/src/locking_tx_datastore/tx.rs b/crates/datastore/src/locking_tx_datastore/tx.rs index 1fadd3e94be..d28e14c1806 100644 --- a/crates/datastore/src/locking_tx_datastore/tx.rs +++ b/crates/datastore/src/locking_tx_datastore/tx.rs @@ -6,6 +6,7 @@ use super::{ }; use crate::execution_context::ExecutionContext; use crate::locking_tx_datastore::state_view::IterTx; +use spacetimedb_durability::TxOffset; use spacetimedb_execution::Datastore; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_primitives::{ColList, TableId}; @@ -13,8 +14,8 @@ use spacetimedb_sats::AlgebraicValue; use spacetimedb_schema::schema::TableSchema; use spacetimedb_table::blob_store::BlobStore; use spacetimedb_table::table::Table; -use std::num::NonZeroU64; use std::sync::Arc; +use std::{future, num::NonZeroU64}; use std::{ ops::RangeBounds, time::{Duration, Instant}, @@ -89,9 +90,18 @@ impl TxId { /// allowing new mutable transactions to start if this was the last read-only transaction. /// /// Returns: + /// - [`TxOffset`], the smallest transaction offset visible to this transaction. /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran within this transaction. - pub(super) fn release(self) -> (TxMetrics, String) { + pub(super) fn release(self) -> (TxOffset, TxMetrics, String) { + // A read tx doesn't consume `next_tx_offset`, so subtract one to obtain + // the offset that was visible to the transaction. + // + // Note that technically the tx could have run against an empty database, + // in which case we'd wrongly return zero (a non-existent transaction). + // This doesn not happen in practice, however, as [RelationalDB::set_initialized] + // creates a transaction. + let tx_offset = self.committed_state_shared_lock.next_tx_offset.saturating_sub(1); let tx_metrics = TxMetrics::new( &self.ctx, self.timer, @@ -102,7 +112,7 @@ impl TxId { &self.committed_state_shared_lock, ); let reducer = self.ctx.into_reducer_name(); - (tx_metrics, reducer) + (tx_offset, tx_metrics, reducer) } /// The Number of Distinct Values (NDV) for a column or list of columns, @@ -120,4 +130,8 @@ impl TxId { let (_, index) = table.get_index_by_cols(cols)?; NonZeroU64::new(index.num_keys() as u64) } + + pub fn tx_offset(&self) -> future::Ready { + future::ready(self.committed_state_shared_lock.next_tx_offset) + } } diff --git a/crates/datastore/src/traits.rs b/crates/datastore/src/traits.rs index 9626f992bae..d9625226b5f 100644 --- a/crates/datastore/src/traits.rs +++ b/crates/datastore/src/traits.rs @@ -9,6 +9,7 @@ use super::Result; use crate::execution_context::{ReducerContext, Workload}; use crate::system_tables::ST_TABLE_ID; use spacetimedb_data_structures::map::IntMap; +use spacetimedb_durability::TxOffset; use spacetimedb_lib::{hash_bytes, Identity}; use spacetimedb_primitives::*; use spacetimedb_sats::hash::Hash; @@ -177,13 +178,13 @@ pub struct TxData { deletes: BTreeMap>, /// Map of all `TableId`s in both `inserts` and `deletes` to their /// corresponding table name. + // TODO: Store table name as ref counted string. tables: IntMap, /// Tx offset of the transaction which performed these operations. /// /// `None` implies that `inserts` and `deletes` are both empty, /// but `Some` does not necessarily imply that either is non-empty. tx_offset: Option, - // TODO: Store an `Arc` or equivalent instead. } impl TxData { @@ -327,9 +328,19 @@ pub trait Tx { /// Release this read-only transaction. /// /// Returns: + /// - [`TxOffset`], the smallest transaction offset visible to this transaction. + /// + /// Note that, if the transaction was running under an isolation level + /// weaker than [`IsolationLevel::Snapshot`], it may have observed + /// transactions at a later offset than when it started. + /// + /// Implementations must uphold that the returned transaction offset + /// accounts for such read anomalies, i.e. the offset must include the + /// observed transactions. + /// /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran within this transaction. - fn release_tx(&self, tx: Self::Tx) -> (TxMetrics, String); + fn release_tx(&self, tx: Self::Tx) -> (TxOffset, TxMetrics, String); } pub trait MutTx { @@ -341,17 +352,27 @@ pub trait MutTx { /// Commits `tx`, applying its changes to the committed state. /// /// Returns: + /// - [`TxOffset`], the offset this transaction was committed at. + /// + /// Note that, if the transaction was running under an isolation level + /// weaker than [`IsolationLevel::Snapshot`], it may have observed + /// transactions at a later offset than when it started. + /// + /// Implementations must uphold that the returned transaction offset + /// accounts for such read anomalies, i.e. the offset must include the + /// observed transactions. + /// /// - [`TxData`], the set of inserts and deletes performed by this transaction. /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran during this transaction. - fn commit_mut_tx(&self, tx: Self::MutTx) -> Result>; + fn commit_mut_tx(&self, tx: Self::MutTx) -> Result>; /// Rolls back this transaction, discarding its changes. /// /// Returns: /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran within this transaction. - fn rollback_mut_tx(&self, tx: Self::MutTx) -> (TxMetrics, String); + fn rollback_mut_tx(&self, tx: Self::MutTx) -> (TxOffset, TxMetrics, String); } /// Standard metadata associated with a database. diff --git a/crates/durability/Cargo.toml b/crates/durability/Cargo.toml index 34816a05437..3fc9eb19e6d 100644 --- a/crates/durability/Cargo.toml +++ b/crates/durability/Cargo.toml @@ -14,6 +14,7 @@ log.workspace = true spacetimedb-commitlog.workspace = true spacetimedb-paths.workspace = true spacetimedb-sats.workspace = true +thiserror.workspace = true tokio.workspace = true tracing.workspace = true diff --git a/crates/durability/src/imp/local.rs b/crates/durability/src/imp/local.rs index 4a5c07feec5..1e2862aca92 100644 --- a/crates/durability/src/imp/local.rs +++ b/crates/durability/src/imp/local.rs @@ -3,10 +3,7 @@ use std::{ num::NonZeroU16, panic, sync::{ - atomic::{ - AtomicI64, AtomicU64, - Ordering::{Acquire, Relaxed, Release}, - }, + atomic::{AtomicU64, Ordering::Relaxed}, Arc, Weak, }, time::Duration, @@ -18,13 +15,13 @@ use log::{info, trace, warn}; use spacetimedb_commitlog::{error, payload::Txdata, Commit, Commitlog, Decoder, Encode, Transaction}; use spacetimedb_paths::server::CommitLogDir; use tokio::{ - sync::mpsc, + sync::{mpsc, watch}, task::{spawn_blocking, AbortHandle, JoinHandle}, time::{interval, MissedTickBehavior}, }; use tracing::instrument; -use crate::{Durability, History, TxOffset}; +use crate::{Durability, DurableOffset, History, TxOffset}; /// [`Local`] configuration. #[derive(Clone, Copy, Debug)] @@ -62,17 +59,7 @@ pub struct Local { clog: Arc>>, /// The durable transaction offset, as reported by the background /// [`FlushAndSyncTask`]. - /// - /// A negative number indicates that we haven't flushed yet, or that the - /// number overflowed. In either case, appending new transactions shall panic. - /// - /// The offset will be used by the datastore to squash durable transactions - /// into the committed state, thereby making them visible to durable-only - /// readers. - /// - /// We don't want to hang on to those transactions longer than needed, so - /// acquire / release or stronger should be used to prevent stale reads. - durable_offset: Arc, + durable_offset: watch::Receiver>, /// Backlog of transactions to be written to disk by the background /// [`PersisterTask`]. /// @@ -100,10 +87,7 @@ impl Local { let clog = Arc::new(Commitlog::open(root, opts.commitlog)?); let (queue, rx) = mpsc::unbounded_channel(); let queue_depth = Arc::new(AtomicU64::new(0)); - let offset = { - let offset = clog.max_committed_offset().map(|x| x as i64).unwrap_or(-1); - Arc::new(AtomicI64::new(offset)) - }; + let (durable_tx, durable_rx) = watch::channel(clog.max_committed_offset()); let persister_task = rt.spawn( PersisterTask { @@ -118,7 +102,7 @@ impl Local { FlushAndSyncTask { clog: Arc::downgrade(&clog), period: opts.sync_interval, - offset: offset.clone(), + offset: durable_tx, abort: persister_task.abort_handle(), } .run(), @@ -126,7 +110,7 @@ impl Local { Ok(Self { clog, - durable_offset: offset, + durable_offset: durable_rx, queue, queue_depth, persister_task, @@ -256,7 +240,7 @@ fn flush_error(e: io::Error) { struct FlushAndSyncTask { clog: Weak>>, period: Duration, - offset: Arc, + offset: watch::Sender>, /// Handle to abort the [`PersisterTask`] if fsync panics. abort: AbortHandle, } @@ -277,8 +261,7 @@ impl FlushAndSyncTask { }; // Skip if nothing changed. if let Some(committed) = clog.max_committed_offset() { - let durable = self.offset.load(Acquire); - if durable.is_positive() && committed == durable as _ { + if self.offset.borrow().is_some_and(|durable| durable == committed) { continue; } } @@ -297,8 +280,9 @@ impl FlushAndSyncTask { } Ok(Ok(Some(new_offset))) => { trace!("synced to offset {new_offset}"); - // NOTE: Overflow will make `durable_tx_offset` return `None` - self.offset.store(new_offset as i64, Release); + self.offset.send_modify(|val| { + val.replace(new_offset); + }); } // No data to flush. Ok(Ok(None)) => {} @@ -317,9 +301,8 @@ impl Durability for Local { self.queue_depth.fetch_add(1, Relaxed); } - fn durable_tx_offset(&self) -> Option { - let offset = self.durable_offset.load(Acquire); - (offset > -1).then_some(offset as u64) + fn durable_tx_offset(&self) -> DurableOffset { + self.durable_offset.clone().into() } } diff --git a/crates/durability/src/lib.rs b/crates/durability/src/lib.rs index 1ab5fac1bd4..ddda60423a0 100644 --- a/crates/durability/src/lib.rs +++ b/crates/durability/src/lib.rs @@ -1,5 +1,8 @@ use std::{iter, marker::PhantomData, sync::Arc}; +use thiserror::Error; +use tokio::sync::watch; + pub use spacetimedb_commitlog::{error, payload::Txdata, Decoder, Transaction}; mod imp; @@ -15,6 +18,71 @@ pub use imp::{local, Local}; /// of all offsets smaller than it. pub type TxOffset = u64; +#[derive(Debug, Error)] +#[error("the database's durability layer went away")] +pub struct DurabilityExited; + +/// Handle to the durable offset, obtained via [`Durability::durable_tx_offset`]. +/// +/// The handle can be used to read the current durable offset, or wait for a +/// provided offset to be reached. +/// +/// The handle is valid for as long as the [`Durability`] instance it was +/// obtained from is live, i.e. able to persist transactions. When the instance +/// shuts down or crashes, methods will return errors of type [`DurabilityExited`]. +pub struct DurableOffset { + // TODO: `watch::Receiver::wait_for` will hold a shared lock until all + // subscribers have seen the current value. Although it may skip entries, + // this may cause unacceptable contention. We may consider a custom watch + // channel that operates on an `AtomicU64` instead of an `RwLock`. + inner: watch::Receiver>, +} + +impl DurableOffset { + /// Get the current durable offset, or `None` if no transaction has been + /// made durable yet. + /// + /// Returns `Err` if the associated durablity is no longer live. + pub fn get(&self) -> Result, DurabilityExited> { + self.guard_closed().map(|()| self.inner.borrow().as_ref().copied()) + } + + /// Get the current durable offset, even if the associated durability is + /// no longer live. + pub fn last_seen(&self) -> Option { + self.inner.borrow().as_ref().copied() + } + + /// Wait for `offset` to become durable, i.e. + /// + /// ```ignore + /// self.get().unwrap().is_some_and(|durable| durable >= offset) + /// ``` + /// + /// Returns the actual durable offset at which above condition evaluated to + /// `true`, or an `Err` if the durability is no longer live. + /// + /// Returns immediately if the condition evaluates to `true` for the current + /// durable offset. + pub async fn wait_for(&mut self, offset: TxOffset) -> Result { + self.inner + .wait_for(|durable| durable.is_some_and(|val| val >= offset)) + .await + .map(|r| r.as_ref().copied().unwrap()) + .map_err(|_| DurabilityExited) + } + + fn guard_closed(&self) -> Result<(), DurabilityExited> { + self.inner.has_changed().map(drop).map_err(|_| DurabilityExited) + } +} + +impl From>> for DurableOffset { + fn from(inner: watch::Receiver>) -> Self { + Self { inner } + } +} + /// The durability API. /// /// NOTE: This is a preliminary definition, still under consideration. @@ -41,7 +109,7 @@ pub trait Durability: Send + Sync { /// A `None` return value indicates that the durable offset is not known, /// either because nothing has been persisted yet, or because the status /// cannot be retrieved. - fn durable_tx_offset(&self) -> Option; + fn durable_tx_offset(&self) -> DurableOffset; } /// Access to the durable history. diff --git a/crates/execution/src/lib.rs b/crates/execution/src/lib.rs index 8fb8e50e59b..ec083dbb4f5 100644 --- a/crates/execution/src/lib.rs +++ b/crates/execution/src/lib.rs @@ -255,7 +255,7 @@ impl<'a> Iterator for DeltaScanIter<'a> { /// Execute a query plan. /// The actual execution is driven by `f`. -pub fn execute_plan(plan: &ProjectPlan, tx: &T, f: impl Fn(PlanIter) -> R) -> Result +pub fn execute_plan(plan: &ProjectPlan, tx: &T, f: impl Fn(PlanIter<'_>) -> R) -> Result where T: Datastore + DeltaStore, { diff --git a/crates/pg/Cargo.toml b/crates/pg/Cargo.toml new file mode 100644 index 00000000000..dd49122dea0 --- /dev/null +++ b/crates/pg/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "spacetimedb-pg" +version.workspace = true +edition.workspace = true +rust-version.workspace = true +license-file = "LICENSE" +description = "Postgres wire protocol Server support for SpacetimeDB" + +[dependencies] +spacetimedb-client-api-messages.workspace = true +spacetimedb-client-api.workspace = true +spacetimedb-lib.workspace = true + +anyhow.workspace = true +async-trait.workspace = true +axum.workspace = true +futures.workspace = true +http.workspace = true +log.workspace = true +pgwire.workspace = true +thiserror.workspace = true +tokio.workspace = true diff --git a/crates/pg/LICENSE b/crates/pg/LICENSE new file mode 120000 index 00000000000..8540cf8a991 --- /dev/null +++ b/crates/pg/LICENSE @@ -0,0 +1 @@ +../../licenses/BSL.txt \ No newline at end of file diff --git a/crates/pg/README.md b/crates/pg/README.md new file mode 100644 index 00000000000..fc5b684dd86 --- /dev/null +++ b/crates/pg/README.md @@ -0,0 +1,3 @@ +> ⚠️ **Internal Crate** ⚠️ +> +> This crate is intended for internal use only. It is **not** stable and may change without notice. diff --git a/crates/pg/src/encoder.rs b/crates/pg/src/encoder.rs new file mode 100644 index 00000000000..f5a6ed990ed --- /dev/null +++ b/crates/pg/src/encoder.rs @@ -0,0 +1,301 @@ +use crate::pg_server::PgError; +use pgwire::api::portal::Format; +use pgwire::api::results::{DataRowEncoder, FieldInfo}; +use pgwire::api::Type; +use spacetimedb_lib::sats::satn::{PsqlChars, PsqlPrintFmt, PsqlType, TypedWriter}; +use spacetimedb_lib::sats::{satn, ValueWithType}; +use spacetimedb_lib::{ + ser, AlgebraicType, AlgebraicValue, ProductType, ProductTypeElement, ProductValue, TimeDuration, Timestamp, +}; +use std::borrow::Cow; +use std::sync::Arc; + +pub(crate) fn row_desc(schema: &ProductType, format: &Format) -> Arc> { + Arc::new( + schema + .elements + .iter() + .enumerate() + .map(|(pos, ty)| { + let field_name = ty.name.clone().map(Into::into).unwrap_or_else(|| format!("col_{pos}")); + let field_type = type_of(schema, ty); + FieldInfo::new(field_name, None, None, field_type, format.format_for(pos)) + }) + .collect(), + ) +} + +pub(crate) fn type_of(schema: &ProductType, ty: &ProductTypeElement) -> Type { + let format = PsqlPrintFmt::use_fmt(schema, ty, ty.name()); + match &ty.algebraic_type { + AlgebraicType::String => Type::VARCHAR, + AlgebraicType::Bool => Type::BOOL, + AlgebraicType::U8 | AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8, + AlgebraicType::U64 | AlgebraicType::I128 | AlgebraicType::U128 | AlgebraicType::I256 | AlgebraicType::U256 => { + Type::NUMERIC + } + AlgebraicType::F32 => Type::FLOAT4, + AlgebraicType::F64 => Type::FLOAT8, + AlgebraicType::Array(ty) => match *ty.elem_ty { + AlgebraicType::String => Type::VARCHAR_ARRAY, + AlgebraicType::Bool => Type::BOOL_ARRAY, + AlgebraicType::U8 => Type::BYTEA, + AlgebraicType::I8 | AlgebraicType::I16 => Type::INT2_ARRAY, + AlgebraicType::U16 | AlgebraicType::I32 => Type::INT4_ARRAY, + AlgebraicType::U32 | AlgebraicType::I64 => Type::INT8_ARRAY, + AlgebraicType::U64 + | AlgebraicType::I128 + | AlgebraicType::U128 + | AlgebraicType::I256 + | AlgebraicType::U256 => Type::NUMERIC_ARRAY, + _ => Type::ANYARRAY, + }, + AlgebraicType::Product(_) => match format { + PsqlPrintFmt::Hex => Type::BYTEA_ARRAY, + PsqlPrintFmt::Timestamp => Type::TIMESTAMP, + PsqlPrintFmt::Duration => Type::INTERVAL, + _ => Type::JSON, + }, + AlgebraicType::Sum(sum) if sum.is_simple_enum() => Type::ANYENUM, + AlgebraicType::Sum(_) => Type::JSON, + _ => Type::UNKNOWN, + } +} + +impl ser::Error for PgError { + fn custom(msg: T) -> Self { + PgError::Other(anyhow::anyhow!(msg.to_string())) + } +} + +pub(crate) struct PsqlFormatter<'a> { + pub(crate) encoder: &'a mut DataRowEncoder, +} + +impl TypedWriter for PsqlFormatter<'_> { + type Error = PgError; + + fn write(&mut self, value: W) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_string())?; + Ok(()) + } + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_string(&mut self, value: &str) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.encoder.encode_field(&value)?; + Ok(()) + } + + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_rfc3339()?)?; + Ok(()) + } + + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { + self.encoder.encode_field(&value.to_iso8601())?; + Ok(()) + } + + fn write_alt_record( + &mut self, + ty: &PsqlType, + value: &ValueWithType<'_, ProductValue>, + ) -> Result { + let json = satn::PsqlWrapper { ty: ty.clone(), value }.to_string(); + self.encoder.encode_field(&json)?; + Ok(true) + } + + fn write_record( + &mut self, + _fields: Vec<(Cow, PsqlType, ValueWithType)>, + ) -> Result<(), Self::Error> { + unreachable!("Use `write_alt_record` for records in PSQL format"); + } + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error> { + // Is a simple enum? + if let AlgebraicType::Sum(sum) = &ty.field.algebraic_type { + if sum.is_simple_enum() { + if let Some(variant_name) = name { + self.encoder.encode_field(&variant_name)?; + return Ok(()); + } + } + } + + let PsqlChars { start, sep, end, quote } = ty.client.format_chars(); + let name = name.map(Cow::from).unwrap_or_else(|| Cow::from(tag.to_string())); + let json = format!( + "{start}{quote}{name}{quote}{sep} {}{end}", + satn::PsqlWrapper { ty, value } + ); + self.encoder.encode_field(&json)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::pg_server::to_rows; + use futures::StreamExt; + use spacetimedb_client_api_messages::http::SqlStmtResult; + use spacetimedb_lib::sats::algebraic_value::Packed; + use spacetimedb_lib::sats::{i256, product, u256, AlgebraicType, ProductType, SumTypeVariant}; + use spacetimedb_lib::{ConnectionId, Identity}; + + async fn run(schema: ProductType, row: ProductValue) -> String { + let header = row_desc(&schema, &Format::UnifiedText); + + let stmt = SqlStmtResult { + schema, + rows: vec![row], + total_duration_micros: 0, + stats: Default::default(), + }; + let mut stream = to_rows(stmt, header).unwrap(); + let mut result = String::new(); + if let Some(row) = stream.next().await { + result = String::from_utf8_lossy(row.unwrap().data.freeze().as_ref()).to_string(); + } + result + } + + #[tokio::test] + async fn test_primitives() { + let schema = ProductType::from([ + AlgebraicType::U8, + AlgebraicType::I8, + AlgebraicType::I16, + AlgebraicType::U16, + AlgebraicType::I32, + AlgebraicType::U32, + AlgebraicType::I64, + AlgebraicType::U64, + AlgebraicType::I128, + AlgebraicType::U128, + AlgebraicType::I256, + AlgebraicType::U256, + AlgebraicType::F32, + AlgebraicType::F64, + AlgebraicType::String, + AlgebraicType::Bool, + ]); + let value = product![ + 1u8, + -1i8, + -2i16, + 3u16, + -4i32, + 5u32, + -6i64, + 7u64, + Packed::from(-8i128), + Packed::from(9u128), + i256::from(-10), + u256::from(11u128), + 12.34f32, + 56.78f64, + "test".to_string(), + true, + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}1\0\0\0\u{2}-1\0\0\0\u{2}-2\0\0\0\u{1}3\0\0\0\u{2}-4\0\0\0\u{1}5\0\0\0\u{2}-6\0\0\0\u{1}7\0\0\0\u{2}-8\0\0\0\u{1}9\0\0\0\u{3}-10\0\0\0\u{2}11\0\0\0\u{5}12.34\0\0\0\u{5}56.78\0\0\0\u{4}test\0\0\0\u{1}t"); + } + + #[tokio::test] + async fn test_enum() { + let some = AlgebraicType::option(AlgebraicType::I64); + let schema = ProductType::from([some.clone(), some]); + let value = product![ + AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Some(1) + AlgebraicValue::sum(1, AlgebraicValue::unit()), // None + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{b}{\"some\": 1}\0\0\0\u{c}{\"none\": {}}"); + + let color = AlgebraicType::Sum([SumTypeVariant::new_named(AlgebraicType::I64, "Gray")].into()); + let nested = AlgebraicType::option(color.clone()); + let schema = ProductType::from([color, nested]); + // {"Gray": 1}, {"some": {"Gray": 2}} + let value = product![ + AlgebraicValue::sum(0, AlgebraicValue::I64(1)), // Gray(1) + AlgebraicValue::sum(0, AlgebraicValue::sum(0, AlgebraicValue::I64(2))), // Some(Gray(2)) + ]; + let row = run(schema.clone(), value.clone()).await; + assert_eq!(row, "\0\0\0\u{b}{\"Gray\": 1}\0\0\0\u{15}{\"some\": {\"Gray\": 2}}"); + + // Now nested product + let product = AlgebraicType::product([ + ProductTypeElement::new(AlgebraicType::Product(schema), Some("x".into())), + ProductTypeElement::new(AlgebraicType::String, Some("y".into())), + ]); + let schema = ProductType::from([product.clone()]); + let value = product![AlgebraicValue::product(vec![ + value.into(), + AlgebraicValue::String("a".into()), + ])]; + let row = run(schema, value).await; + assert_eq!( + row, + "\0\0\0G{\"x\": {\"col_0\": {\"Gray\": 1}, \"col_1\": {\"some\": {\"Gray\": 2}}}, \"y\": \"a\"}" + ); + + // Now a simple enum + let names = AlgebraicType::simple_enum(["A", "B", "C"].into_iter()); + let schema = ProductType::from([names.clone(), names.clone(), names]); + let value = product![ + AlgebraicValue::enum_simple(0), // A + AlgebraicValue::enum_simple(1), // B + AlgebraicValue::enum_simple(2), // C + ]; + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0\u{1}A\0\0\0\u{1}B\0\0\0\u{1}C"); + } + + #[tokio::test] + async fn test_special_types() { + let schema = ProductType::from([ + AlgebraicType::identity(), + AlgebraicType::connection_id(), + AlgebraicType::time_duration(), + AlgebraicType::timestamp(), + AlgebraicType::bytes(), + ]); + let value = product![ + Identity::ZERO, + ConnectionId::ZERO, + TimeDuration::from_micros(0), + Timestamp::from_micros_since_unix_epoch(1622545800000), + AlgebraicValue::Bytes("test".as_bytes().into()), + ]; + + let row = run(schema, value).await; + assert_eq!(row, "\0\0\0B\\x0000000000000000000000000000000000000000000000000000000000000000\0\0\0\"\\x00000000000000000000000000000000\0\0\0\u{3}P0D\0\0\0\u{1d}1970-01-19T18:42:25.800+00:00\0\0\0\n\\x74657374"); + } +} diff --git a/crates/pg/src/lib.rs b/crates/pg/src/lib.rs new file mode 100644 index 00000000000..c4466bbc50d --- /dev/null +++ b/crates/pg/src/lib.rs @@ -0,0 +1,2 @@ +mod encoder; +pub mod pg_server; diff --git a/crates/pg/src/pg_server.rs b/crates/pg/src/pg_server.rs new file mode 100644 index 00000000000..39b0fcacacd --- /dev/null +++ b/crates/pg/src/pg_server.rs @@ -0,0 +1,381 @@ +use std::fmt::Debug; +use std::sync::Arc; + +use crate::encoder::{row_desc, PsqlFormatter}; +use async_trait::async_trait; +use axum::body::to_bytes; +use axum::response::IntoResponse; +use futures::{stream, Sink}; +use futures::{SinkExt, Stream}; +use http::StatusCode; +use pgwire::api::auth::{ + finish_authentication, save_startup_parameters_to_metadata, DefaultServerParameterProvider, LoginInfo, + StartupHandler, +}; +use pgwire::api::portal::Format; +use pgwire::api::query::SimpleQueryHandler; +use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse, Response, Tag}; +use pgwire::api::{ClientInfo, METADATA_DATABASE}; +use pgwire::api::{PgWireConnectionState, PgWireServerHandlers}; +use pgwire::error::{ErrorInfo, PgWireError, PgWireResult}; +use pgwire::messages::data::DataRow; +use pgwire::messages::startup::Authentication; +use pgwire::messages::{PgWireBackendMessage, PgWireFrontendMessage}; +use pgwire::tokio::process_socket; +use spacetimedb_client_api::auth::validate_token; +use spacetimedb_client_api::routes::database; +use spacetimedb_client_api::routes::database::{SqlParams, SqlQueryParams}; +use spacetimedb_client_api::{ControlStateReadAccess, ControlStateWriteAccess, NodeDelegate}; +use spacetimedb_client_api_messages::http::SqlStmtResult; +use spacetimedb_client_api_messages::name::DatabaseName; +use spacetimedb_lib::sats::satn::{PsqlClient, TypedSerializer}; +use spacetimedb_lib::sats::{satn, Serialize, Typespace}; +use spacetimedb_lib::version::spacetimedb_lib_version; +use spacetimedb_lib::{Identity, ProductValue}; +use thiserror::Error; +use tokio::net::TcpListener; +use tokio::sync::{Mutex, Notify}; + +#[derive(Error, Debug)] +pub(crate) enum PgError { + #[error("(metadata) {0}")] + MetadataError(anyhow::Error), + #[error("(Sql) {0}")] + Sql(String), + #[error("Database name is required")] + DatabaseNameRequired, + #[error(transparent)] + Pg(#[from] PgWireError), + #[error("SSL is not supported by SpacetimeDB")] + SSLNotSupported, + #[error(transparent)] + Other(#[from] anyhow::Error), +} + +impl From for PgWireError { + fn from(err: PgError) -> Self { + if let PgError::Pg(err) = err { + err + } else { + PgWireError::ApiError(Box::new(err)) + } + } +} + +#[derive(Clone)] +struct Metadata { + database: String, + caller_identity: Identity, +} + +pub(crate) fn to_rows( + stmt: SqlStmtResult, + header: Arc>, +) -> Result>, PgError> { + let mut results = Vec::with_capacity(stmt.rows.len()); + let ty = Typespace::EMPTY.with_type(&stmt.schema); + + for row in stmt.rows { + let mut encoder = DataRowEncoder::new(header.clone()); + + for (idx, value) in ty.with_values(&row).enumerate() { + let ty = satn::PsqlType { + client: PsqlClient::Postgres, + tuple: ty.ty(), + field: &ty.ty().elements[idx], + idx, + }; + let mut fmt = PsqlFormatter { encoder: &mut encoder }; + value.serialize(TypedSerializer { ty: &ty, f: &mut fmt })?; + } + results.push(encoder.finish()); + } + Ok(stream::iter(results)) +} + +fn stats(stmt: &SqlStmtResult) -> String { + let mut info = Vec::new(); + if stmt.stats.rows_inserted != 0 { + info.push(format!("inserted: {}", stmt.stats.rows_inserted)); + } + if stmt.stats.rows_deleted != 0 { + info.push(format!("deleted: {}", stmt.stats.rows_deleted)); + } + if stmt.stats.rows_updated != 0 { + info.push(format!("updated: {}", stmt.stats.rows_updated)); + } + info.push(format!( + "server: {:.2?}", + std::time::Duration::from_micros(stmt.total_duration_micros) + )); + + info.join(", ") +} + +struct ResponseWrapper(T); +impl IntoResponse for ResponseWrapper { + fn into_response(self) -> axum::response::Response { + unreachable!("Blank impl to satisfy IntoResponse") + } +} + +async fn response(res: axum::response::Result, database: &str) -> Result { + match res.map(ResponseWrapper) { + Ok(sql) => Ok(sql.0), + err => { + let res = err.into_response(); + if res.status() == StatusCode::NOT_FOUND { + log::error!("PG: Database not found: {database}"); + return Err(PgWireError::UserError(Box::new(ErrorInfo::new( + "FATAL".to_string(), + "3D000".to_string(), + format!("database \"{database}\" does not exist"), + ))) + .into()); + } + let bytes = to_bytes(res.into_body(), usize::MAX) + .await + .map_err(|err| PgWireError::ApiError(Box::new(err)))?; + let err = String::from_utf8_lossy(&bytes); + log::error!("PG: Error for database {database}: {err}"); + Err(PgError::Sql(format!("{err}"))) + } + } +} + +struct PgSpacetimeDB { + ctx: T, + cached: Mutex>, + parameter_provider: DefaultServerParameterProvider, +} + +impl PgSpacetimeDB { + async fn exe_sql<'a>(&self, query: String) -> PgWireResult>> { + let params = self.cached.lock().await.clone().unwrap(); + let db = SqlParams { + name_or_identity: database::NameOrIdentity::Name(DatabaseName(params.database.clone())), + }; + + let sql = match response( + database::sql_direct( + self.ctx.clone(), + db, + SqlQueryParams { confirmed: true }, + params.caller_identity, + query.to_string(), + ) + .await, + ¶ms.database, + ) + .await + { + Ok(sql) => sql, + Err(PgError::Pg(PgWireError::UserError(err))) => { + return Ok(vec![Response::Error(err)]); + } + Err(err) => { + return Err(err.into()); + } + }; + + let mut result = Vec::with_capacity(sql.len()); + for sql_result in sql { + let header = row_desc(&sql_result.schema, &Format::UnifiedText); + if sql_result.rows.is_empty() && !query.to_uppercase().contains("SELECT") { + let tag = Tag::new(&stats(&sql_result)); + result.push(Response::Execution(tag)); + } else { + let rows = to_rows(sql_result, header.clone())?; + let q = QueryResponse::new(header, rows); + result.push(Response::Query(q)); + } + } + Ok(result) + } +} + +async fn close_client(client: &mut C, err: E) -> PgWireResult<()> +where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + pgwire::messages::response::ErrorResponse: From, +{ + let err = pgwire::messages::response::ErrorResponse::from(err); + client.feed(PgWireBackendMessage::ErrorResponse(err)).await?; + client.close().await?; + Ok(()) +} + +#[async_trait] +impl StartupHandler + for PgSpacetimeDB +{ + async fn on_startup(&self, client: &mut C, message: PgWireFrontendMessage) -> PgWireResult<()> + where + C: ClientInfo + Sink + Unpin + Send, + C::Error: Debug, + PgWireError: From<>::Error>, + { + match message { + PgWireFrontendMessage::Startup(ref startup) => { + save_startup_parameters_to_metadata(client, startup); + client.set_state(PgWireConnectionState::AuthenticationInProgress); + + let login_info = LoginInfo::from_client_info(client); + + if login_info.database().is_none() { + return Err(PgError::DatabaseNameRequired.into()); + } + + client + .send(PgWireBackendMessage::Authentication(Authentication::CleartextPassword)) + .await?; + } + PgWireFrontendMessage::PasswordMessageFamily(pwd) => { + let params = client.metadata(); + let param = |param: &str| { + params + .get(param) + .map(String::from) + .ok_or_else(|| PgError::MetadataError(anyhow::anyhow!("Missing parameter: {}", param))) + }; + + // We don't support `METADATA_USER` because we don't have a user management system. + let database = param(METADATA_DATABASE)?; + let pwd = pwd.into_password()?; + if let Ok(application_name) = param("application_name") { + log::info!("PG: Connecting to database: {database}, by {application_name}",); + } else { + log::info!("PG: Connecting to database: {database}"); + } + + let name = database::NameOrIdentity::Name(DatabaseName(database.clone())); + match response(name.resolve(&self.ctx).await, &database).await { + Ok(identity) => identity, + Err(PgError::Pg(PgWireError::UserError(err))) => { + return close_client(client, *err).await; + } + Err(err) => { + return Err(err.into()); + } + }; + + let caller_identity = match validate_token(&self.ctx, &pwd.password).await { + Ok(claims) => claims.identity, + Err(err) => { + log::error!( + "PG: Authentication failed for identity `{}` on database {database}: {err}", + pwd.password + ); + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; + } + }; + + log::info!("PG: Connected to database: {database} using identity `{caller_identity}`"); + + let metadata = Metadata { + database, + caller_identity, + }; + self.cached.lock().await.clone_from(&Some(metadata)); + finish_authentication(client, &self.parameter_provider).await?; + } + PgWireFrontendMessage::SslRequest(_) => { + let err = PgError::SSLNotSupported; + log::error!("{err}"); + let err = ErrorInfo::new("FATAL".to_owned(), "28P01".to_owned(), err.to_string()); + return close_client(client, err).await; + } + // The other messages are for features not supported by SpacetimeDB, that are rejected by the parser. + _ => { + unreachable!("Unsupported startup message: {message:?}"); + } + } + Ok(()) + } +} + +#[async_trait] +impl SimpleQueryHandler + for PgSpacetimeDB +{ + async fn do_query<'a, C>(&self, _client: &mut C, query: &str) -> PgWireResult>> + where + C: ClientInfo + Unpin + Send + Sync, + { + self.exe_sql(query.to_string()).await + } +} + +#[derive(Clone)] +pub struct PgSpacetimeDBFactory { + handler: Arc>, +} + +impl PgSpacetimeDBFactory { + pub fn new(ctx: T) -> Self { + let mut parameter_provider = DefaultServerParameterProvider::default(); + parameter_provider.server_version = format!("spacetime {}", spacetimedb_lib_version()); + + Self { + handler: Arc::new(PgSpacetimeDB { + ctx, + // This is a placeholder, it will be set in the startup handler + cached: None.into(), + parameter_provider, + }), + } + } +} + +impl PgWireServerHandlers + for PgSpacetimeDBFactory +{ + fn simple_query_handler(&self) -> Arc { + self.handler.clone() + } + + // TODO: fn extended_query_handler(&self) -> Arc {} + + fn startup_handler(&self) -> Arc { + self.handler.clone() + } +} + +pub async fn start_pg( + shutdown: Arc, + ctx: T, + tcp: TcpListener, +) { + let factory = Arc::new(PgSpacetimeDBFactory::new(ctx)); + + log::debug!( + "PG: Starting SpacetimeDB Protocol listening on {}", + tcp.local_addr().unwrap() + ); + loop { + tokio::select! { + accept_result = tcp.accept() => { + match accept_result { + Ok((stream, _addr)) => { + let factory_ref = factory.clone(); + tokio::spawn(async move { + process_socket(stream, None, factory_ref).await.inspect_err(|err|{ + log::error!("PG: Error processing socket: {err:?}"); + }) + }); + } + Err(e) => { + log::error!("PG: Accept error: {e}"); + } + } + } + _ = shutdown.notified() => { + log::info!("PG: Shutting down PostgreSQL server."); + break; + } + } + } +} diff --git a/crates/sats/src/satn.rs b/crates/sats/src/satn.rs index 4462609520e..2a5f128262e 100644 --- a/crates/sats/src/satn.rs +++ b/crates/sats/src/satn.rs @@ -1,13 +1,14 @@ -use crate::de::DeserializeSeed; use crate::time_duration::TimeDuration; use crate::timestamp::Timestamp; -use crate::{i256, u256, AlgebraicValue, WithTypespace}; +use crate::{i256, u256, AlgebraicType, AlgebraicValue, ProductValue, Serialize, SumValue, ValueWithType}; use crate::{ser, ProductType, ProductTypeElement}; use core::fmt; use core::fmt::Write as _; -use derive_more::{From, Into}; +use derive_more::{Display, From, Into}; +use std::borrow::Cow; +use std::marker::PhantomData; -/// An extension trait for [`Serialize`](ser::Serialize) providing formatting methods. +/// An extension trait for [`Serialize`] providing formatting methods. pub trait Satn: ser::Serialize { /// Formats the value using the SATN data format into the formatter `f`. fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -18,9 +19,12 @@ pub trait Satn: ser::Serialize { /// Formats the value using the postgres SATN(PsqlFormatter { f }, /* PsqlType */) formatter `f`. fn fmt_psql(&self, f: &mut fmt::Formatter, ty: &PsqlType<'_>) -> fmt::Result { Writer::with(f, |f| { - self.serialize(PsqlFormatter { - fmt: SatnFormatter { f }, + self.serialize(TypedSerializer { ty, + f: &mut SqlFormatter { + fmt: SatnFormatter { f }, + ty, + }, }) })?; Ok(()) @@ -229,9 +233,30 @@ struct SatnFormatter<'a, 'f> { f: Writer<'a, 'f>, } +impl SatnFormatter<'_, '_> { + fn ser_variant( + &mut self, + _tag: u8, + name: Option<&str>, + value: &T, + ) -> Result<(), SatnError> { + write!(self, "(")?; + EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { + if let Some(name) = name { + write!(f, "{name}")?; + } + write!(f, " = ")?; + value.serialize(SatnFormatter { f })?; + Ok(()) + })?; + write!(self, ")")?; + + Ok(()) + } +} /// An error occurred during serialization to the SATS data format. #[derive(From, Into)] -struct SatnError(fmt::Error); +pub struct SatnError(fmt::Error); impl ser::Error for SatnError { fn custom(_msg: T) -> Self { @@ -331,20 +356,11 @@ impl<'a, 'f> ser::Serializer for SatnFormatter<'a, 'f> { fn serialize_variant( mut self, - _tag: u8, + tag: u8, name: Option<&str>, value: &T, ) -> Result { - write!(self, "(")?; - EntryWrapper::<','>::new(self.f.as_mut()).entry(|mut f| { - if let Some(name) = name { - write!(f, "{name}")?; - } - write!(f, " = ")?; - value.serialize(SatnFormatter { f })?; - Ok(()) - })?; - write!(self, ")") + self.ser_variant(tag, name, value) } } @@ -427,119 +443,41 @@ impl ser::SerializeNamedProduct for NamedFormatter<'_, '_> { } } -struct PsqlEntryWrapper<'a, 'f, const SEP: char> { - entry: EntryWrapper<'a, 'f, SEP>, - /// The index of the element. - idx: usize, - ty: &'a PsqlType<'a>, +/// Which client is used to format the `SQL` output? +#[derive(PartialEq, Copy, Clone, Debug)] +pub enum PsqlClient { + SpacetimeDB, + Postgres, } -/// Provides the data format for named products for `SQL`. -struct PsqlNamedFormatter<'a, 'f> { - /// The formatter for each element separating elements by a `,`. - f: PsqlEntryWrapper<'a, 'f, ','>, - /// If is not [Self::is_special] to control if we start with `(` - start: bool, - /// Remember what format we are using - use_fmt: PsqlPrintFmt, -} - -impl<'a, 'f> PsqlNamedFormatter<'a, 'f> { - pub fn new(ty: &'a PsqlType<'a>, f: Writer<'a, 'f>) -> Self { - Self { - start: true, - f: PsqlEntryWrapper { - entry: EntryWrapper::new(f), - idx: 0, - ty, - }, - // Will set later - use_fmt: PsqlPrintFmt::Satn, - } - } +pub struct PsqlChars { + pub start: char, + pub sep: &'static str, + pub end: char, + pub quote: &'static str, } -impl ser::SerializeNamedProduct for PsqlNamedFormatter<'_, '_> { - type Ok = (); - type Error = SatnError; - - fn serialize_element( - &mut self, - name: Option<&str>, - elem: &T, - ) -> Result<(), Self::Error> { - // For binary data & special types, output in `hex` format and skip the tagging of each value - // We need to check for both the enclosing(`self.f.ty`) type and the inner element(`name`) type. - self.use_fmt = self.f.ty.use_fmt(name); - let res = self.f.entry.entry(|mut f| { - let PsqlType { tuple, field, idx } = self.f.ty; - if !self.use_fmt.is_special() { - if self.start { - write!(f, "(")?; - self.start = false; - } - // Format the name or use the index if unnamed. - if let Some(name) = name { - write!(f, "{name}")?; - } else { - write!(f, "{idx}")?; - } - write!(f, " = ")?; - } - //Is a nested product type? - let (tuple, field, idx) = if let Some(product) = field.algebraic_type.as_product() { - (product, &product.elements[self.f.idx], self.f.idx) - } else { - (*tuple, *field, *idx) - }; - - elem.serialize(PsqlFormatter { - fmt: SatnFormatter { f }, - ty: &PsqlType { tuple, field, idx }, - })?; - - Ok(()) - }); - - // Advance to the next field. - if !self.use_fmt.is_special() { - self.f.idx += 1; - } - - res?; - - Ok(()) - } - - fn end(mut self) -> Result { - if !self.use_fmt.is_special() { - write!(self.f.entry.fmt, ")")?; +impl PsqlClient { + pub fn format_chars(&self) -> PsqlChars { + match self { + PsqlClient::SpacetimeDB => PsqlChars { + start: '(', + sep: " =", + end: ')', + quote: "", + }, + PsqlClient::Postgres => PsqlChars { + start: '{', + sep: ":", + end: '}', + quote: "\"", + }, } - Ok(()) - } -} - -/// Provides the data format for unnamed products for `SQL`. -struct PsqlSeqFormatter<'a, 'f> { - /// Delegates to the named format. - inner: PsqlNamedFormatter<'a, 'f>, -} - -impl ser::SerializeSeqProduct for PsqlSeqFormatter<'_, '_> { - type Ok = (); - type Error = SatnError; - - fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { - ser::SerializeNamedProduct::serialize_element(&mut self.inner, None, elem) - } - - fn end(self) -> Result { - ser::SerializeNamedProduct::end(self.inner) } } /// How format of the `SQL` output? -#[derive(PartialEq)] +#[derive(Debug, Copy, Clone, PartialEq, Display)] pub enum PsqlPrintFmt { /// Print as `hex` format Hex, @@ -552,46 +490,32 @@ pub enum PsqlPrintFmt { } impl PsqlPrintFmt { - fn is_special(&self) -> bool { + pub fn is_special(&self) -> bool { self != &PsqlPrintFmt::Satn } -} - -/// A wrapper that remember the `header` of the tuple/struct and the current field -#[derive(Debug, Clone)] -pub struct PsqlType<'a> { - /// The header of the tuple/struct - pub tuple: &'a ProductType, - /// The current field - pub field: &'a ProductTypeElement, - /// The index of the field in the tuple/struct - pub idx: usize, -} - -impl PsqlType<'_> { /// Returns if the type is a special type /// /// Is required to check both the enclosing type and the inner element type - fn use_fmt(&self, name: Option<&str>) -> PsqlPrintFmt { - if self.tuple.is_identity() - || self.tuple.is_connection_id() - || self.field.algebraic_type.is_identity() - || self.field.algebraic_type.is_connection_id() + pub fn use_fmt(tuple: &ProductType, field: &ProductTypeElement, name: Option<&str>) -> PsqlPrintFmt { + if tuple.is_identity() + || tuple.is_connection_id() + || field.algebraic_type.is_identity() + || field.algebraic_type.is_connection_id() || name.map(ProductType::is_identity_tag).unwrap_or_default() || name.map(ProductType::is_connection_id_tag).unwrap_or_default() { return PsqlPrintFmt::Hex; }; - if self.tuple.is_timestamp() - || self.field.algebraic_type.is_timestamp() + if tuple.is_timestamp() + || field.algebraic_type.is_timestamp() || name.map(ProductType::is_timestamp_tag).unwrap_or_default() { return PsqlPrintFmt::Timestamp; }; - if self.tuple.is_time_duration() - || self.field.algebraic_type.is_time_duration() + if tuple.is_time_duration() + || field.algebraic_type.is_time_duration() || name.map(ProductType::is_time_duration_tag).unwrap_or_default() { return PsqlPrintFmt::Duration; @@ -601,139 +525,386 @@ impl PsqlType<'_> { } } +/// A wrapper that remember the `header` of the tuple/struct and the current field +#[derive(Debug, Clone)] +pub struct PsqlType<'a> { + /// The client used to format the output + pub client: PsqlClient, + /// The header of the tuple/struct + pub tuple: &'a ProductType, + /// The current field + pub field: &'a ProductTypeElement, + /// The index of the field in the tuple/struct + pub idx: usize, +} + +impl PsqlType<'_> { + /// Returns if the type is a special type + /// + /// Is required to check both the enclosing type and the inner element type + pub fn use_fmt(&self) -> PsqlPrintFmt { + PsqlPrintFmt::use_fmt(self.tuple, self.field, None) + } +} + /// An implementation of [`Serializer`](ser::Serializer) for `SQL` output. -struct PsqlFormatter<'a, 'f> { +pub struct SqlFormatter<'a, 'f> { fmt: SatnFormatter<'a, 'f>, ty: &'a PsqlType<'a>, } -impl<'a, 'f> ser::Serializer for PsqlFormatter<'a, 'f> { +/// A trait for writing values, after the special types has been determined. +/// +/// This is used to write values that could have different representations depending on the output format, +/// as defined by [`PsqlClient`] and [`PsqlPrintFmt`]. +pub trait TypedWriter { + type Error: ser::Error; + + /// Writes a value using [`ser::Serializer`] + fn write(&mut self, value: W) -> Result<(), Self::Error>; + + // Values that need special handling: + + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error>; + fn write_string(&mut self, value: &str) -> Result<(), Self::Error>; + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error>; + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error>; + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error>; + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error>; + /// Writes a value as an alternative record format, e.g., for use `JSON` inside `SQL`. + fn write_alt_record( + &mut self, + _ty: &PsqlType, + _value: &ValueWithType<'_, ProductValue>, + ) -> Result { + Ok(false) + } + + fn write_record( + &mut self, + fields: Vec<(Cow, PsqlType, ValueWithType)>, + ) -> Result<(), Self::Error>; + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error>; +} + +/// A formatter for arrays that uses the `TypedWriter` trait to write elements. +pub struct TypedArrayFormatter<'a, 'f, F> { + ty: &'a PsqlType<'a>, + f: &'f mut F, +} + +impl ser::SerializeArray for TypedArrayFormatter<'_, '_, F> { type Ok = (); - type Error = SatnError; - type SerializeArray = ArrayFormatter<'a, 'f>; - type SerializeSeqProduct = PsqlSeqFormatter<'a, 'f>; - type SerializeNamedProduct = PsqlNamedFormatter<'a, 'f>; + type Error = F::Error; + + fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A formatter for sequences that uses the `TypedWriter` trait to write elements. +pub struct TypedSeqFormatter<'a, 'f, F> { + ty: &'a PsqlType<'a>, + f: &'f mut F, +} + +impl ser::SerializeSeqProduct for TypedSeqFormatter<'_, '_, F> { + type Ok = (); + type Error = F::Error; + + fn serialize_element(&mut self, elem: &T) -> Result<(), Self::Error> { + elem.serialize(TypedSerializer { ty: self.ty, f: self.f })?; + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A formatter for named products that uses the `TypedWriter` trait to write elements. +pub struct TypedNamedProductFormatter { + f: PhantomData, +} + +impl ser::SerializeNamedProduct for TypedNamedProductFormatter { + type Ok = (); + type Error = F::Error; + + fn serialize_element( + &mut self, + _name: Option<&str>, + _elem: &T, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn end(self) -> Result { + Ok(()) + } +} + +/// A serializer that uses the `TypedWriter` trait to serialize values +pub struct TypedSerializer<'a, 'f, F> { + pub ty: &'a PsqlType<'a>, + pub f: &'f mut F, +} + +impl<'a, 'f, F: TypedWriter> ser::Serializer for TypedSerializer<'a, 'f, F> { + type Ok = (); + type Error = F::Error; + type SerializeArray = TypedArrayFormatter<'a, 'f, F>; + type SerializeSeqProduct = TypedSeqFormatter<'a, 'f, F>; + type SerializeNamedProduct = TypedNamedProductFormatter; fn serialize_bool(self, v: bool) -> Result { - self.fmt.serialize_bool(v) + self.f.write_bool(v) } + fn serialize_u8(self, v: u8) -> Result { - self.fmt.serialize_u8(v) + self.f.write(v) } + fn serialize_u16(self, v: u16) -> Result { - self.fmt.serialize_u16(v) + self.f.write(v) } + fn serialize_u32(self, v: u32) -> Result { - self.fmt.serialize_u32(v) + self.f.write(v) } + fn serialize_u64(self, v: u64) -> Result { - self.fmt.serialize_u64(v) + self.f.write(v) } + fn serialize_u128(self, v: u128) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()), - _ => self.fmt.serialize_u128(v), + match self.ty.use_fmt() { + PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()), + _ => self.f.write(v), } } + fn serialize_u256(self, v: u256) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Hex => self.serialize_bytes(&v.to_be_bytes()), - _ => self.fmt.serialize_u256(v), + match self.ty.use_fmt() { + PsqlPrintFmt::Hex => self.f.write_hex(&v.to_be_bytes()), + _ => self.f.write(v), } } + fn serialize_i8(self, v: i8) -> Result { - self.fmt.serialize_i8(v) + self.f.write(v) } + fn serialize_i16(self, v: i16) -> Result { - self.fmt.serialize_i16(v) + self.f.write(v) } + fn serialize_i32(self, v: i32) -> Result { - self.fmt.serialize_i32(v) + self.f.write(v) } - fn serialize_i64(mut self, v: i64) -> Result { - match self.ty.use_fmt(None) { - PsqlPrintFmt::Duration => { - write!(self.fmt, "{}", TimeDuration::from_micros(v))?; - Ok(()) - } - PsqlPrintFmt::Timestamp => { - write!(self.fmt, "{}", Timestamp::from_micros_since_unix_epoch(v))?; - Ok(()) - } - _ => self.fmt.serialize_i64(v), + + fn serialize_i64(self, v: i64) -> Result { + match self.ty.use_fmt() { + PsqlPrintFmt::Duration => self.f.write_duration(TimeDuration::from_micros(v)), + PsqlPrintFmt::Timestamp => self.f.write_timestamp(Timestamp::from_micros_since_unix_epoch(v)), + _ => self.f.write(v), } } + fn serialize_i128(self, v: i128) -> Result { - self.fmt.serialize_i128(v) + self.f.write(v) } + fn serialize_i256(self, v: i256) -> Result { - self.fmt.serialize_i256(v) + self.f.write(v) } + fn serialize_f32(self, v: f32) -> Result { - self.fmt.serialize_f32(v) + self.f.write(v) } + fn serialize_f64(self, v: f64) -> Result { - self.fmt.serialize_f64(v) + self.f.write(v) } fn serialize_str(self, v: &str) -> Result { - self.fmt.serialize_str(v) + self.f.write_string(v) } fn serialize_bytes(self, v: &[u8]) -> Result { - self.fmt.serialize_bytes(v) + if self.ty.use_fmt() == PsqlPrintFmt::Satn { + self.f.write_hex(v) + } else { + self.f.write_bytes(v) + } } - fn serialize_array(self, len: usize) -> Result { - self.fmt.serialize_array(len) + fn serialize_array(self, _len: usize) -> Result { + Ok(TypedArrayFormatter { ty: self.ty, f: self.f }) } - fn serialize_seq_product(self, len: usize) -> Result { - Ok(PsqlSeqFormatter { - inner: self.serialize_named_product(len)?, - }) + fn serialize_seq_product(self, _len: usize) -> Result { + Ok(TypedSeqFormatter { ty: self.ty, f: self.f }) } fn serialize_named_product(self, _len: usize) -> Result { - Ok(PsqlNamedFormatter::new(self.ty, self.fmt.f)) + unreachable!("This should never be called, use `serialize_named_product_raw` instead."); + } + + fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result { + let val = &value.val.elements; + assert_eq!(val.len(), value.ty().elements.len()); + // If the value is a special type, we can write it directly + if self.ty.use_fmt().is_special() { + // Is a nested product type? + // We need to check for both the enclosing(`self.ty`) type and the inner element type. + let (tuple, field) = if let Some(product) = self.ty.field.algebraic_type.as_product() { + (product, &product.elements[0]) + } else { + (self.ty.tuple, self.ty.field) + }; + return value.val.serialize(TypedSerializer { + ty: &PsqlType { + client: self.ty.client, + tuple, + field, + idx: self.ty.idx, + }, + f: self.f, + }); + } + // Allow to switch to an alternative record format, for example to write a `JSON` record. + if self.f.write_alt_record(self.ty, value)? { + return Ok(()); + } + let mut record = Vec::with_capacity(val.len()); + + for (idx, (val, field)) in val.iter().zip(&*value.ty().elements).enumerate() { + let ty = PsqlType { + client: self.ty.client, + tuple: value.ty(), + field, + idx, + }; + record.push(( + field + .name() + .map(Cow::from) + .unwrap_or_else(|| Cow::from(format!("col_{idx}"))), + ty, + value.with(&field.algebraic_type, val), + )); + } + self.f.write_record(record) + } + + fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result { + let sv = sum.value(); + let (tag, val) = (sv.tag, &*sv.value); + let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. + let product = ProductType::from([AlgebraicType::sum(sum.ty().clone())]); + let ty = PsqlType { + client: self.ty.client, + tuple: &product, + field: &product.elements[0], + idx: 0, + }; + self.f + .write_variant(tag, ty, var_ty.name(), sum.with(&var_ty.algebraic_type, val)) } - fn serialize_variant( + fn serialize_variant( self, - tag: u8, - name: Option<&str>, - value: &T, + _tag: u8, + _name: Option<&str>, + _value: &T, ) -> Result { - self.fmt.serialize_variant(tag, name, value) + unreachable!("Use `serialize_variant_raw` instead."); + } +} + +impl TypedWriter for SqlFormatter<'_, '_> { + type Error = SatnError; + + fn write(&mut self, value: W) -> Result<(), Self::Error> { + write!(self.fmt, "{value}") } - unsafe fn serialize_bsatn(self, ty: &Ty, bsatn: &[u8]) -> Result - where - for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into>, - { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_bsatn(ty, bsatn) } + fn write_bool(&mut self, value: bool) -> Result<(), Self::Error> { + write!(self.fmt, "{value}") } - unsafe fn serialize_bsatn_in_chunks<'c, Ty, I: Clone + Iterator>( - self, - ty: &Ty, - total_bsatn_len: usize, - bsatn: I, - ) -> Result - where - for<'b, 'de> WithTypespace<'b, Ty>: DeserializeSeed<'de, Output: Into>, - { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_bsatn_in_chunks(ty, total_bsatn_len, bsatn) } - } - - unsafe fn serialize_str_in_chunks<'c, I: Clone + Iterator>( - self, - total_len: usize, - string: I, - ) -> Result { - // SAFETY: Forward caller requirements of this method to that we are calling. - unsafe { self.fmt.serialize_str_in_chunks(total_len, string) } + fn write_string(&mut self, value: &str) -> Result<(), Self::Error> { + write!(self.fmt, "\"{value}\"") + } + + fn write_bytes(&mut self, value: &[u8]) -> Result<(), Self::Error> { + self.write_hex(value) + } + + fn write_hex(&mut self, value: &[u8]) -> Result<(), Self::Error> { + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "0x{}", hex::encode(value)), + PsqlClient::Postgres => write!(self.fmt, "\"0x{}\"", hex::encode(value)), + } + } + + fn write_timestamp(&mut self, value: Timestamp) -> Result<(), Self::Error> { + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "{}", value.to_rfc3339().unwrap()), + PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_rfc3339().unwrap()), + } + } + + fn write_duration(&mut self, value: TimeDuration) -> Result<(), Self::Error> { + match self.ty.client { + PsqlClient::SpacetimeDB => write!(self.fmt, "{value}"), + PsqlClient::Postgres => write!(self.fmt, "\"{}\"", value.to_iso8601()), + } + } + + fn write_record( + &mut self, + fields: Vec<(Cow, PsqlType<'_>, ValueWithType)>, + ) -> Result<(), Self::Error> { + let PsqlChars { start, sep, end, quote } = self.ty.client.format_chars(); + write!(self.fmt, "{start}")?; + for (idx, (name, ty, value)) in fields.into_iter().enumerate() { + if idx > 0 { + write!(self.fmt, ", ")?; + } + write!(self.fmt, "{quote}{name}{quote}{sep} ")?; + + // Serialize the value + value.serialize(TypedSerializer { ty: &ty, f: self })?; + } + write!(self.fmt, "{end}")?; + Ok(()) + } + + fn write_variant( + &mut self, + tag: u8, + ty: PsqlType, + name: Option<&str>, + value: ValueWithType, + ) -> Result<(), Self::Error> { + self.write_record(vec![( + name.map(Cow::from).unwrap_or_else(|| Cow::from(format!("col_{tag}"))), + ty, + value, + )]) } } diff --git a/crates/sats/src/ser.rs b/crates/sats/src/ser.rs index a9fd0fe01f8..4d56fc5aabc 100644 --- a/crates/sats/src/ser.rs +++ b/crates/sats/src/ser.rs @@ -6,7 +6,7 @@ mod impls; pub mod serde; use crate::de::DeserializeSeed; -use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter}; +use crate::{algebraic_value::ser::ValueSerializer, bsatn, buffer::BufWriter, ProductValue, SumValue, ValueWithType}; use crate::{AlgebraicValue, WithTypespace}; use core::marker::PhantomData; use core::{convert::Infallible, fmt}; @@ -117,6 +117,31 @@ pub trait Serializer: Sized { /// The argument is the number of fields in the product. fn serialize_named_product(self, len: usize) -> Result; + /// Serialize a product with named fields. + /// + /// Allow to override the default serialization for where we need to switch the output format, + /// see [`crate::satn::TypedWriter`]. + fn serialize_named_product_raw(self, value: &ValueWithType<'_, ProductValue>) -> Result { + let val = &value.val.elements; + assert_eq!(val.len(), value.ty().elements.len()); + let mut prod = self.serialize_named_product(val.len())?; + for (val, el_ty) in val.iter().zip(&*value.ty().elements) { + prod.serialize_element(el_ty.name(), &value.with(&el_ty.algebraic_type, val))? + } + prod.end() + } + + /// Serialize a sum value + /// + /// Allow to override the default serialization for where we need to switch the output format, + /// see [`crate::satn::TypedWriter`]. + fn serialize_variant_raw(self, sum: &ValueWithType<'_, SumValue>) -> Result { + let sv = sum.value(); + let (tag, val) = (sv.tag, &*sv.value); + let var_ty = &sum.ty().variants[tag as usize]; // Extract the variant type by tag. + self.serialize_variant(tag, var_ty.name(), &sum.with(&var_ty.algebraic_type, val)) + } + /// Serialize a sum value provided the chosen `tag`, `name`, and `value`. fn serialize_variant( self, diff --git a/crates/sats/src/ser/impls.rs b/crates/sats/src/ser/impls.rs index 914099e14c5..39d08d38d64 100644 --- a/crates/sats/src/ser/impls.rs +++ b/crates/sats/src/ser/impls.rs @@ -1,4 +1,4 @@ -use super::{Serialize, SerializeArray, SerializeNamedProduct, SerializeSeqProduct, Serializer}; +use super::{Serialize, SerializeArray, SerializeSeqProduct, Serializer}; use crate::{i256, u256}; use crate::{AlgebraicType, AlgebraicValue, ArrayValue, ProductValue, SumValue, ValueWithType, F32, F64}; use core::ops::Bound; @@ -190,19 +190,10 @@ impl_serialize!( } ); impl_serialize!([] ValueWithType<'_, SumValue>, (self, ser) => { - let sv = self.value(); - let (tag, val) = (sv.tag, &*sv.value); - let var_ty = &self.ty().variants[tag as usize]; // Extract the variant type by tag. - ser.serialize_variant(tag, var_ty.name(), &self.with(&var_ty.algebraic_type, val)) + ser.serialize_variant_raw(self) }); impl_serialize!([] ValueWithType<'_, ProductValue>, (self, ser) => { - let val = &self.value().elements; - assert_eq!(val.len(), self.ty().elements.len()); - let mut prod = ser.serialize_named_product(val.len())?; - for (val, el_ty) in val.iter().zip(&*self.ty().elements) { - prod.serialize_element(el_ty.name(), &self.with(&el_ty.algebraic_type, val))? - } - prod.end() + ser.serialize_named_product_raw(self) }); impl_serialize!([] ValueWithType<'_, ArrayValue>, (self, ser) => { let mut ty = &*self.ty().elem_ty; diff --git a/crates/sats/src/time_duration.rs b/crates/sats/src/time_duration.rs index bd8d4543fbc..f6d932c4bd8 100644 --- a/crates/sats/src/time_duration.rs +++ b/crates/sats/src/time_duration.rs @@ -82,6 +82,22 @@ impl TimeDuration { pub fn checked_sub(self, other: Self) -> Option { self.to_micros().checked_sub(other.to_micros()).map(Self::from_micros) } + + /// Generate an `iso8601` format string. + /// + /// This is the better supported format for use for the `pg wire protocol`. + /// + /// Example: + /// ```rust + /// use std::time::Duration; + /// use spacetimedb_sats::time_duration::TimeDuration; + /// assert_eq!( TimeDuration::from_micros(0).to_iso8601().as_str(), "P0D"); + /// assert_eq!( TimeDuration::from_micros(-1_000_000).to_iso8601().as_str(), "-PT1S"); + /// assert_eq!( TimeDuration::from_duration(Duration::from_secs(60 * 24)).to_iso8601().as_str(), "PT1440S"); + /// ``` + pub fn to_iso8601(self) -> String { + chrono::Duration::microseconds(self.to_micros()).to_string() + } } impl From for TimeDuration { diff --git a/crates/sats/src/timestamp.rs b/crates/sats/src/timestamp.rs index 5da2e73964a..50affcf6094 100644 --- a/crates/sats/src/timestamp.rs +++ b/crates/sats/src/timestamp.rs @@ -171,13 +171,17 @@ impl Timestamp { pub fn checked_sub_duration(&self, duration: Duration) -> Option { self.checked_sub(TimeDuration::from_duration(duration)) } - /// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`. - pub fn to_rfc3339(&self) -> anyhow::Result { + + pub fn to_chrono_date_time(&self) -> anyhow::Result> { DateTime::from_timestamp_micros(self.to_micros_since_unix_epoch()) - .map(|t| t.to_rfc3339()) .ok_or_else(|| anyhow::anyhow!("Timestamp with i64 microseconds since Unix epoch overflows DateTime")) .with_context(|| self.to_micros_since_unix_epoch()) } + + /// Returns an RFC 3339 and ISO 8601 date and time string such as `1996-12-19T16:39:57-08:00`. + pub fn to_rfc3339(&self) -> anyhow::Result { + Ok(self.to_chrono_date_time()?.to_rfc3339()) + } } impl Add for Timestamp { diff --git a/crates/schema/src/auto_migrate.rs b/crates/schema/src/auto_migrate.rs index c1d7e71ae4d..7a0f665fe7b 100644 --- a/crates/schema/src/auto_migrate.rs +++ b/crates/schema/src/auto_migrate.rs @@ -8,13 +8,14 @@ use spacetimedb_data_structures::{ }; use spacetimedb_lib::{ db::raw_def::v9::{RawRowLevelSecurityDefV9, TableType}, - AlgebraicType, + hash_bytes, AlgebraicType, Identity, }; use spacetimedb_sats::{ layout::{HasLayout, SumTypeLayout}, WithTypespace, }; use termcolor_formatter::{ColorScheme, TermColorFormatter}; +use thiserror::Error; mod formatter; mod termcolor_formatter; @@ -50,9 +51,19 @@ impl<'def> MigratePlan<'def> { } } + pub fn breaks_client(&self) -> bool { + match self { + //TODO: fix it when support for manual migration plans is added. + MigratePlan::Manual(_) => true, + MigratePlan::Auto(plan) => plan + .steps + .iter() + .any(|step| matches!(step, AutoMigrateStep::DisconnectAllUsers)), + } + } + pub fn pretty_print(&self, style: PrettyPrintStyle) -> anyhow::Result { use PrettyPrintStyle::*; - match self { MigratePlan::Manual(_) => { anyhow::bail!("Manual migration plans are not yet supported for pretty printing.") @@ -73,6 +84,101 @@ impl<'def> MigratePlan<'def> { } } +/// A migration policy that determines whether a module update is allowed to break client compatibility. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum MigrationPolicy { + /// Migration must maintain backward compatibility with existing clients. + Compatible, + /// To use this, a valid [`MigrationToken`] must be provided. + /// The token is issued through the pre-publish API (see the `client-api` crate) + /// and proves that the publisher explicitly acknowledged the breaking change. + BreakClients(spacetimedb_lib::Hash), +} + +impl MigrationPolicy { + /// Verifies whether the given migration plan is allowed under the current policy. + /// + /// Returns `Ok(())` if allowed, otherwise an appropriate `MigrationPolicyError` + fn permits_plan(&self, plan: &MigratePlan<'_>, token: &MigrationToken) -> anyhow::Result<(), MigrationPolicyError> { + match self { + MigrationPolicy::Compatible => { + if plan.breaks_client() { + Err(MigrationPolicyError::ClientBreakingChangeDisallowed) + } else { + Ok(()) + } + } + MigrationPolicy::BreakClients(expected_hash) => { + if token.hash() == *expected_hash { + Ok(()) + } else { + Err(MigrationPolicyError::InvalidToken) + } + } + } + } + + /// Attempts to generate a migration plan and validate it under this policy. + /// + /// Fails if migration is not permitted by the policy or migration planning fails. + pub fn try_migrate<'def>( + &self, + database_identity: Identity, + old_module_hash: spacetimedb_lib::Hash, + old_module_def: &'def ModuleDef, + new_module_hash: spacetimedb_lib::Hash, + new_module_def: &'def ModuleDef, + ) -> anyhow::Result, MigrationPolicyError> { + let plan = ponder_migrate(old_module_def, new_module_def).map_err(MigrationPolicyError::AutoMigrateFailure)?; + + let token = MigrationToken { + database_identity, + old_module_hash, + new_module_hash, + }; + self.permits_plan(&plan, &token)?; + Ok(plan) + } +} + +#[derive(Debug, Error)] +pub enum MigrationPolicyError { + #[error("Automatic migration planning failed")] + AutoMigrateFailure(ErrorStream), + + #[error("Token provided is invalid or does not match expected hash")] + InvalidToken, + + #[error("Migration plan contains a client-breaking change which is disallowed under current policy")] + ClientBreakingChangeDisallowed, +} + +/// A token acknowledging a breaking migration. +/// +/// Note: This token is only intended as a UX safeguard, not as a security measure. +/// No secret is used in its generation, which means anyone can reproduce it given +/// the inputs. That is acceptable for our purposes since it only signals user intent, +/// not authorization. +pub struct MigrationToken { + pub database_identity: Identity, + pub old_module_hash: spacetimedb_lib::Hash, + pub new_module_hash: spacetimedb_lib::Hash, +} + +impl MigrationToken { + pub fn hash(&self) -> spacetimedb_lib::Hash { + hash_bytes( + format!( + "{}{}{}", + self.database_identity.to_hex(), + self.old_module_hash.to_hex(), + self.new_module_hash.to_hex() + ) + .as_str(), + ) + } +} + /// A plan for a manual migration. /// `new` must have a reducer marked with `Lifecycle::Update`. #[derive(Debug)] diff --git a/crates/standalone/Cargo.toml b/crates/standalone/Cargo.toml index 4e74fd7316d..e719721721d 100644 --- a/crates/standalone/Cargo.toml +++ b/crates/standalone/Cargo.toml @@ -27,7 +27,9 @@ spacetimedb-core.workspace = true spacetimedb-datastore.workspace = true spacetimedb-lib.workspace = true spacetimedb-paths.workspace = true +spacetimedb-pg.workspace = true spacetimedb-table.workspace = true +spacetimedb-schema.workspace = true anyhow.workspace = true async-trait.workspace = true diff --git a/crates/standalone/src/lib.rs b/crates/standalone/src/lib.rs index cb179bb258e..860dddd682f 100644 --- a/crates/standalone/src/lib.rs +++ b/crates/standalone/src/lib.rs @@ -5,7 +5,7 @@ pub mod version; use crate::control_db::ControlDb; use crate::subcommands::{extract_schema, start}; -use anyhow::{ensure, Context, Ok}; +use anyhow::{ensure, Context as _, Ok}; use async_trait::async_trait; use clap::{ArgMatches, Command}; use spacetimedb::client::ClientActorIndex; @@ -13,7 +13,8 @@ use spacetimedb::config::{CertificateAuthority, MetadataFile}; use spacetimedb::db::{self, relational_db}; use spacetimedb::energy::{EnergyBalance, EnergyQuanta, NullEnergyMonitor}; use spacetimedb::host::{ - DiskStorage, DurabilityProvider, ExternalDurability, HostController, StartSnapshotWatcher, UpdateDatabaseResult, + DiskStorage, DurabilityProvider, ExternalDurability, HostController, MigratePlanResult, StartSnapshotWatcher, + UpdateDatabaseResult, }; use spacetimedb::identity::Identity; use spacetimedb::messages::control_db::{Database, Node, Replica}; @@ -28,6 +29,7 @@ use spacetimedb_datastore::db_metrics::DB_METRICS; use spacetimedb_datastore::traits::Program; use spacetimedb_paths::server::{ModuleLogsDir, PidFile, ServerDataDir}; use spacetimedb_paths::standalone::StandaloneDataDirExt; +use spacetimedb_schema::auto_migrate::{MigrationPolicy, PrettyPrintStyle}; use spacetimedb_table::page_pool::PagePool; use std::sync::Arc; @@ -184,6 +186,7 @@ impl spacetimedb_client_api::ControlStateReadAccess for StandaloneEnv { id: 0, unschedulable: false, advertise_addr: Some("node:80".to_owned()), + pg_addr: Some("node:5432".to_owned()), })); } Ok(None) @@ -239,6 +242,7 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { &self, publisher: &Identity, spec: spacetimedb_client_api::DatabaseDef, + policy: MigrationPolicy, ) -> anyhow::Result> { let existing_db = self.control_db.get_database_by_identity(&spec.database_identity)?; @@ -295,7 +299,7 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { .await? .ok_or_else(|| anyhow::anyhow!("No leader for database"))?; let update_result = leader - .update(database, spec.host_type, spec.program_bytes.into()) + .update(database, spec.host_type, spec.program_bytes.into(), policy) .await?; if update_result.was_successful() { let replicas = self.control_db.get_replicas_by_database(database_id)?; @@ -345,6 +349,30 @@ impl spacetimedb_client_api::ControlStateWriteAccess for StandaloneEnv { } } + async fn migrate_plan( + &self, + spec: spacetimedb_client_api::DatabaseDef, + style: PrettyPrintStyle, + ) -> anyhow::Result { + let existing_db = self.control_db.get_database_by_identity(&spec.database_identity)?; + + match existing_db { + Some(db) => { + let host = self + .leader(db.id) + .await? + .ok_or_else(|| anyhow::anyhow!("No leader for database"))?; + self.host_controller + .migrate_plan(db, spec.host_type, host.replica_id, spec.program_bytes.into(), style) + .await + } + None => anyhow::bail!( + "Database `{}` does not exist", + spec.database_identity.to_abbreviated_hex() + ), + } + } + async fn delete_database(&self, caller_identity: &Identity, database_identity: &Identity) -> anyhow::Result<()> { let Some(database) = self.control_db.get_database_by_identity(database_identity)? else { return Ok(()); diff --git a/crates/standalone/src/subcommands/start.rs b/crates/standalone/src/subcommands/start.rs index 10a52830053..5615945e547 100644 --- a/crates/standalone/src/subcommands/start.rs +++ b/crates/standalone/src/subcommands/start.rs @@ -1,3 +1,4 @@ +use spacetimedb_pg::pg_server; use std::sync::Arc; use crate::{StandaloneEnv, StandaloneOptions}; @@ -176,12 +177,27 @@ pub async fn exec(args: &ArgMatches, db_cores: JobCores) -> anyhow::Result<()> { db_routes.root_post = db_routes.root_post.layer(DefaultBodyLimit::disable()); db_routes.db_put = db_routes.db_put.layer(DefaultBodyLimit::disable()); let extra = axum::Router::new().nest("/health", spacetimedb_client_api::routes::health::router()); - let service = router(&ctx, db_routes, extra).with_state(ctx); + let service = router(&ctx, db_routes, extra).with_state(ctx.clone()); let tcp = TcpListener::bind(listen_addr).await?; socket2::SockRef::from(&tcp).set_nodelay(true)?; - log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr().unwrap()); - axum::serve(tcp, service).await?; + log::debug!("Starting SpacetimeDB listening on {}", tcp.local_addr()?); + let pg_server_addr = format!("{}:5432", listen_addr.split(':').next().unwrap()); + let tcp_pg = TcpListener::bind(pg_server_addr).await?; + + let notify = Arc::new(tokio::sync::Notify::new()); + let shutdown_notify = notify.clone(); + tokio::select! { + _ = pg_server::start_pg(notify.clone(), ctx, tcp_pg) => {}, + _ = axum::serve(tcp, service).with_graceful_shutdown(async move { + shutdown_notify.notified().await; + }) => {}, + _ = tokio::signal::ctrl_c() => { + println!("Shutting down servers..."); + notify.notify_waiters(); // Notify all tasks + } + } + Ok(()) } diff --git a/crates/testing/src/modules.rs b/crates/testing/src/modules.rs index 1df5f6e4005..1e7fb48f637 100644 --- a/crates/testing/src/modules.rs +++ b/crates/testing/src/modules.rs @@ -12,6 +12,7 @@ use spacetimedb::Identity; use spacetimedb_client_api::auth::SpacetimeAuth; use spacetimedb_client_api::routes::subscribe::{generate_random_connection_id, WebSocketOptions}; use spacetimedb_paths::{RootDir, SpacetimePaths}; +use spacetimedb_schema::auto_migrate::MigrationPolicy; use spacetimedb_schema::def::ModuleDef; use tokio::runtime::{Builder, Runtime}; @@ -205,6 +206,7 @@ impl CompiledModule { num_replicas: None, host_type: HostType::Wasm, }, + MigrationPolicy::Compatible, ) .await .unwrap(); diff --git a/docker-compose.yml b/docker-compose.yml index 8e2dd39fb87..a31ed13d89a 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -25,6 +25,8 @@ services: - /stdb ports: - "3000:3000" + # Postgres + - "5432:5432" # Tracy - "8086:8086" entrypoint: cargo watch -i flamegraphs -i log.conf --why -C crates/standalone -x 'run start --data-dir=/stdb/data --jwt-pub-key-path=/etc/spacetimedb/id_ecdsa.pub --jwt-priv-key-path=/etc/spacetimedb/id_ecdsa' diff --git a/docs/docs/cli-reference.md b/docs/docs/cli-reference.md index a346ed23744..42d61171e49 100644 --- a/docs/docs/cli-reference.md +++ b/docs/docs/cli-reference.md @@ -214,6 +214,7 @@ Runs a SQL query on the database. WARNING: This command is UNSTABLE and subject ###### Options: * `--interactive` — Instead of using a query, run an interactive command prompt for `SQL` expressions +* `--confirmed` — Instruct the server to deliver only updates of confirmed transactions * `--anonymous` — Perform this action with an anonymous identity * `-s`, `--server ` — The nickname, host name or URL of the server hosting the database * `-y`, `--yes` — Run non-interactively wherever possible. This will answer "yes" to almost all prompts, but will sometimes answer "no" to preserve non-interactivity (e.g. when prompting whether to log in with spacetimedb.com). @@ -479,6 +480,7 @@ Subscribe to SQL queries on the database. WARNING: This command is UNSTABLE and * `-t`, `--timeout ` — The timeout, in seconds, after which to disconnect and stop receiving subscription messages. If `-n` is specified, it will stop after whichever one comes first. * `--print-initial-update` — Print the initial update for the queries. +* `--confirmed` — Instruct the server to deliver only updates of confirmed transactions * `--anonymous` — Perform this action with an anonymous identity * `-y`, `--yes` — Run non-interactively wherever possible. This will answer "yes" to almost all prompts, but will sometimes answer "no" to preserve non-interactivity (e.g. when prompting whether to log in with spacetimedb.com). * `-s`, `--server ` — The nickname, host name or URL of the server hosting the database diff --git a/sdks/rust/src/db_connection.rs b/sdks/rust/src/db_connection.rs index 10bcb149320..7306f45160e 100644 --- a/sdks/rust/src/db_connection.rs +++ b/sdks/rust/src/db_connection.rs @@ -962,6 +962,24 @@ but you must call one of them, or else the connection will never progress. self } + /// Sets whether to use confirmed reads. + /// + /// When enabled, the server will send query results only after they are + /// confirmed to be durable. + /// + /// What durable means depends on the server configuration: a single node + /// server may consider a transaction durable once it is `fsync`'ed to disk, + /// a cluster after some number of replicas have acknowledged that they + /// have stored the transaction. + /// + /// Note that enabling confirmed reads will increase the latency between a + /// reducer call and the corresponding subscription update arriving at the + /// client. + pub fn with_confirmed_reads(mut self, confirmed: bool) -> Self { + self.params.confirmed = confirmed; + self + } + /// Register a callback to run when the connection is successfully initiated. /// /// The callback will receive three arguments: diff --git a/sdks/rust/src/websocket.rs b/sdks/rust/src/websocket.rs index 0595b3dfcc9..09fa7680633 100644 --- a/sdks/rust/src/websocket.rs +++ b/sdks/rust/src/websocket.rs @@ -107,6 +107,8 @@ fn parse_scheme(scheme: Option) -> Result { pub(crate) struct WsParams { pub compression: Compression, pub light: bool, + /// `true` to enable confirmed reads for the connection. + pub confirmed: bool, } fn make_uri(host: Uri, db_name: &str, connection_id: Option, params: WsParams) -> Result { @@ -152,6 +154,11 @@ fn make_uri(host: Uri, db_name: &str, connection_id: Option, param path.push_str("&light=true"); } + // Enable confirmed reads if requested. + if params.confirmed { + path.push_str("&confirmed=true"); + } + parts.path_and_query = Some(path.parse().map_err(|source: InvalidUri| UriError::InvalidUri { source: Arc::new(source), })?); diff --git a/smoketests/__init__.py b/smoketests/__init__.py index 4c4dc6e8954..58e61ffa72d 100644 --- a/smoketests/__init__.py +++ b/smoketests/__init__.py @@ -118,6 +118,17 @@ def extract_fields(cmd_output, field_name): out.append(val) return out +def parse_sql_result(res: str) -> list[dict]: + """Parse tabular output from an SQL query into a list of dicts.""" + lines = res.splitlines() + headers = lines[0].split('|') if '|' in lines[0] else [lines[0]] + headers = [header.strip() for header in headers] + rows = [] + for row in lines[2:]: + cols = [col.strip() for col in row.split('|')] + rows.append(dict(zip(headers, cols))) + return rows + def extract_field(cmd_output, field_name): field, = extract_fields(cmd_output, field_name) return field @@ -232,11 +243,22 @@ def fingerprint(self): def new_identity(self): new_identity(self.__class__.config_path) - def subscribe(self, *queries, n): + def subscribe(self, *queries, n, confirmed = False): self._check_published() assert isinstance(n, int) - args = [SPACETIME_BIN, "--config-path", str(self.config_path),"subscribe", self.database_identity, "-t", "600", "-n", str(n), "--print-initial-update", "--", *queries] + args = [ + SPACETIME_BIN, + "--config-path", str(self.config_path), + "subscribe", self.database_identity, + "-t", "600", + "-n", str(n), + "--print-initial-update", + ] + if confirmed: + args.append("--confirmed") + args.extend(["--", *queries]) + fake_args = ["spacetime", *args[1:]] log_cmd(fake_args) diff --git a/smoketests/config.toml b/smoketests/config.toml index bc7409327ef..3bc097d0d79 100644 --- a/smoketests/config.toml +++ b/smoketests/config.toml @@ -1,5 +1,5 @@ default_server = "localhost" -spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwYzc3NDY1NTE5MDM2MTE4M2JiNjFmMWMxYzY3NDUzMzYzY2MxMTY4MmM1NTUwNWZiNjdlYzI0ZWMyMWViIiwic3ViIjoiOTJlMmNkOGQtNTk5Ny00NjZlLWIwNmYtZDNjOGQ1NzU3ODI4IiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc1MjA0NjgwMCwiZXhwIjpudWxsfQ.dgefoxC7eCOONVUufu2JTVFo9876zQ4Mqwm0ivZ0PQK7Hacm3Ip_xqyav4bilZ0vIEf8IM8AB0_xawk8WcbvMg" +spacetimedb_token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJoZXhfaWRlbnRpdHkiOiJjMjAwMTU3NGEwMjgyNDRjNzZhNTE1MjU1NGMzY2ZjMWJiNmIzNzZlNjY4YmU1Yjg2MzE0MDAyYWRmOTMyYWVlIiwic3ViIjoiYzYwOWJkYjUtMDAyNS00YzZkLWIyZTktOGYyODEwM2IzNWUzIiwiaXNzIjoibG9jYWxob3N0IiwiYXVkIjpbInNwYWNldGltZWRiIl0sImlhdCI6MTc1NjkwOTQ3NywiZXhwIjpudWxsfQ.t6Aobx9fTe6kwvq7H01-2RO7vdK4SjQB7Uw-Lh4Daz0lG43WzIw3oVG_65txqlsFSkpx40wYElByj4jMolutpA" [[server_configs]] nickname = "localhost" diff --git a/smoketests/tests/confirmed_reads.py b/smoketests/tests/confirmed_reads.py new file mode 100644 index 00000000000..4d8a844c4bf --- /dev/null +++ b/smoketests/tests/confirmed_reads.py @@ -0,0 +1,52 @@ +from .. import Smoketest, parse_sql_result + +# +# TODO: We only test that we can pass a --confirmed flag and that things +# appear to works as if we hadn't. Without controlling the server, we can't +# test that there is any difference in behavior. +# + +class ConfirmedReads(Smoketest): + def test_confirmed_reads_receive_updates(self): + """Tests that subscribing with confirmed=true receives updates""" + + sub = self.subscribe("select * from person", n = 2, confirmed = True) + self.call("add", "Horst") + self.spacetime( + "sql", + self.database_identity, + "insert into person (name) values ('Egon')") + + events = sub() + self.assertEqual([ + { + 'person': { + 'deletes': [], + 'inserts': [{'name': 'Horst'}] + } + }, + { + 'person': { + 'deletes': [], + 'inserts': [{'name': 'Egon'}] + } + } + ], events) + +class ConfirmedReadsSql(Smoketest): + def test_sql_with_confirmed_reads_receives_result(self): + """Tests that an SQL operations with confirmed=true returns a result""" + + self.spacetime( + "sql", + "--confirmed", + self.database_identity, + "insert into person (name) values ('Horst')") + + res = self.spacetime( + "sql", + "--confirmed", + self.database_identity, + "select * from person") + res = parse_sql_result(str(res)) + self.assertEqual([{'name': '"Horst"'}], res) diff --git a/smoketests/tests/pg_wire.py b/smoketests/tests/pg_wire.py new file mode 100644 index 00000000000..89c7880c33c --- /dev/null +++ b/smoketests/tests/pg_wire.py @@ -0,0 +1,293 @@ +from .. import Smoketest +import subprocess +import os +import tomllib +import psycopg2 + + +def psql(identity: str, sql: str, extra=None) -> str: + """Call `psql` and execute the given SQL statement.""" + if extra is None: + extra = dict() + result = subprocess.run( + ["psql", "-h", "127.0.0.1", "-p", "5432", "-U", "postgres", "-d", "quickstart", "--quiet", "-c", sql], + encoding="utf8", + env={**os.environ, **extra, "PGPASSWORD": identity}, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + + if result.stderr: + raise Exception(result.stderr.strip()) + return result.stdout.strip() + + +def connect_db(identity: str): + """Connect to the database using `psycopg2`.""" + conn = psycopg2.connect(host="127.0.0.1", port=5432, user="postgres", password=identity, dbname="quickstart") + conn.set_session(autocommit=True) # Disable automic transaction + return conn + + +class SqlFormat(Smoketest): + AUTOPUBLISH = False + MODULE_CODE = """ +use spacetimedb::sats::{i256, u256}; +use spacetimedb::{ConnectionId, Identity, ReducerContext, SpacetimeType, Table, Timestamp, TimeDuration}; + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_ints, public)] +pub struct TInts { + i8: i8, + i16: i16, + i32: i32, + i64: i64, + i128: i128, + i256: i256, +} + +#[spacetimedb::table(name = t_ints_tuple, public)] +pub struct TIntsTuple { + tuple: TInts, +} + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_uints, public)] +pub struct TUints { + u8: u8, + u16: u16, + u32: u32, + u64: u64, + u128: u128, + u256: u256, +} + +#[spacetimedb::table(name = t_uints_tuple, public)] +pub struct TUintsTuple { + tuple: TUints, +} + +#[derive(Clone)] +#[spacetimedb::table(name = t_others, public)] +pub struct TOthers { + bool: bool, + f32: f32, + f64: f64, + str: String, + bytes: Vec, + identity: Identity, + connection_id: ConnectionId, + timestamp: Timestamp, + duration: TimeDuration, +} + +#[spacetimedb::table(name = t_others_tuple, public)] +pub struct TOthersTuple { + tuple: TOthers +} + +#[derive(SpacetimeType, Debug, Clone, Copy)] +pub enum Action { + Inactive, + Active, +} + +#[derive(SpacetimeType, Debug, Clone, Copy)] +pub enum Color { + Gray(u8), +} + +#[derive(Copy, Clone)] +#[spacetimedb::table(name = t_simple_enum, public)] +pub struct TSimpleEnum { + id : u32, + action: Action, +} + +#[spacetimedb::table(name = t_enum, public)] +pub struct TEnum { + id : u32, + color: Color, +} + +#[spacetimedb::table(name = t_nested, public)] +pub struct TNested { + en: TEnum, + se: TSimpleEnum, + ints: TInts, +} + +#[spacetimedb::reducer] +pub fn test(ctx: &ReducerContext) { + let tuple = TInts { + i8: -25, + i16: -3224, + i32: -23443, + i64: -2344353, + i128: -234434897853, + i256: (-234434897853i128).into(), + }; + let ints = tuple; + ctx.db.t_ints().insert(tuple); + ctx.db.t_ints_tuple().insert(TIntsTuple { tuple }); + + let tuple = TUints { + u8: 105, + u16: 1050, + u32: 83892, + u64: 48937498, + u128: 4378528978889, + u256: 4378528978889u128.into(), + }; + ctx.db.t_uints().insert(tuple); + ctx.db.t_uints_tuple().insert(TUintsTuple { tuple }); + + let tuple = TOthers { + bool: true, + f32: 594806.58906, + f64: -3454353.345389043278459, + str: "This is spacetimedb".to_string(), + bytes: vec!(1, 2, 3, 4, 5, 6, 7), + identity: Identity::ONE, + connection_id: ConnectionId::ZERO, + timestamp: Timestamp::UNIX_EPOCH, + duration: TimeDuration::from_micros(1000 * 10000), + }; + ctx.db.t_others().insert(tuple.clone()); + ctx.db.t_others_tuple().insert(TOthersTuple { tuple }); + + ctx.db.t_simple_enum().insert(TSimpleEnum { id: 1, action: Action::Inactive }); + ctx.db.t_simple_enum().insert(TSimpleEnum { id: 2, action: Action::Active }); + + ctx.db.t_enum().insert(TEnum { id: 1, color: Color::Gray(128) }); + + ctx.db.t_nested().insert(TNested { + en: TEnum { id: 1, color: Color::Gray(128) }, + se: TSimpleEnum { id: 2, action: Action::Active }, + ints, + }); +} +""" + + def assertSql(self, token: str, sql: str, expected): + self.maxDiff = None + sql_out = psql(token, sql) + sql_out = "\n".join([line.rstrip() for line in sql_out.splitlines()]) + expected = "\n".join([line.rstrip() for line in expected.splitlines()]) + print(sql_out) + self.assertMultiLineEqual(sql_out, expected) + + def read_token(self): + """Read the token from the config file.""" + with open(self.config_path, "rb") as f: + config = tomllib.load(f) + return config['spacetimedb_token'] + + def test_sql_format(self): + """This test is designed to test calling `psql` to execute SQL statements""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + + self.call("test") + + self.assertSql(token, "SELECT * FROM t_ints", """\ +i8 | i16 | i32 | i64 | i128 | i256 +-----+-------+--------+----------+---------------+--------------- + -25 | -3224 | -23443 | -2344353 | -234434897853 | -234434897853 +(1 row)""") + self.assertSql(token, "SELECT * FROM t_ints_tuple", """\ +tuple +--------------------------------------------------------------------------------------------------------- + {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_uints", """\ +u8 | u16 | u32 | u64 | u128 | u256 +-----+------+-------+----------+---------------+--------------- + 105 | 1050 | 83892 | 48937498 | 4378528978889 | 4378528978889 +(1 row)""") + self.assertSql(token, "SELECT * FROM t_uints_tuple", """\ +tuple +------------------------------------------------------------------------------------------------------- + {"u8": 105, "u16": 1050, "u32": 83892, "u64": 48937498, "u128": 4378528978889, "u256": 4378528978889} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_others", """\ +bool | f32 | f64 | str | bytes | identity | connection_id | timestamp | duration +------+-----------+---------------------+---------------------+------------------+--------------------------------------------------------------------+------------------------------------+---------------------------+---------- + t | 594806.56 | -3454353.3453890434 | This is spacetimedb | \\x01020304050607 | \\x0000000000000000000000000000000000000000000000000000000000000001 | \\x00000000000000000000000000000000 | 1970-01-01T00:00:00+00:00 | PT10S +(1 row)""") + self.assertSql(token, "SELECT * FROM t_others_tuple", """\ +tuple +--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + {"bool": true, "f32": 594806.56, "f64": -3454353.3453890434, "str": "This is spacetimedb", "bytes": "0x01020304050607", "identity": "0x0000000000000000000000000000000000000000000000000000000000000001", "connection_id": "0x00000000000000000000000000000000", "timestamp": "1970-01-01T00:00:00+00:00", "duration": "PT10S"} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_simple_enum", """\ +id | action +----+---------- + 1 | Inactive + 2 | Active +(2 rows)""") + self.assertSql(token, "SELECT * FROM t_enum", """\ +id | color +----+--------------- + 1 | {"Gray": 128} +(1 row)""") + self.assertSql(token, "SELECT * FROM t_nested", """\ +en | se | ints +-----------------------------------+-------------------------------------+--------------------------------------------------------------------------------------------------------- + {"id": 1, "color": {"Gray": 128}} | {"id": 2, "action": {"Active": {}}} | {"i8": -25, "i16": -3224, "i32": -23443, "i64": -2344353, "i128": -234434897853, "i256": -234434897853} +(1 row)""") + + def test_sql_conn(self): + """This test is designed to test connecting to the database and executing queries using `psycopg2`""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + self.call("test") + + conn = connect_db(token) + # Check prepared statements (faked by `psycopg2`) + with conn.cursor() as cur: + cur.execute("select * from t_uints where u8 = %s and u16 = %s", (105, 1050)) + rows = cur.fetchall() + self.assertEqual(rows[0], (105, 1050, 83892, 48937498, 4378528978889, 4378528978889)) + # Check long-lived connection + with conn.cursor() as cur: + for _ in range(10): + cur.execute("select count(*) as t from t_uints") + rows = cur.fetchall() + self.assertEqual(rows[0], (1,)) + conn.close() + + def test_failures(self): + """This test is designed to test failure cases""" + token = self.read_token() + self.publish_module("quickstart", clear=True) + + # Empty query + sql_out = psql(token, "") + self.assertEqual(sql_out, "") + + # Connection fails when `ssl` is required + for ssl_mode in ["require", "verify-ca", "verify-full"]: + with self.assertRaises(Exception) as cm: + psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode}) + self.assertIn("not support SSL", str(cm.exception)) + + # But works with `ssl` is disabled or optional + for ssl_mode in ["disable", "allow", "prefer"]: + psql(token, "SELECT * FROM t_uints", extra={"PGSSLMODE": ssl_mode}) + + # Connection fails with invalid token + with self.assertRaises(Exception) as cm: + psql("invalid_token", "SELECT * FROM t_uints") + self.assertIn("Invalid token", str(cm.exception)) + + # Returns error for unsupported `sql` statements + with self.assertRaises(Exception) as cm: + psql(token, "SELECT CASE a WHEN 1 THEN 'one' ELSE 'other' END FROM t_uints") + self.assertIn("Unsupported", str(cm.exception)) + + # And prepared statements + with self.assertRaises(Exception) as cm: + psql(token, "SELECT * FROM t_uints where u8 = $1") + self.assertIn("Unsupported", str(cm.exception)) diff --git a/smoketests/tests/replication.py b/smoketests/tests/replication.py index eb988078d84..1936a6f0c42 100644 --- a/smoketests/tests/replication.py +++ b/smoketests/tests/replication.py @@ -1,4 +1,4 @@ -from .. import COMPOSE_FILE, Smoketest, requires_docker, spacetime +from .. import COMPOSE_FILE, Smoketest, requires_docker, spacetime, parse_sql_result from ..docker import DockerManager import time @@ -18,17 +18,6 @@ def retry(func: Callable, max_retries: int = 3, retry_delay: int = 2): print("Max retries reached. Skipping the exception.") return False -def parse_sql_result(res: str) -> list[dict]: - """Parse tabular output from an SQL query into a list of dicts.""" - lines = res.splitlines() - headers = lines[0].split('|') if '|' in lines[0] else [lines[0]] - headers = [header.strip() for header in headers] - rows = [] - for row in lines[2:]: - cols = [col.strip() for col in row.split('|')] - rows.append(dict(zip(headers, cols))) - return rows - def int_vals(rows: list[dict]) -> list[dict]: """For all dicts in list, cast all values in dict to int.""" return [{k: int(v) for k, v in row.items()} for row in rows] From 63f35a70f251cd1a7acd209819a52f0b38aff516 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Mon, 15 Sep 2025 10:43:44 -0400 Subject: [PATCH 14/17] Fix inserted line. --- .../core/src/subscription/module_subscription_actor.rs | 9 +++++++-- .../core/src/subscription/module_subscription_manager.rs | 1 - crates/datastore/src/locking_tx_datastore/mut_tx.rs | 1 - 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 63b56d54c07..4df0ff9306c 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -1,5 +1,8 @@ use super::execution_unit::QueryHash; -use super::module_subscription_manager::{from_tx_offset, spawn_send_worker, BroadcastError, BroadcastQueue, Plan, SubscriptionGaugeStats, SubscriptionManager, TransactionOffset}; +use super::module_subscription_manager::{ + from_tx_offset, spawn_send_worker, BroadcastError, BroadcastQueue, Plan, SubscriptionGaugeStats, + SubscriptionManager, TransactionOffset, +}; use super::query::compile_query_with_hashes; use super::tx::DeltaTx; use super::{collect_table_update, TableUpdateType}; @@ -897,7 +900,9 @@ impl ModuleSubscriptions { database_update: SubscriptionUpdateMessage::default_for_protocol(client.config.protocol, None), }; - let _ = self.broadcast_queue.send_client_message(client, Some(from_tx_offset(tx_offset)), message); + let _ = self + .broadcast_queue + .send_client_message(client, Some(from_tx_offset(tx_offset)), message); } else { log::trace!("Reducer failed but there is no client to send the failure to!") } diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index e2fbd3b375a..eab10f22aa9 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -568,7 +568,6 @@ pub fn from_tx_offset(offset: TxOffset) -> TransactionOffset { rx } - /// Message sent by the [`SubscriptionManager`] to the [`SendWorker`]. #[derive(Debug)] enum SendWorkerMessage { diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index 7645fed23f6..d9762dc27f1 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -1197,7 +1197,6 @@ impl MutTxId { /// - [`TxMetrics`], various measurements of the work performed by this transaction. /// - `String`, the name of the reducer which ran during this transaction. pub(super) fn commit(mut self) -> (TxOffset, TxData, TxMetrics, String) { - self.committed_state_write_lock.next_tx_offset += 1; let tx_offset = self.committed_state_write_lock.next_tx_offset; let tx_data = self.committed_state_write_lock.merge(self.tx_state, &self.ctx); From 7bdabc02228c60e85a3731ab672bf4a1a5784dba Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Tue, 16 Sep 2025 11:36:01 -0400 Subject: [PATCH 15/17] Update comment --- crates/core/src/host/module_host.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 31155712085..c1dda1d1e22 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -764,7 +764,8 @@ impl ModuleHost { // TODO: Is this being broadcast? Does it need to be, or are st_client table subscriptions // not allowed? - // I don't think it was being broadcast previously. + // I (jsdt) don't think it was being broadcast previously. See: + // https://github.com/clockworklabs/SpacetimeDB/issues/3130 stdb.finish_tx(ScopeGuard::into_inner(mut_tx), Ok(())) .map_err(|e: DBError| { log::error!("`call_identity_connected`: finish transaction failed: {e:#?}"); From fc7cd3d5596f62ede4d79cff6e0209dd0ec23ac7 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Thu, 18 Sep 2025 15:43:00 -0400 Subject: [PATCH 16/17] Remove commented out code. --- .../datastore/src/locking_tx_datastore/mut_tx.rs | 15 --------------- crates/datastore/src/system_tables.rs | 1 - 2 files changed, 16 deletions(-) diff --git a/crates/datastore/src/locking_tx_datastore/mut_tx.rs b/crates/datastore/src/locking_tx_datastore/mut_tx.rs index d9762dc27f1..4283b3e8aec 100644 --- a/crates/datastore/src/locking_tx_datastore/mut_tx.rs +++ b/crates/datastore/src/locking_tx_datastore/mut_tx.rs @@ -508,11 +508,6 @@ impl MutTxId { /// - The index metadata is inserted into the system tables (and other data structures reflecting them). /// - The returned ID is unique and is not `IndexId::SENTINEL`. pub fn create_index(&mut self, mut index_schema: IndexSchema, is_unique: bool) -> Result { - /* - if index_schema.index_id != IndexId::SENTINEL { - return Err(anyhow::anyhow!("`index_id` must be `IndexId::SENTINEL` in `{:#?}`", index_schema).into()); - } - */ let table_id = index_schema.table_id; if table_id == TableId::SENTINEL { return Err(anyhow::anyhow!("`table_id` must not be `TableId::SENTINEL` in `{:#?}`", index_schema).into()); @@ -975,16 +970,6 @@ impl MutTxId { /// - The constraint metadata is inserted into the system tables (and other data structures reflecting them). /// - The returned ID is unique and is not `constraintId::SENTINEL`. fn create_constraint(&mut self, mut constraint: ConstraintSchema) -> Result { - /* - if constraint.constraint_id != ConstraintId::SENTINEL { - return Err(anyhow::anyhow!( - "`constraint_id` must be `ConstraintId::SENTINEL` in `{:#?}`", - constraint - ) - .into()); - } - - */ if constraint.table_id == TableId::SENTINEL { return Err(anyhow::anyhow!("`table_id` must not be `TableId::SENTINEL` in `{:#?}`", constraint).into()); } diff --git a/crates/datastore/src/system_tables.rs b/crates/datastore/src/system_tables.rs index e7741c70d4d..1ecd69df51f 100644 --- a/crates/datastore/src/system_tables.rs +++ b/crates/datastore/src/system_tables.rs @@ -355,7 +355,6 @@ fn system_module_def() -> ModuleDef { // TODO: add empty unique constraint here, once we've implemented those. let st_connection_credentials_type = builder.add_type::(); - // let st_connection_credentials_unique_cols = [StConnectionCredentialsFields::ConnectionId]; builder .build_table( ST_CONNECTION_CREDENTIALS_NAME, From c80ac15afd61c10bd290e6662d26d21379e12438 Mon Sep 17 00:00:00 2001 From: Jeffrey Dallatezza Date: Fri, 19 Sep 2025 10:26:56 -0400 Subject: [PATCH 17/17] Add helper function --- crates/client-api/src/auth.rs | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/crates/client-api/src/auth.rs b/crates/client-api/src/auth.rs index f017ae4d55d..b9f091e6f08 100644 --- a/crates/client-api/src/auth.rs +++ b/crates/client-api/src/auth.rs @@ -89,6 +89,19 @@ pub struct SpacetimeAuth { pub jwt_payload: String, } +impl SpacetimeAuth { + pub fn new(creds: SpacetimeCreds, claims: SpacetimeIdentityClaims) -> Result { + let payload = creds + .extract_jwt_payload_string() + .ok_or_else(|| anyhow!("Failed to extract JWT payload"))?; + Ok(Self { + creds, + claims, + jwt_payload: payload, + }) + } +} + impl From for ConnectionAuthCtx { fn from(auth: SpacetimeAuth) -> Self { ConnectionAuthCtx { @@ -176,15 +189,7 @@ impl SpacetimeAuth { let (claims, token) = claims.encode_and_sign(ctx.jwt_auth_provider()).map_err(log_and_500)?; let creds = SpacetimeCreds::from_signed_token(token); // Pulling out the payload should never fail, since we just made it. - let payload = creds - .extract_jwt_payload_string() - .ok_or_else(|| log_and_500("internal error"))?; - - Ok(Self { - creds, - claims, - jwt_payload: payload, - }) + Self::new(creds, claims).map_err(log_and_500) } /// Get the auth credentials as headers to be returned from an endpoint.