diff --git a/Cargo.lock b/Cargo.lock index ee3283c920..db4f2aef1a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3807,6 +3807,7 @@ dependencies = [ "cli-table", "colored", "console-subscriber", + "futures", "golem-api-grpc", "golem-common", "golem-service-base", @@ -3821,6 +3822,7 @@ dependencies = [ "once_cell", "postgres", "redis", + "scylla", "serde 1.0.210", "serde_json", "serde_yaml", @@ -4020,6 +4022,7 @@ dependencies = [ "redis", "ringbuf", "rustls 0.23.14", + "scylla", "serde 1.0.210", "serde_json", "sqlx", @@ -4374,6 +4377,12 @@ version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" +[[package]] +name = "histogram" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12cb882ccb290b8646e554b157ab0b71e64e8d5bef775cd66b6531e52d302669" + [[package]] name = "hkdf" version = "0.12.4" @@ -5378,6 +5387,15 @@ dependencies = [ "hashbrown 0.15.0", ] +[[package]] +name = "lz4_flex" +version = "0.11.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75761162ae2b0e580d7e7c390558127e5f01b4194debd6221fd8c207fc80e3f5" +dependencies = [ + "twox-hash", +] + [[package]] name = "mach2" version = "0.4.2" @@ -7038,6 +7056,15 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rand_pcg" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59cad018caf63deb318e5a4586d99a24424a364f40f1e5778c29aca23f4fc73e" +dependencies = [ + "rand_core", +] + [[package]] name = "rand_xorshift" version = "0.3.0" @@ -7620,6 +7647,66 @@ dependencies = [ "untrusted", ] +[[package]] +name = "scylla" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8139623d3fb0c8205b15e84fa587f3aa0ba61f876c19a9157b688f7c1763a7c5" +dependencies = [ + "arc-swap", + "async-trait", + "byteorder", + "bytes 1.7.2", + "chrono", + "dashmap", + "futures", + "hashbrown 0.14.5", + "histogram", + "itertools 0.13.0", + "lazy_static 1.5.0", + "lz4_flex", + "rand", + "rand_pcg", + "scylla-cql", + "scylla-macros", + "smallvec", + "snap", + "socket2 0.5.7", + "thiserror", + "tokio", + "tracing", + "uuid", +] + +[[package]] +name = "scylla-cql" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de7020bcd1f6fdbeaed356cd426bf294b2071bd7120d48d2e8e319295e2acdcd" +dependencies = [ + "async-trait", + "byteorder", + "bytes 1.7.2", + "lz4_flex", + "scylla-macros", + "snap", + "thiserror", + "tokio", + "uuid", +] + +[[package]] +name = "scylla-macros" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3859b6938663fc5062e3b26f3611649c9bd26fb252e85f6fdfa581e0d2ce74b6" +dependencies = [ + "darling 0.20.10", + "proc-macro2", + "quote", + "syn 2.0.79", +] + [[package]] name = "sec1" version = "0.3.0" @@ -8139,6 +8226,12 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7c388c1b5e93756d0c740965c41e8822f866621d41acbdf6336a6a168f8840c" +[[package]] +name = "snap" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b6b67fb9a61334225b5b790716f609cd58395f895b3fe8b328786812a40bc3b" + [[package]] name = "socket2" version = "0.4.10" @@ -9303,6 +9396,16 @@ dependencies = [ "utf-8", ] +[[package]] +name = "twox-hash" +version = "1.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" +dependencies = [ + "cfg-if", + "static_assertions", +] + [[package]] name = "typenum" version = "1.17.0" diff --git a/Cargo.toml b/Cargo.toml index 9a7cec8725..409e5007b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -200,6 +200,7 @@ wasmtime = { version = "=21.0.1", features = ["component-model"] } wasmtime-wasi = { version = "=21.0.1" } wasmtime-wasi-http = { version = "=21.0.1" } webpki-roots = { version = "0.26.0" } +scylla = "0.14.0" [patch.crates-io] wasmtime = { git = "https://github.com/golemcloud/wasmtime.git", branch = "golem-wasmtime-v21.0.1" } diff --git a/golem-common/src/config.rs b/golem-common/src/config.rs index 580c04a5d7..30022a6b65 100644 --- a/golem-common/src/config.rs +++ b/golem-common/src/config.rs @@ -17,6 +17,7 @@ use figment::providers::{Env, Format, Serialized, Toml}; use figment::value::Value; use figment::Figment; use serde::{Deserialize, Serialize}; +use std::net::SocketAddr; use std::path::{Path, PathBuf}; use std::time::Duration; use url::Url; @@ -467,3 +468,31 @@ pub struct DbPostgresConfig { pub max_connections: u32, pub schema: Option, } + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CassandraConfig { + pub nodes: Vec, + #[serde(default = "default_cassandra_keyspace")] + pub keyspace: String, + pub tracing: bool, + pub pool_size_per_host: usize, + pub username: Option, + pub password: Option, +} + +fn default_cassandra_keyspace() -> String { + String::from("__golem") +} + +impl Default for CassandraConfig { + fn default() -> Self { + Self { + nodes: vec!["127.0.0.1:9042".parse().unwrap()], + keyspace: default_cassandra_keyspace(), + tracing: false, + pool_size_per_host: 3, + username: None, + password: None, + } + } +} diff --git a/golem-test-framework/Cargo.toml b/golem-test-framework/Cargo.toml index ba62399cba..48e402583c 100644 --- a/golem-test-framework/Cargo.toml +++ b/golem-test-framework/Cargo.toml @@ -27,6 +27,7 @@ cli-table = { workspace = true } chrono = { workspace = true } colored = "2.1.0" console-subscriber = { workspace = true } +futures = { workspace = true } itertools = { workspace = true } k8s-openapi = { workspace = true } kill_tree = { version = "0.2.4", features = ["tokio"] } @@ -35,6 +36,7 @@ kube-derive = { workspace = true } once_cell = { workspace = true } postgres = { workspace = true } redis = { workspace = true } +scylla = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_yaml = { workspace = true } diff --git a/golem-test-framework/src/components/cassandra/docker.rs b/golem-test-framework/src/components/cassandra/docker.rs new file mode 100644 index 0000000000..57e19eb1ab --- /dev/null +++ b/golem-test-framework/src/components/cassandra/docker.rs @@ -0,0 +1,82 @@ +// Copyright 2024 Golem Cloud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use super::Cassandra; +use crate::components::{DOCKER, NETWORK}; +use std::sync::atomic::{AtomicBool, Ordering}; +use testcontainers::{Container, GenericImage, RunnableImage}; +use tracing::info; + +pub struct DockerCassandra { + container: Container<'static, GenericImage>, + keep_container: bool, + valid: AtomicBool, + public_port: u16, +} + +impl DockerCassandra { + const NAME: &'static str = "golem_cassandra"; + + pub fn new(keep_container: bool) -> Self { + let image = GenericImage::new("cassandra", "latest") + .with_exposed_port(super::DEFAULT_PORT) + .with_wait_for(testcontainers::core::WaitFor::message_on_stdout( + "Starting listening for CQL clients on", + )); + let cassandra_image: RunnableImage<_> = RunnableImage::from(image) + .with_container_name(Self::NAME) + .with_network(NETWORK); + + let container = DOCKER.run(cassandra_image); + let public_port: u16 = container.get_host_port_ipv4(super::DEFAULT_PORT); + + DockerCassandra { + container, + keep_container, + valid: AtomicBool::new(true), + public_port, + } + } +} + +impl Cassandra for DockerCassandra { + fn assert_valid(&self) { + if !self.valid.load(Ordering::Acquire) { + std::panic!("Cassandra has been closed") + } + } + + fn private_known_nodes(&self) -> Vec { + vec![format!("{}:{}", Self::NAME, super::DEFAULT_PORT)] + } + + fn kill(&self) { + info!("Stopping Cassandra container"); + if self.keep_container { + self.container.stop() + } else { + self.container.rm() + } + } + + fn public_known_nodes(&self) -> Vec { + vec![format!("localhost:{}", self.public_port)] + } +} + +impl Drop for DockerCassandra { + fn drop(&mut self) { + self.kill() + } +} diff --git a/golem-test-framework/src/components/cassandra/mod.rs b/golem-test-framework/src/components/cassandra/mod.rs new file mode 100644 index 0000000000..2ff7877f3b --- /dev/null +++ b/golem-test-framework/src/components/cassandra/mod.rs @@ -0,0 +1,71 @@ +// Copyright 2024 Golem Cloud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use scylla::{transport::session::PoolSize, Session, SessionBuilder}; +use std::{num::NonZeroUsize, sync::Arc}; +use tonic::async_trait; + +pub mod docker; + +#[async_trait] +pub trait Cassandra { + fn assert_valid(&self); + + fn private_known_nodes(&self) -> Vec; + + fn public_known_nodes(&self) -> Vec; + + fn kill(&self); + + async fn try_get_session(&self, keyspace: Option<&str>) -> Result, String> { + let mut session_builder = SessionBuilder::new() + .known_nodes(self.public_known_nodes()) + .pool_size(PoolSize::PerHost(NonZeroUsize::new(10).unwrap())); + + if let Some(keyspace) = keyspace { + session_builder = session_builder.use_keyspace(keyspace, false) + }; + + let session = session_builder + .build() + .await + .map_err(|e| e.to_string()) + .unwrap(); + + Ok(Arc::new(session)) + } + + async fn get_session(&self, keyspace: Option<&str>) -> Arc { + self.assert_valid(); + self.try_get_session(keyspace).await.unwrap() + } + + async fn flush_keyspace(&self, keyspace: &str) { + let session = self.get_session(Some(keyspace)).await; + session + .query_unpaged(format!("TRUNCATE {}.{};", keyspace, "kv_store"), &[]) + .await + .unwrap(); + session + .query_unpaged(format!("TRUNCATE {}.{};", keyspace, "kv_sets"), &[]) + .await + .unwrap(); + session + .query_unpaged(format!("TRUNCATE {}.{};", keyspace, "kv_sorted_sets"), &[]) + .await + .unwrap(); + } +} + +const DEFAULT_PORT: u16 = 9042; diff --git a/golem-test-framework/src/components/mod.rs b/golem-test-framework/src/components/mod.rs index 95aee27c2a..db9054ea7e 100644 --- a/golem-test-framework/src/components/mod.rs +++ b/golem-test-framework/src/components/mod.rs @@ -26,6 +26,7 @@ use tracing::{error, warn, Level}; use golem_api_grpc::proto::grpc::health::v1::health_check_response::ServingStatus; use golem_api_grpc::proto::grpc::health::v1::HealthCheckRequest; +pub mod cassandra; pub mod component_compilation_service; pub mod component_service; mod docker; diff --git a/golem-test-framework/src/config/cli.rs b/golem-test-framework/src/config/cli.rs index e512534952..9effa9c4c3 100644 --- a/golem-test-framework/src/config/cli.rs +++ b/golem-test-framework/src/config/cli.rs @@ -1085,6 +1085,12 @@ impl TestDependencies for CliTestDependencies { fn worker_executor_cluster(&self) -> Arc { self.worker_executor_cluster.clone() } + + fn cassandra( + &self, + ) -> Option> { + None + } } #[allow(dead_code)] diff --git a/golem-test-framework/src/config/env.rs b/golem-test-framework/src/config/env.rs index 86649818ee..302605e4a3 100644 --- a/golem-test-framework/src/config/env.rs +++ b/golem-test-framework/src/config/env.rs @@ -13,9 +13,10 @@ // limitations under the License. use std::path::{Path, PathBuf}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use crate::components; +use crate::components::cassandra::Cassandra; use crate::components::component_compilation_service::docker::DockerComponentCompilationService; use crate::components::component_compilation_service::spawned::SpawnedComponentCompilationService; use crate::components::component_compilation_service::ComponentCompilationService; @@ -79,7 +80,7 @@ impl EnvBasedTestDependenciesConfig { self.keep_docker_containers = keep_docker_containers } - if let Some(redis_port) = opt_env_var("REDIS_KEY_PREFIX") { + if let Some(redis_port) = opt_env_var("REDIS_PORT") { self.redis_port = redis_port.parse().expect("Failed to parse REDIS_PORT"); } @@ -151,6 +152,7 @@ pub struct EnvBasedTestDependencies { component_compilation_service: Arc, worker_service: Arc, worker_executor_cluster: Arc, + cassandra: RwLock>>, } impl EnvBasedTestDependencies { @@ -406,6 +408,12 @@ impl EnvBasedTestDependencies { } } + fn make_cassandra( + _config: Arc, + ) -> RwLock>> { + RwLock::new(None) + } + pub async fn new(config: EnvBasedTestDependenciesConfig) -> Self { let config = Arc::new(config); @@ -461,6 +469,8 @@ impl EnvBasedTestDependencies { let redis_monitor = redis_monitor_join.await.expect("Failed to join"); + let cassandra = Self::make_cassandra(config.clone()); + Self { config: config.clone(), rdb, @@ -471,6 +481,7 @@ impl EnvBasedTestDependencies { component_compilation_service, worker_service, worker_executor_cluster, + cassandra, } } } @@ -513,6 +524,13 @@ impl TestDependencies for EnvBasedTestDependencies { fn worker_executor_cluster(&self) -> Arc { self.worker_executor_cluster.clone() } + + fn cassandra(&self) -> Option> { + match self.cassandra.read() { + Ok(cassandra) => cassandra.as_ref().map(|c| c.clone()), + Err(poison_err) => poison_err.into_inner().as_ref().map(|c| c.clone()), + } + } } fn opt_env_var(name: &str) -> Option { diff --git a/golem-test-framework/src/config/mod.rs b/golem-test-framework/src/config/mod.rs index c1f5b116f6..16c7bdeef1 100644 --- a/golem-test-framework/src/config/mod.rs +++ b/golem-test-framework/src/config/mod.rs @@ -15,6 +15,7 @@ use std::path::PathBuf; use std::sync::Arc; +use crate::components::cassandra::Cassandra; use crate::components::component_compilation_service::ComponentCompilationService; pub use cli::{CliParams, CliTestDependencies, CliTestService}; pub use env::EnvBasedTestDependencies; @@ -44,6 +45,7 @@ pub trait TestDependencies { ) -> Arc; fn worker_service(&self) -> Arc; fn worker_executor_cluster(&self) -> Arc; + fn cassandra(&self) -> Option>; fn kill_all(&self) { self.worker_executor_cluster().kill_all(); @@ -54,6 +56,9 @@ pub trait TestDependencies { self.rdb().kill(); self.redis_monitor().kill(); self.redis().kill(); + if let Some(c) = self.cassandra() { + c.kill() + } } } diff --git a/golem-worker-executor-base/Cargo.toml b/golem-worker-executor-base/Cargo.toml index 66fc8ba6eb..2b8b471fb8 100644 --- a/golem-worker-executor-base/Cargo.toml +++ b/golem-worker-executor-base/Cargo.toml @@ -87,6 +87,7 @@ wasmtime-wasi-http = { workspace = true } windows-sys = "0.52.0" zstd = "0.13" sqlx = { workspace = true } +scylla = { workspace = true } [dev-dependencies] golem-test-framework = { path = "../golem-test-framework", version = "0.0.0" } diff --git a/golem-worker-executor-base/src/lib.rs b/golem-worker-executor-base/src/lib.rs index 1d5088c8ca..7bab48ed54 100644 --- a/golem-worker-executor-base/src/lib.rs +++ b/golem-worker-executor-base/src/lib.rs @@ -72,6 +72,8 @@ use humansize::{ISizeFormatter, BINARY}; use nonempty_collections::NEVec; use prometheus::Registry; use std::sync::Arc; +use storage::cassandra::CassandraSession; +use storage::keyvalue::cassandra::CassandraKeyValueStorage; use storage::keyvalue::sqlite::SqliteKeyValueStorage; use storage::sqlite_types::SqlitePool; use tokio::runtime::Handle; @@ -197,6 +199,18 @@ pub trait Bootstrap { Arc::new(SqliteKeyValueStorage::new(pool.clone())); (None, key_value_storage) } + KeyValueStorageConfig::Cassandra(cassandra) => { + info!( + "Using Cassandra for key-value storage at {:?}", + cassandra.nodes + ); + let session = CassandraSession::configured(cassandra) + .await + .map_err(|err| anyhow!(err))?; + let key_value_storage: Arc = + Arc::new(CassandraKeyValueStorage::new(session.clone())); + (None, key_value_storage) + } }; let indexed_storage: Arc = match &golem_config diff --git a/golem-worker-executor-base/src/services/golem_config.rs b/golem-worker-executor-base/src/services/golem_config.rs index db57856d89..894328f575 100644 --- a/golem-worker-executor-base/src/services/golem_config.rs +++ b/golem-worker-executor-base/src/services/golem_config.rs @@ -24,7 +24,8 @@ use serde::{Deserialize, Serialize}; use url::Url; use golem_common::config::{ - ConfigExample, ConfigLoader, DbSqliteConfig, HasConfigExamples, RedisConfig, RetryConfig, + CassandraConfig, ConfigExample, ConfigLoader, DbSqliteConfig, HasConfigExamples, RedisConfig, + RetryConfig, }; use golem_common::tracing::TracingConfig; @@ -253,6 +254,7 @@ pub struct OplogConfig { pub enum KeyValueStorageConfig { Redis(RedisConfig), Sqlite(DbSqliteConfig), + Cassandra(CassandraConfig), InMemory, } diff --git a/golem-worker-executor-base/src/storage/cassandra.rs b/golem-worker-executor-base/src/storage/cassandra.rs new file mode 100644 index 0000000000..7857f7e3ae --- /dev/null +++ b/golem-worker-executor-base/src/storage/cassandra.rs @@ -0,0 +1,220 @@ +// Copyright 2024 Golem Cloud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use futures::StreamExt; +use golem_common::config::CassandraConfig; +use golem_common::metrics::db::{record_db_failure, record_db_success}; +use scylla::batch::{Batch, BatchType}; +use scylla::prepared_statement::PreparedStatement; +use scylla::serialize::row::SerializeRow; +use scylla::transport::errors::QueryError; +use scylla::FromRow; +use scylla::{transport::session::PoolSize, Session, SessionBuilder}; +use std::fmt::Debug; +use std::time::Instant; +use std::{num::NonZeroUsize, sync::Arc}; + +#[derive(Debug, Clone)] +pub struct CassandraSession { + pub session: Arc, + pub keyspace: String, + pub set_tracing: bool, +} + +impl CassandraSession { + pub fn new(session: Arc, set_tracing: bool, keyspace: &str) -> Self { + CassandraSession { + session, + keyspace: String::from(keyspace), + set_tracing, + } + } + + pub async fn configured(config: &CassandraConfig) -> Result { + let mut session_builder = SessionBuilder::new() + .known_nodes_addr(config.nodes.iter()) + .pool_size(PoolSize::PerHost( + NonZeroUsize::new(config.pool_size_per_host).unwrap(), + )) + .use_keyspace(&config.keyspace, false); + + if let (Some(username), Some(password)) = + (config.username.as_ref(), config.password.as_ref()) + { + session_builder = session_builder.user(username, password); + } + + let session = session_builder.build().await.map_err(|e| e.to_string())?; + + Ok(CassandraSession { + session: Arc::new(session), + keyspace: config.keyspace.clone(), + set_tracing: config.tracing, + }) + } + + pub fn with(&self, svc_name: &'static str, api_name: &'static str) -> CassandraLabelledApi { + CassandraLabelledApi { + svc_name, + api_name, + cassandra: self.clone(), + } + } +} + +pub struct CassandraLabelledApi { + svc_name: &'static str, + api_name: &'static str, + pub cassandra: CassandraSession, +} + +impl CassandraLabelledApi { + fn record( + &self, + start: Instant, + cmd_name: &'static str, + result: Result, + ) -> Result { + let end = Instant::now(); + match result { + Ok(result) => { + record_db_success( + "cassandra", + self.svc_name, + self.api_name, + cmd_name, + end.duration_since(start), + ); + Ok(result) + } + Err(err) => { + record_db_failure("cassandra", self.svc_name, self.api_name, cmd_name); + Err(err) + } + } + } + + async fn statement(&self, query_text: &str) -> PreparedStatement { + let mut statement = self.cassandra.session.prepare(query_text).await.unwrap(); + statement.set_tracing(self.cassandra.set_tracing); + statement + } + + pub async fn perform_query( + &self, + cmd_name: &'static str, + query: String, + values: impl SerializeRow, + ) -> Result<(), QueryError> { + let start = Instant::now(); + self.record( + start, + cmd_name, + self.cassandra + .session + .execute_unpaged(&self.statement(&query).await, values) + .await, + ) + .map(|_| ()) + } + + pub async fn perform_batch( + &self, + cmd_name: &'static str, + query: String, + values: Vec, + ) -> Result<(), QueryError> { + let mut batch: Batch = Batch::new(BatchType::Logged); + + let start = Instant::now(); + for _ in 1..=values.len() { + batch.append_statement(self.statement(&query).await); + } + + let mut batch: Batch = self.cassandra.session.prepare_batch(&batch).await?; + + batch.set_tracing(self.cassandra.set_tracing); + + self.record( + start, + cmd_name, + self.cassandra.session.batch(&batch, &values).await, + ) + .map(|_| ()) + } + + pub async fn maybe_row( + &self, + cmd_name: &'static str, + query: String, + values: impl SerializeRow, + map_to_value: F, + ) -> Result, QueryError> + where + RowT: FromRow, + T: Debug, + F: FnOnce(RowT) -> T, + { + let start = Instant::now(); + + self.record( + start, + cmd_name, + self.cassandra + .session + .execute_unpaged(&self.statement(&query).await, values) + .await? + .maybe_first_row_typed::() + .map_err(|e| QueryError::InvalidMessage(e.to_string())) + .map(|opt_row| opt_row.map(map_to_value)), + ) + } + + pub async fn get_rows( + &self, + cmd_name: &'static str, + query: String, + values: impl SerializeRow, + mut map_to_value: F, + ) -> Result, QueryError> + where + RowT: FromRow, + T: Debug, + F: FnMut(RowT) -> T, + { + let start = Instant::now(); + + let mut rows = self + .cassandra + .session + .execute_iter(self.statement(&query).await, &values) + .await? + .into_typed::(); + + let mut result = Vec::new(); + while let Some(row) = rows.next().await { + match row { + Ok(row) => result.push(map_to_value(row)), + Err(err) => { + return self.record( + start, + cmd_name, + Err(QueryError::InvalidMessage(err.to_string())), + ) + } + } + } + self.record(start, cmd_name, Ok(result)) + } +} diff --git a/golem-worker-executor-base/src/storage/keyvalue/cassandra.rs b/golem-worker-executor-base/src/storage/keyvalue/cassandra.rs new file mode 100644 index 0000000000..8c63db9ef4 --- /dev/null +++ b/golem-worker-executor-base/src/storage/keyvalue/cassandra.rs @@ -0,0 +1,610 @@ +// Copyright 2024 Golem Cloud +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use crate::storage::{ + cassandra::CassandraSession, + keyvalue::{KeyValueStorage, KeyValueStorageNamespace}, +}; +use async_trait::async_trait; +use bytes::Bytes; +use scylla::{ + prepared_statement::PreparedStatement, query::Query, serialize::row::SerializeRow, + transport::errors::QueryError, FromRow, +}; +use serde::Deserialize; +use std::fmt::Debug; +use std::{collections::HashMap, iter::repeat}; + +#[derive(Debug)] +pub struct CassandraKeyValueStorage { + session: CassandraSession, +} + +impl CassandraKeyValueStorage { + pub async fn create_schema(&self) -> Result<(), String> { + self.session.session.query_unpaged( + Query::new( + format!("CREATE KEYSPACE IF NOT EXISTS {} WITH REPLICATION = {{ 'class' : 'SimpleStrategy', 'replication_factor' : 1 }};", self.session.keyspace), + ), + &[], + ).await + .map_err(|e| e.to_string())?; + + self.session + .session + .query_unpaged( + Query::new(format!( + r#" + CREATE TABLE IF NOT EXISTS {}.kv_store ( + namespace TEXT, + key TEXT, + value BLOB, + PRIMARY KEY (namespace, key) + );"#, + self.session.keyspace + )), + &[], + ) + .await + .map_err(|e| e.to_string())?; + + self.session + .session + .query_unpaged( + Query::new(format!( + r#" + CREATE TABLE IF NOT EXISTS {}.kv_sets ( + namespace TEXT, + key TEXT, + value BLOB, + PRIMARY KEY ((namespace, key), value) + );"#, + self.session.keyspace + )), + &[], + ) + .await + .map_err(|e| e.to_string())?; + + self.session + .session + .query_unpaged( + Query::new(format!( + r#" + CREATE TABLE IF NOT EXISTS {}.kv_sorted_sets ( + namespace TEXT, + key TEXT, + score DOUBLE, + value BLOB, + PRIMARY KEY ((namespace, key), score, value) + );"#, + self.session.keyspace + )), + &[], + ) + .await + .map_err(|e| e.to_string()) + .map(|_| ()) + } + + pub fn new(session: CassandraSession) -> Self { + Self { session } + } + + fn to_string(&self, ns: KeyValueStorageNamespace) -> String { + match ns { + KeyValueStorageNamespace::Worker => "worker".to_string(), + KeyValueStorageNamespace::Promise => "promise".to_string(), + KeyValueStorageNamespace::Schedule => "schedule".to_string(), + KeyValueStorageNamespace::UserDefined { account_id, bucket } => { + format!("user-defined:{account_id}:{bucket}") + } + } + } + + async fn statement(&self, query_text: &str) -> PreparedStatement { + let mut statement = self.session.session.prepare(query_text).await.unwrap(); + statement.set_tracing(self.session.set_tracing); + statement + } + + async fn maybe_row( + &self, + query: String, + values: impl SerializeRow, + map_to_value: F, + ) -> Result, QueryError> + where + RowT: FromRow, + T: Debug, + F: FnOnce(RowT) -> T, + { + self.session + .session + .execute_unpaged(&self.statement(&query).await, values) + .await? + .maybe_first_row_typed::() + .map_err(|e| QueryError::InvalidMessage(e.to_string())) + .map(|opt_row| opt_row.map(map_to_value)) + } +} + +#[derive(FromRow, Debug, Deserialize)] +struct ValueRow { + value: Vec, +} + +impl ValueRow { + fn into_bytes(self) -> Bytes { + Bytes::from(self.value) + } +} + +#[derive(FromRow, Debug, Deserialize)] +struct KeyValueRow { + key: String, + value: Vec, +} + +impl KeyValueRow { + fn into_pair(self) -> (String, Bytes) { + (self.key, Bytes::from(self.value)) + } +} +#[derive(FromRow, Debug, Deserialize)] +struct ScoreValueRow { + score: f64, + value: Vec, +} + +impl ScoreValueRow { + fn into_pair(self) -> (f64, Bytes) { + (self.score, Bytes::from(self.value)) + } +} + +#[async_trait] +impl KeyValueStorage for CassandraKeyValueStorage { + async fn set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + value: &[u8], + ) -> Result<(), String> { + let query = format!( + "INSERT INTO {}.kv_store (namespace, key, value) VALUES (?, ?, ?);", + self.session.keyspace.clone() + ); + self.session + .with(svc_name, api_name) + .perform_query("set", query, (&self.to_string(namespace), key, value)) + .await + .map_err(|e| e.to_string()) + } + + async fn set_many( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + pairs: &[(&str, &[u8])], + ) -> Result<(), String> { + let query = format!( + "INSERT INTO {}.kv_store (namespace, key, value) VALUES (?, ?, ?)", + self.session.keyspace + ); + let namespace = self.to_string(namespace); + let values = pairs + .iter() + .map(|(field_key, field_value)| (&namespace, *field_key, *field_value)) + .collect::>(); + + self.session + .with(svc_name, api_name) + .perform_batch("set_many", query, values) + .await + .map_err(|e| e.to_string()) + } + + async fn set_if_not_exists( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + value: &[u8], + ) -> Result { + let exists_query = format!( + "SELECT value FROM {}.kv_store WHERE namespace = ? AND key = ? LIMIT 1;", + self.session.keyspace + ); + let namespace = self.to_string(namespace); + let not_exists = self + .maybe_row(exists_query, (&namespace, key), |r: ValueRow| r) + .await + .map_or(true, |opt| opt.is_none()); + + let insert_query = format!( + "INSERT INTO {}.kv_store (namespace, key, value) VALUES (?, ?, ?) IF NOT EXISTS;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .perform_query("set_if_not_exists", insert_query, (&namespace, key, value)) + .await + .map_err(|e| e.to_string()) + .map(|_| not_exists) + } + + async fn get( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + ) -> Result, String> { + let query = format!( + "SELECT value FROM {}.kv_store WHERE namespace = ? AND key = ?;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .maybe_row( + "get", + query, + (self.to_string(namespace), key), + |row: ValueRow| row.into_bytes(), + ) + .await + .map_err(|e| e.to_string()) + } + + async fn get_many( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + keys: Vec, + ) -> Result>, String> { + let placeholders: String = repeat("?").take(keys.len()).collect::>().join(", "); + let query = format!( + "SELECT key, value FROM {}.kv_store WHERE namespace = ? AND key IN ({});", + self.session.keyspace, placeholders + ); + + let parameters: Vec = vec![self.to_string(namespace)] + .into_iter() + .chain(keys) + .collect(); + + let result = self + .session + .with(svc_name, api_name) + .get_rows("get_many", query, ¶meters, |row: KeyValueRow| { + row.into_pair() + }) + .await + .map_err(|e| e.to_string()) + .unwrap(); + + let mut result_map = result.into_iter().collect::>(); + + let keys = parameters[1..].to_vec(); + let values = keys + .into_iter() + .map(|key| result_map.remove(&key)) + .collect::>>(); + + Ok(values) + } + + async fn del( + &self, + svc_name: &'static str, + api_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + ) -> Result<(), String> { + let query = format!( + "DELETE FROM {}.kv_store WHERE namespace = ? AND key = ?;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .perform_query("del", query, (&self.to_string(namespace), key)) + .await + .map_err(|e| e.to_string()) + } + + async fn del_many( + &self, + svc_name: &'static str, + api_name: &'static str, + namespace: KeyValueStorageNamespace, + keys: Vec, + ) -> Result<(), String> { + let placeholders: String = repeat("?").take(keys.len()).collect::>().join(", "); + let query = format!( + "DELETE FROM {}.kv_store WHERE namespace = ? AND key IN ({});", + self.session.keyspace, placeholders + ); + + let parameters: Vec = vec![self.to_string(namespace)] + .into_iter() + .chain(keys) + .collect(); + + self.session + .with(svc_name, api_name) + .perform_query("del_many", query, ¶meters) + .await + .map_err(|e| e.to_string()) + } + + async fn exists( + &self, + svc_name: &'static str, + api_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + ) -> Result { + let query = format!( + "SELECT value FROM {}.kv_store WHERE namespace = ? AND key = ? LIMIT 1;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .maybe_row( + "exists", + query, + (self.to_string(namespace), key), + |row: ValueRow| row, + ) + .await + .map(|opt| opt.is_some()) + .map_err(|e| e.to_string()) + } + + async fn keys( + &self, + svc_name: &'static str, + api_name: &'static str, + namespace: KeyValueStorageNamespace, + ) -> Result, String> { + let query = format!( + "SELECT key FROM {}.kv_store WHERE namespace = ?;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .get_rows( + "keys", + query, + (self.to_string(namespace),), + |row: (String,)| row.0, + ) + .await + .map_err(|e| e.to_string()) + } + + async fn add_to_set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + value: &[u8], + ) -> Result<(), String> { + let query = format!( + "INSERT INTO {}.kv_sets (namespace, key, value) VALUES (?, ?, ?);", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .perform_query( + "add_to_set", + query, + (&self.to_string(namespace), key, value), + ) + .await + .map_err(|e| e.to_string()) + } + + async fn remove_from_set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + value: &[u8], + ) -> Result<(), String> { + let query = format!( + "DELETE FROM {}.kv_sets WHERE namespace = ? AND key = ? AND value = ?;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .perform_query( + "remove_from_set", + query, + (&self.to_string(namespace), key, value), + ) + .await + .map_err(|e| e.to_string()) + } + + async fn members_of_set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + ) -> Result, String> { + let query = format!( + "SELECT value FROM {}.kv_sets WHERE namespace = ? AND key = ?;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .get_rows( + "members_of_set", + query, + (&self.to_string(namespace), key), + |row: ValueRow| row.into_bytes(), + ) + .await + .map_err(|e| e.to_string()) + } + + async fn add_to_sorted_set( + &self, + svc_name: &'static str, + api_name: &'static str, + entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + score: f64, + value: &[u8], + ) -> Result<(), String> { + self.remove_from_sorted_set( + svc_name, + api_name, + entity_name, + namespace.clone(), + key, + value, + ) + .await?; + let insert_statement = format!( + "INSERT INTO {}.kv_sorted_sets (namespace, key, score, value) VALUES (?, ?, ?, ?);", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .perform_query( + "add_to_sorted_set", + insert_statement, + (&self.to_string(namespace), key, score, value), + ) + .await + .map_err(|e| e.to_string()) + } + + async fn remove_from_sorted_set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + value: &[u8], + ) -> Result<(), String> { + let get_score = format!( + "SELECT score, value FROM {}.kv_sorted_sets WHERE namespace = ? AND key = ? AND value = ? ALLOW FILTERING;", + self.session.keyspace + ); + let namespace = self.to_string(namespace); + match self + .maybe_row(get_score, (&namespace, key, value), |row: ScoreValueRow| { + row.score + }) + .await + .map_err(|e| e.to_string())? + { + None => Ok(()), + Some(score) => { + let delete_statement = format!("DELETE FROM {}.kv_sorted_sets WHERE namespace = ? AND key = ? AND score = ? AND value = ?;", self.session.keyspace); + self.session + .with(svc_name, api_name) + .perform_query( + "remove_from_sorted_set", + delete_statement, + (&namespace, key, score, value), + ) + .await + .map_err(|e| e.to_string()) + } + } + } + + async fn get_sorted_set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + ) -> Result, String> { + let query = format!( + "SELECT score, value FROM {}.kv_sorted_sets WHERE namespace = ? AND key = ? ORDER BY score ASC;", + self.session.keyspace + ); + + self.session + .with(svc_name, api_name) + .get_rows( + "get_sorted_set", + query, + (&self.to_string(namespace), key), + |row: ScoreValueRow| row.into_pair(), + ) + .await + .map_err(|e| e.to_string()) + } + + async fn query_sorted_set( + &self, + svc_name: &'static str, + api_name: &'static str, + _entity_name: &'static str, + namespace: KeyValueStorageNamespace, + key: &str, + min: f64, + max: f64, + ) -> Result, String> { + let query = format!( + "SELECT score, value FROM {}.kv_sorted_sets WHERE namespace = ? AND key = ? AND score >= ? AND score <= ? ORDER BY score ASC;", + self.session.keyspace + ); + self.session + .with(svc_name, api_name) + .get_rows( + "query_sorted_set", + query, + (&self.to_string(namespace), key, min, max), + |row: ScoreValueRow| row.into_pair(), + ) + .await + .map_err(|e| e.to_string()) + } +} diff --git a/golem-worker-executor-base/src/storage/keyvalue/mod.rs b/golem-worker-executor-base/src/storage/keyvalue/mod.rs index 9e443fe8ea..7df6fae0b3 100644 --- a/golem-worker-executor-base/src/storage/keyvalue/mod.rs +++ b/golem-worker-executor-base/src/storage/keyvalue/mod.rs @@ -12,6 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. +pub mod cassandra; pub mod memory; pub mod redis; pub mod sqlite; diff --git a/golem-worker-executor-base/src/storage/mod.rs b/golem-worker-executor-base/src/storage/mod.rs index 522dd2afa9..c0e766e069 100644 --- a/golem-worker-executor-base/src/storage/mod.rs +++ b/golem-worker-executor-base/src/storage/mod.rs @@ -13,6 +13,7 @@ // limitations under the License. pub mod blob; +pub mod cassandra; pub mod indexed; pub mod keyvalue; pub mod sqlite_types; diff --git a/golem-worker-executor-base/tests/common/mod.rs b/golem-worker-executor-base/tests/common/mod.rs index a04b94fe77..671e6a4250 100644 --- a/golem-worker-executor-base/tests/common/mod.rs +++ b/golem-worker-executor-base/tests/common/mod.rs @@ -2,6 +2,7 @@ use anyhow::Error; use async_trait::async_trait; use ctor::ctor; +use golem_test_framework::components::cassandra::Cassandra; use golem_wasm_rpc::wasmtime::ResourceStore; use golem_wasm_rpc::{Uri, Value}; use prometheus::Registry; @@ -235,6 +236,10 @@ impl TestDependencies for TestWorkerExecutor { fn worker_executor_cluster(&self) -> Arc { self.deps.worker_executor_cluster() } + + fn cassandra(&self) -> Option> { + self.deps.cassandra() + } } impl Drop for TestWorkerExecutor { diff --git a/golem-worker-executor-base/tests/key_value_storage.rs b/golem-worker-executor-base/tests/key_value_storage.rs index 00ca903b63..080d7d7c86 100644 --- a/golem-worker-executor-base/tests/key_value_storage.rs +++ b/golem-worker-executor-base/tests/key_value_storage.rs @@ -19,6 +19,8 @@ use golem_common::redis::RedisPool; use golem_test_framework::components::redis::Redis; use golem_test_framework::components::redis_monitor::RedisMonitor; use golem_test_framework::config::TestDependencies; +use golem_worker_executor_base::storage::cassandra::CassandraSession; +use golem_worker_executor_base::storage::keyvalue::cassandra::CassandraKeyValueStorage; use golem_worker_executor_base::storage::keyvalue::memory::InMemoryKeyValueStorage; use golem_worker_executor_base::storage::keyvalue::redis::RedisKeyValueStorage; use golem_worker_executor_base::storage::keyvalue::sqlite::SqliteKeyValueStorage; @@ -111,6 +113,35 @@ pub(crate) async fn sqlite_storage() -> impl GetKeyValueStorage { SqliteKeyValueStorageWrapper { kvs } } +struct CassandraKeyValueStorageWrapper { + kvs: CassandraKeyValueStorage, +} + +impl GetKeyValueStorage for CassandraKeyValueStorageWrapper { + fn get_key_value_storage(&self) -> &dyn KeyValueStorage { + &self.kvs + } +} + +pub(crate) async fn cassandra_storage() -> impl GetKeyValueStorage { + if let Some(cassandra) = BASE_DEPS.cassandra() { + cassandra.assert_valid(); + let test_keyspace = format!("golem_test_{}", &Uuid::new_v4().to_string()[..8]); + let session = cassandra.get_session(None).await; + let cassandra_session = CassandraSession::new(session, true, &test_keyspace); + + let kvs = CassandraKeyValueStorage::new(cassandra_session); + if let Err(err_msg) = kvs.create_schema().await { + cassandra.kill(); + panic!("Cannot create schema : {}", err_msg); + } + + CassandraKeyValueStorageWrapper { kvs } + } else { + panic!("Cassandra is not configured"); + } +} + pub fn ns() -> KeyValueStorageNamespace { KeyValueStorageNamespace::Worker } @@ -774,3 +805,10 @@ test_kv_storage!( crate::key_value_storage::ns2, crate::key_value_storage::ns ); + +test_kv_storage!( + cassandra, + crate::key_value_storage::cassandra_storage, + crate::key_value_storage::ns2, + crate::key_value_storage::ns +); diff --git a/golem-worker-executor-base/tests/lib.rs b/golem-worker-executor-base/tests/lib.rs index 611b8316b5..969c26d5e4 100644 --- a/golem-worker-executor-base/tests/lib.rs +++ b/golem-worker-executor-base/tests/lib.rs @@ -17,6 +17,8 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use ctor::{ctor, dtor}; +use golem_test_framework::components::cassandra::docker::DockerCassandra; +use golem_test_framework::components::cassandra::Cassandra; use tracing::Level; use golem_common::tracing::{init_tracing_with_default_debug_env_filter, TracingConfig}; @@ -67,6 +69,7 @@ pub(crate) struct WorkerExecutorPerTestDependencies { worker_service: Arc, component_service: Arc, component_directory: PathBuf, + cassandra: Arc, } impl TestDependencies for WorkerExecutorPerTestDependencies { @@ -107,6 +110,10 @@ impl TestDependencies for WorkerExecutorPerTestDependencies { fn worker_executor_cluster(&self) -> Arc { panic!("Not supported") } + + fn cassandra(&self) -> Option> { + Some(self.cassandra.clone()) + } } struct WorkerExecutorTestDependencies { @@ -114,6 +121,7 @@ struct WorkerExecutorTestDependencies { redis_monitor: Arc, component_service: Arc, component_directory: PathBuf, + cassandra: Arc, } impl WorkerExecutorTestDependencies { @@ -131,11 +139,13 @@ impl WorkerExecutorTestDependencies { let component_service: Arc = Arc::new( FileSystemComponentService::new(Path::new("data/components")), ); + let cassandra = Arc::new(DockerCassandra::new(false)); Self { redis, redis_monitor, component_directory, component_service, + cassandra, } } @@ -166,6 +176,7 @@ impl WorkerExecutorTestDependencies { worker_service, component_service: self.component_service().clone(), component_directory: self.component_directory.clone(), + cassandra: self.cassandra.clone(), } } } @@ -208,6 +219,10 @@ impl TestDependencies for WorkerExecutorTestDependencies { fn worker_executor_cluster(&self) -> Arc { panic!("Not supported") } + + fn cassandra(&self) -> Option> { + Some(self.cassandra.clone()) + } } #[ctor] @@ -219,6 +234,9 @@ unsafe fn drop_base_deps() { let base_deps_ptr = base_deps_ptr as *mut WorkerExecutorTestDependencies; (*base_deps_ptr).redis().kill(); (*base_deps_ptr).redis_monitor().kill(); + if let Some(c) = (*base_deps_ptr).cassandra() { + c.kill() + } } struct Tracing;