Skip to content

Commit 01fa040

Browse files
committed
Fix thread-safety of own_back<T>() when use lender feature, closing issue #14
1 parent d186822 commit 01fa040

File tree

4 files changed

+66
-15
lines changed

4 files changed

+66
-15
lines changed

src/lender.rs

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ pub(super) fn lent_type_of<T>(pointer: *const T) -> Option<TypeId> {
2727
lent_pointers.get(&(pointer as usize)).copied()
2828
}
2929

30+
fn writable_lent_pointers() -> RwLockWriteGuard<'static, HashMap<usize, TypeId>> {
31+
let Ok(lent_pointers) = LENT_POINTERS.write() else {
32+
log::error!("RwLock poisoned, it is not possible to add or remove pointers");
33+
unreachable!("RwLock poisoned, it is not possible to add or remove pointers");
34+
};
35+
36+
lent_pointers
37+
}
38+
3039
/// Use only when lend memory as a [`raw`](crate::raw) pointer.
3140
///
3241
/// # Panics
@@ -53,15 +62,18 @@ pub(super) fn lend<T: 'static>(pointer: *const T) -> Result<(), PointerError> {
5362
/// If the [`RwLock`] used is poisoned, but it only happens if a panic happens
5463
/// while holding it. And it's specially reviewed and in a small module to
5564
/// avoid panics while holding it.
56-
pub(super) fn retrieve<T>(pointer: *const T) {
57-
writable_lent_pointers().remove(&(pointer as usize));
58-
}
59-
60-
fn writable_lent_pointers() -> RwLockWriteGuard<'static, HashMap<usize, TypeId>> {
61-
let Ok(lent_pointers) = LENT_POINTERS.write() else {
62-
log::error!("RwLock poisoned, it is not possible to add or remove pointers");
63-
unreachable!();
64-
};
65-
66-
lent_pointers
65+
pub(super) fn retrieve<T: 'static>(pointer: *const T) -> Result<(), PointerError> {
66+
match writable_lent_pointers().remove(&(pointer as usize)) {
67+
Some(type_id) if type_id != TypeId::of::<T>() => {
68+
log::error!(
69+
"Using a pointer with a different type as an opaque pointer to Rust's data"
70+
);
71+
Err(PointerError::InvalidType)
72+
}
73+
None => {
74+
log::error!("Using an invalid pointer as an opaque pointer to Rust's data");
75+
Err(PointerError::Invalid)
76+
}
77+
_ => Ok(()),
78+
}
6779
}

src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,14 @@ pub fn raw<T: 'static>(data: T) -> Result<*mut T, PointerError> {
6363
#[allow(clippy::not_unsafe_ptr_arg_deref)]
6464
pub unsafe fn own_back<T: 'static>(pointer: *mut T) -> Result<T, PointerError> {
6565
validation::not_null_pointer(pointer)?;
66+
67+
// TODO: Simplify and optimize this.
6668
#[cfg(all(feature = "std", feature = "lender"))]
6769
validation::lent_pointer(pointer)?;
68-
let boxed = { Box::from_raw(pointer) };
69-
7070
#[cfg(all(feature = "std", feature = "lender"))]
71-
lender::retrieve(pointer);
71+
lender::retrieve(pointer)?;
7272

73-
Ok(*boxed)
73+
Ok(*Box::from_raw(pointer))
7474
}
7575

7676
/// Reference to an object but without to own it.

tests/pointer.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,16 @@ fn own_back_invalid_type() {
4343
unsafe { opaque_pointer::own_back(pointer as *mut TestIt).unwrap() };
4444
}
4545

46+
#[cfg(all(feature = "std", feature = "lender"))]
47+
#[test]
48+
fn own_back_twice_error() {
49+
let pointer = opaque_pointer::raw(TestIt::new(2)).unwrap();
50+
unsafe { opaque_pointer::own_back(pointer).unwrap() };
51+
52+
let error = unsafe { opaque_pointer::own_back(pointer).unwrap_err() };
53+
assert_eq!(error, PointerError::Invalid);
54+
}
55+
4656
#[test]
4757
fn immutable_reference() {
4858
let pointer = opaque_pointer::raw(TestIt::new(2)).unwrap();

tests/thread_safety.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use std::thread;
2+
use std::time::Duration;
3+
use opaque_pointer;
4+
5+
#[cfg(all(feature = "std", feature = "lender"))]
6+
#[test]
7+
fn own_back() {
8+
let for_test = 0;
9+
let pointer = opaque_pointer::raw(for_test).unwrap();
10+
11+
let mut threads = Vec::new();
12+
for _ in 0..1000 {
13+
let pointer = pointer as usize;
14+
threads.push(thread::spawn(move || {
15+
thread::sleep(Duration::from_millis(5));
16+
unsafe { opaque_pointer::own_back(pointer as *mut i32).is_ok() }
17+
}));
18+
}
19+
20+
// If all works well, only one thread will be able to own_back the pointer.
21+
let mut counter = 0;
22+
for thread in threads {
23+
if let Ok(true) = thread.join() {
24+
counter += 1;
25+
}
26+
}
27+
28+
assert_eq!(1, counter);
29+
}

0 commit comments

Comments
 (0)