diff --git a/Cargo.toml b/Cargo.toml index aa85a82..9658ec1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["codegen", "examples", "performance_measurement", "performance_measur [package] name = "worktable" -version = "0.6.11" +version = "0.6.12" edition = "2024" authors = ["Handy-caT"] license = "MIT" @@ -16,13 +16,14 @@ perf_measurements = ["dep:performance_measurement", "dep:performance_measurement # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +worktable_codegen = { path = "codegen", version = "0.6.12" } + eyre = "0.6.12" derive_more = { version = "1.0.0", features = ["from", "error", "display", "into"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" rkyv = { version = "0.8.9", features = ["uuid-1"] } lockfree = { version = "0.5.1" } -worktable_codegen = { path = "codegen", version = "0.6.11" } fastrand = "2.3.0" futures = "0.3.30" uuid = { version = "1.10.0", features = ["v4", "v7"] } diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml index d5585bd..6c80721 100644 --- a/codegen/Cargo.toml +++ b/codegen/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "worktable_codegen" -version = "0.6.11" +version = "0.6.12" edition = "2024" license = "MIT" description = "WorkTable codegeneration crate" diff --git a/codegen/src/worktable/generator/queries/in_place.rs b/codegen/src/worktable/generator/queries/in_place.rs index e673974..126097a 100644 --- a/codegen/src/worktable/generator/queries/in_place.rs +++ b/codegen/src/worktable/generator/queries/in_place.rs @@ -117,15 +117,15 @@ impl Generator { by: #by_type, ) -> eyre::Result<()> { let pk: #pk_type = by.into(); + let lock = { + #custom_lock + }; let link = self .0 .pk_map .get(&pk) .map(|v| v.get().value) .ok_or(WorkTableError::NotFound)?; - let lock = { - #custom_lock - }; unsafe { self.0 .data diff --git a/codegen/src/worktable/generator/queries/locks.rs b/codegen/src/worktable/generator/queries/locks.rs index 5be6c40..34ba174 100644 --- a/codegen/src/worktable/generator/queries/locks.rs +++ b/codegen/src/worktable/generator/queries/locks.rs @@ -132,7 +132,7 @@ impl Generator { #[allow(clippy::mutable_key_type)] let (locks, op_lock) = lock_guard.lock(lock_id); drop(lock_guard); - futures::future::join_all(locks.iter().map(|l| l.as_ref()).collect::>()).await; + futures::future::join_all(locks.iter().map(|l| l.wait()).collect::>()).await; op_lock } else { @@ -147,7 +147,7 @@ impl Generator { drop(old_lock_guard); drop(guard); - futures::future::join_all(locks.iter().map(|l| l.as_ref()).collect::>()).await; + futures::future::join_all(locks.iter().map(|l| l.wait()).collect::>()).await; } op_lock @@ -166,8 +166,7 @@ impl Generator { #[allow(clippy::mutable_key_type)] let (locks, op_lock) = lock_guard.#ident(lock_id); drop(lock_guard); - futures::future::join_all(locks.iter().map(|l| l.as_ref()).collect::>()).await; - + futures::future::join_all(locks.iter().map(|l| l.wait()).collect::>()).await; op_lock } else { let mut lock = #lock_ident::new(); @@ -182,7 +181,7 @@ impl Generator { drop(old_lock_guard); drop(guard); - futures::future::join_all(locks.iter().map(|l| l.as_ref()).collect::>()).await; + futures::future::join_all(locks.iter().map(|l| l.wait()).collect::>()).await; } op_lock diff --git a/codegen/src/worktable/generator/queries/update.rs b/codegen/src/worktable/generator/queries/update.rs index 566e82d..f75178d 100644 --- a/codegen/src/worktable/generator/queries/update.rs +++ b/codegen/src/worktable/generator/queries/update.rs @@ -518,6 +518,7 @@ impl Generator { locks.insert(pk, op_lock); } + let links: Vec<_> = self.0.indexes.#index.get(#by).map(|(_, l)| *l).collect(); let mut pk_to_unlock: std::collections::HashMap<_, std::sync::Arc> = std::collections::HashMap::new(); let op_id = OperationId::Multi(uuid::Uuid::now_v7()); for link in links.into_iter() { @@ -612,6 +613,11 @@ impl Generator { #custom_lock }; + let link = self.0.indexes.#index + .get(#by) + .map(|kv| kv.get().value) + .ok_or(WorkTableError::NotFound)?; + let op_id = OperationId::Single(uuid::Uuid::now_v7()); #size_check #diff_process diff --git a/src/lock/map.rs b/src/lock/map.rs index f998481..d898806 100644 --- a/src/lock/map.rs +++ b/src/lock/map.rs @@ -49,14 +49,12 @@ where LockType: RowLock, { let mut set = self.map.write(); - let remove = if let Some(lock) = set.get(key) { - !lock.read().await.is_locked() - } else { - false - }; - - if remove { - set.remove(key); + if let Some(lock) = set.get(key).cloned() { + if let Ok(guard) = lock.try_read() { + if !guard.is_locked() { + set.remove(key); + } + } } } diff --git a/src/lock/mod.rs b/src/lock/mod.rs index 48e7cea..f4e075c 100644 --- a/src/lock/mod.rs +++ b/src/lock/mod.rs @@ -5,19 +5,20 @@ use std::future::Future; use std::hash::{Hash, Hasher}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; use std::task::{Context, Poll}; use derive_more::From; use futures::task::AtomicWaker; - pub use map::LockMap; +use parking_lot::Mutex; pub use row_lock::RowLock; #[derive(Debug)] pub struct Lock { id: u16, - locked: AtomicBool, - waker: AtomicWaker, + locked: Arc, + wakers: Mutex>>, } impl PartialEq for Lock { @@ -38,32 +39,59 @@ impl Lock { pub fn new(id: u16) -> Self { Self { id, - locked: AtomicBool::from(true), - waker: AtomicWaker::new(), + locked: Arc::new(AtomicBool::from(true)), + wakers: Mutex::new(vec![]), } } + pub fn id(&self) -> u16 { + self.id + } + pub fn unlock(&self) { - self.locked.store(false, Ordering::Release); - self.waker.wake() + self.locked.store(false, Ordering::Relaxed); + let guard = self.wakers.lock(); + for w in guard.iter() { + w.wake() + } } pub fn lock(&self) { - self.locked.store(true, Ordering::Release); - self.waker.wake() + self.locked.store(true, Ordering::Relaxed); } pub fn is_locked(&self) -> bool { - self.locked.load(Ordering::Acquire) + self.locked.load(Ordering::Relaxed) + } + + pub fn wait(&self) -> LockWait { + let mut guard = self.wakers.lock(); + let waker = Arc::new(AtomicWaker::new()); + guard.push(waker.clone()); + LockWait { + locked: self.locked.clone(), + waker, + } } } -impl Future for &Lock { +#[derive(Debug)] +pub struct LockWait { + locked: Arc, + waker: Arc, +} + +impl Future for LockWait { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.as_ref().waker.register(cx.waker()); - if self.locked.load(Ordering::Acquire) { + if !self.locked.load(Ordering::Relaxed) { + return Poll::Ready(()); + } + + self.waker.register(cx.waker()); + + if self.locked.load(Ordering::Relaxed) { Poll::Pending } else { Poll::Ready(()) diff --git a/tests/worktable/in_place.rs b/tests/worktable/in_place.rs index e389e44..b0e51fa 100644 --- a/tests/worktable/in_place.rs +++ b/tests/worktable/in_place.rs @@ -14,6 +14,7 @@ worktable!( val1: u64, val2: i16, another: String, + something: u64, }, queries: { in_place: { @@ -22,6 +23,7 @@ worktable!( } update: { AnotherById(another) by id, + SomethingById(something) by id, } } ); @@ -35,6 +37,7 @@ async fn test_update_val_by_id() -> eyre::Result<()> { val1: 0, val2: 0, another: "another".to_string(), + something: 0, }; let pk = table.insert(row)?; for _ in 0..10000 { @@ -56,6 +59,7 @@ async fn test_update_val2_by_id() -> eyre::Result<()> { val1: 0, val2: 0, another: "another".to_string(), + something: 0, }; let pk = table.insert(row)?; for _ in 0..100 { @@ -69,7 +73,7 @@ async fn test_update_val2_by_id() -> eyre::Result<()> { } #[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_update_val_by_id_multi_thread() -> eyre::Result<()> { +async fn test_update_val_by_id_two_thread() -> eyre::Result<()> { let table = Arc::new(TestWorkTable::default()); let row = TestRow { id: table.get_next_pk().0, @@ -77,6 +81,7 @@ async fn test_update_val_by_id_multi_thread() -> eyre::Result<()> { val1: 0, val2: 0, another: "another".to_string(), + something: 0, }; let pk = table.insert(row)?; let shared_table = table.clone(); @@ -99,11 +104,117 @@ async fn test_update_val_by_id_multi_thread() -> eyre::Result<()> { Ok(()) } -#[tokio::test(flavor = "multi_thread", worker_threads = 2)] -async fn test_update_in_place_and_usual_multithread() -> eyre::Result<()> { +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_update_val_and_val2_by_id_four_thread() -> eyre::Result<()> { + let table = Arc::new(TestWorkTable::default()); + let row = TestRow { + id: table.get_next_pk().0, + val: 0, + val1: 0, + val2: 0, + another: "another".to_string(), + something: 0, + }; + let pk = table.insert(row)?; + let shared_table = table.clone(); + let h1 = tokio::spawn(async move { + for _ in 0..10_000 { + shared_table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + let shared_table = table.clone(); + let h2 = tokio::spawn(async move { + for _ in 0..10_000 { + shared_table + .update_val_2_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + let shared_table = table.clone(); + let h3 = tokio::spawn(async move { + for _ in 0..10_000 { + shared_table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + for _ in 0..10_000 { + table + .update_val_2_by_id_in_place(|val| *val += 1, pk.0) + .await? + } + h1.await?; + h2.await?; + h3.await?; + let row = table.select(pk).unwrap(); + assert_eq!(row.val, 20_000); + assert_eq!(row.val2, 20_000); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 4)] +async fn test_update_val_by_id_four_thread() -> eyre::Result<()> { + let table = Arc::new(TestWorkTable::default()); + let row = TestRow { + id: table.get_next_pk().0, + val: 0, + val1: 0, + val2: 0, + another: "another".to_string(), + something: 0, + }; + let pk = table.insert(row)?; + let shared_table = table.clone(); + let h1 = tokio::spawn(async move { + for _ in 0..10_000 { + shared_table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + let shared_table = table.clone(); + let h2 = tokio::spawn(async move { + for _ in 0..10_000 { + shared_table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + let shared_table = table.clone(); + let h3 = tokio::spawn(async move { + for _ in 0..10_000 { + shared_table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await + .unwrap() + } + }); + for _ in 0..10_000 { + table + .update_val_by_id_in_place(|val| *val += 1, pk.0) + .await? + } + h1.await?; + h2.await?; + h3.await?; + let row = table.select(pk).unwrap(); + assert_eq!(row.val, 40_000); + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 3)] +async fn test_update_in_place_and_update_sized_multithread() -> eyre::Result<()> { let table = Arc::new(TestWorkTable::default()); let i_state = Arc::new(Mutex::new(HashMap::new())); let val_state = Arc::new(Mutex::new(HashMap::new())); + let val2_state = Arc::new(Mutex::new(HashMap::new())); for i in 0..100 { let row = TestRow { id: table.get_next_pk().into(), @@ -111,13 +222,14 @@ async fn test_update_in_place_and_usual_multithread() -> eyre::Result<()> { val1: 0, val2: 0, another: format!("another_{i}"), + something: 0, }; let _ = table.insert(row.clone())?; } let shared = table.clone(); let shared_val_state = val_state.clone(); - let h = tokio::spawn(async move { - for _ in 0..50_000 { + let h1 = tokio::spawn(async move { + for _ in 0..15_000 { let val = fastrand::i64(..); let id_to_update = fastrand::u64(0..=99); shared @@ -133,8 +245,115 @@ async fn test_update_in_place_and_usual_multithread() -> eyre::Result<()> { } } }); + let shared = table.clone(); + let shared_val2_state = val2_state.clone(); + let h2 = tokio::spawn(async move { + for _ in 0..15_000 { + let val = fastrand::i16(..); + let id_to_update = fastrand::u64(0..=99); + shared + .update_val_2_by_id_in_place(|v| *v = val.into(), id_to_update) + .await + .unwrap(); + { + let mut guard = shared_val2_state.lock(); + guard + .entry(id_to_update) + .and_modify(|v| *v = val) + .or_insert(val); + } + } + }); tokio::time::sleep(Duration::from_micros(20)).await; - for _ in 0..1000 { + for _ in 0..5000 { + let val = fastrand::u64(..); + let id_to_update = fastrand::u64(0..=99); + table + .update_something_by_id(SomethingByIdQuery { something: val }, id_to_update.into()) + .await?; + { + let mut guard = i_state.lock(); + guard + .entry(id_to_update) + .and_modify(|v| *v = val) + .or_insert(val); + } + } + h1.await?; + h2.await?; + + for (id, smth) in i_state.lock_arc().iter() { + let row = table.select((*id).into()).unwrap(); + assert_eq!(&row.something, smth); + } + for (id, val) in val2_state.lock_arc().iter() { + let row = table.select((*id).into()).unwrap(); + assert_eq!(&row.val2, val); + } + for (id, val) in val_state.lock_arc().iter() { + let row = table.select((*id).into()).unwrap(); + assert_eq!(&row.val, val); + } + Ok(()) +} + +#[tokio::test(flavor = "multi_thread", worker_threads = 3)] +async fn test_update_in_place_and_update_unsized_multithread() -> eyre::Result<()> { + let table = Arc::new(TestWorkTable::default()); + let i_state = Arc::new(Mutex::new(HashMap::new())); + let val_state = Arc::new(Mutex::new(HashMap::new())); + let val2_state = Arc::new(Mutex::new(HashMap::new())); + for i in 0..100 { + let row = TestRow { + id: table.get_next_pk().into(), + val: 0, + val1: 0, + val2: 0, + another: format!("another_{i}"), + something: 0, + }; + let _ = table.insert(row.clone())?; + } + let shared = table.clone(); + let shared_val_state = val_state.clone(); + let h1 = tokio::spawn(async move { + for _ in 0..15_000 { + let val = fastrand::i64(..); + let id_to_update = fastrand::u64(0..=99); + shared + .update_val_by_id_in_place(|v| *v = val.into(), id_to_update) + .await + .unwrap(); + { + let mut guard = shared_val_state.lock(); + guard + .entry(id_to_update) + .and_modify(|v| *v = val) + .or_insert(val); + } + } + }); + let shared = table.clone(); + let shared_val2_state = val2_state.clone(); + let h2 = tokio::spawn(async move { + for _ in 0..15_000 { + let val = fastrand::i16(..); + let id_to_update = fastrand::u64(0..=99); + shared + .update_val_2_by_id_in_place(|v| *v = val.into(), id_to_update) + .await + .unwrap(); + { + let mut guard = shared_val2_state.lock(); + guard + .entry(id_to_update) + .and_modify(|v| *v = val) + .or_insert(val); + } + } + }); + tokio::time::sleep(Duration::from_micros(20)).await; + for _ in 0..5000 { let val = fastrand::u64(..); let id_to_update = fastrand::u64(0..=99); table @@ -152,17 +371,29 @@ async fn test_update_in_place_and_usual_multithread() -> eyre::Result<()> { .and_modify(|v| *v = format!("another_{val}")) .or_insert(format!("another_{val}")); } - tokio::time::sleep(Duration::from_micros(5)).await; } - h.await?; + h1.await?; + h2.await?; - for (id, another) in i_state.lock_arc().iter() { + for (id, smth) in i_state.lock_arc().iter() { let row = table.select((*id).into()).unwrap(); - assert_eq!(&row.another, another); + assert_eq!(&row.another, smth); } + let mut errors = 0; + for (id, val) in val2_state.lock_arc().iter() { + let row = table.select((*id).into()).unwrap(); + if &row.val2 != val { + errors += 1; + } + } + assert_eq!(errors, 0); + let mut errors = 0; for (id, val) in val_state.lock_arc().iter() { let row = table.select((*id).into()).unwrap(); - assert_eq!(&row.val, val); + if &row.val != val { + errors += 1; + } } + assert_eq!(errors, 0); Ok(()) }