Skip to content

Commit 8c9a861

Browse files
committed
Add unit tests for the cache
1 parent d364306 commit 8c9a861

File tree

1 file changed

+223
-6
lines changed

1 file changed

+223
-6
lines changed

neo4j/src/driver/home_db_cache.rs

+223-6
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,11 @@ impl Default for HomeDbCache {
4646
impl HomeDbCache {
4747
pub(super) fn new(max_size: usize) -> Self {
4848
let max_size_f64 = max_size as f64;
49-
let prune_size = usize::min(max_size, (max_size_f64 * 0.01).log(max_size_f64) as usize);
49+
let mut prune_size = (0.01 * max_size_f64 * max_size_f64.ln()) as usize;
50+
prune_size = usize::min(prune_size, max_size);
51+
if prune_size == 0 && max_size > 0 {
52+
prune_size = 1; // ensure at least one entry is pruned
53+
}
5054
HomeDbCache {
5155
cache: Mutex::new(HashMap::with_capacity(max_size)),
5256
config: HomeDbCacheConfig {
@@ -221,11 +225,8 @@ impl SessionAuthKey {
221225
}
222226

223227
impl HomeDbCacheKey {
224-
pub(super) fn new(
225-
imp_user: Option<&Arc<String>>,
226-
session_auth: Option<&Arc<AuthToken>>,
227-
) -> Self {
228-
if let Some(user) = imp_user {
228+
pub(super) fn new(user: Option<&Arc<String>>, session_auth: Option<&Arc<AuthToken>>) -> Self {
229+
if let Some(user) = user {
229230
HomeDbCacheKey::FixedUser(Arc::clone(user))
230231
} else if let Some(auth) = session_auth {
231232
if let Some(ValueSend::String(scheme)) = auth.data.get("scheme") {
@@ -247,3 +248,219 @@ struct HomeDbCacheEntry {
247248
database: Arc<String>,
248249
last_used: Instant,
249250
}
251+
252+
#[cfg(test)]
253+
mod test {
254+
use rstest::*;
255+
256+
use crate::value::time;
257+
use crate::value_map;
258+
259+
use super::*;
260+
261+
#[rstest]
262+
#[case(HashMap::new(), HashMap::new())]
263+
#[case(
264+
value_map!({
265+
"list": [1, 1.5, ValueSend::Null, "string", true],
266+
"principal": "user",
267+
"map": value_map!({
268+
"nested": value_map!({
269+
"key": "value",
270+
"when": time::LocalDateTime::new(
271+
time::Date::from_ymd_opt(2021, 1, 1).unwrap(),
272+
time::LocalTime::from_hms_opt(12, 0, 0).unwrap(),
273+
),
274+
}),
275+
"point": spatial::Cartesian2D::new(1.0, 2.0),
276+
"key": "value",
277+
}),
278+
"nan": ValueSend::Float(f64::NAN),
279+
"foo": "bar",
280+
}),
281+
value_map!({
282+
"foo": "bar",
283+
"principal": "user",
284+
"nan": ValueSend::Float(f64::NAN),
285+
"list": [1, 1.5, ValueSend::Null, "string", true],
286+
"map": value_map!({
287+
"key": "value",
288+
"nested": value_map!({
289+
"key": "value",
290+
"when": time::LocalDateTime::new(
291+
time::Date::from_ymd_opt(2021, 1, 1).unwrap(),
292+
time::LocalTime::from_hms_opt(12, 0, 0).unwrap(),
293+
),
294+
}),
295+
"point": spatial::Cartesian2D::new(1.0, 2.0),
296+
}),
297+
})
298+
)]
299+
fn test_cache_key_equality(
300+
#[case] a: HashMap<String, ValueSend>,
301+
#[case] b: HashMap<String, ValueSend>,
302+
) {
303+
let auth1 = Arc::new(AuthToken { data: a });
304+
let auth2 = Arc::new(AuthToken { data: b });
305+
let key1 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth1)));
306+
let key2 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth2)));
307+
assert_eq!(key1, key1);
308+
assert_eq!(key1, key2);
309+
assert_eq!(key2, key1);
310+
assert_eq!(key2, key2);
311+
312+
let mut hasher1 = std::hash::DefaultHasher::new();
313+
let mut hasher2 = std::hash::DefaultHasher::new();
314+
key1.hash(&mut hasher1);
315+
key2.hash(&mut hasher2);
316+
assert_eq!(hasher1.finish(), hasher2.finish());
317+
}
318+
319+
#[rstest]
320+
#[case(value_map!({"principal": "user"}), value_map!({"principal": "admin"}))]
321+
#[case(value_map!({"int": 1}), value_map!({"int": 2}))]
322+
#[case(value_map!({"int": 1}), value_map!({"int": 1.0}))]
323+
#[case(value_map!({"zero": 0.0}), value_map!({"zero": -0.0}))]
324+
#[case(value_map!({"large": f64::INFINITY}), value_map!({"large": f64::NEG_INFINITY}))]
325+
#[case(value_map!({"nan": f64::NAN}), value_map!({"nan": -f64::NAN}))]
326+
#[case(value_map!({"int": 1}), value_map!({"int": "1"}))]
327+
#[case(value_map!({"list": [1, 2]}), value_map!({"list": [2, 1]}))]
328+
fn test_cache_key_inequality(
329+
#[case] a: HashMap<String, ValueSend>,
330+
#[case] b: HashMap<String, ValueSend>,
331+
) {
332+
let auth1 = Arc::new(AuthToken { data: a });
333+
let auth2 = Arc::new(AuthToken { data: b });
334+
let key1 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth1)));
335+
let key2 = HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::clone(&auth2)));
336+
assert_ne!(key1, key2);
337+
}
338+
339+
fn fixed_user_key(user: &str) -> HomeDbCacheKey {
340+
HomeDbCacheKey::FixedUser(Arc::new(user.to_string()))
341+
}
342+
343+
fn auth_basic(principal: &str) -> AuthToken {
344+
AuthToken {
345+
data: value_map!({
346+
"scheme": "basic",
347+
"principal": principal,
348+
"credentials": "password",
349+
}),
350+
}
351+
}
352+
353+
fn any_auth_key() -> HomeDbCacheKey {
354+
HomeDbCacheKey::SessionAuth(SessionAuthKey(Arc::new(AuthToken {
355+
data: Default::default(),
356+
})))
357+
}
358+
359+
#[rstest]
360+
#[case(None, None, HomeDbCacheKey::DriverUser)]
361+
#[case(Some("user"), None, fixed_user_key("user"))]
362+
#[case(Some("user"), Some(auth_basic("user2")), fixed_user_key("user"))]
363+
#[case(
364+
None,
365+
Some(AuthToken::new_basic_auth("user2", "password")),
366+
fixed_user_key("user2")
367+
)]
368+
#[case(
369+
None,
370+
Some(AuthToken::new_basic_auth_with_realm("user2", "password", "my-realm")),
371+
fixed_user_key("user2")
372+
)]
373+
#[case(None, Some(AuthToken::new_basic_auth("", "empty")), fixed_user_key(""))]
374+
#[case(None, Some(AuthToken::new_none_auth()), any_auth_key())]
375+
#[case(None, Some(AuthToken::new_bearer_auth("token123")), any_auth_key())]
376+
#[case(None, Some(AuthToken::new_kerberos_auth("token123")), any_auth_key())]
377+
#[case(
378+
None,
379+
Some(AuthToken::new_custom_auth(None, None, None, None, None)),
380+
any_auth_key()
381+
)]
382+
#[case(
383+
None,
384+
Some(AuthToken::new_custom_auth(
385+
Some("principal".into()),
386+
Some("credentials".into()),
387+
Some("realm".into()),
388+
Some("scheme".into()),
389+
Some(value_map!({"key": "value"})),
390+
)),
391+
any_auth_key()
392+
)]
393+
fn test_cache_key_new(
394+
#[case] user: Option<&str>,
395+
#[case] session_auth: Option<AuthToken>,
396+
#[case] expected: HomeDbCacheKey,
397+
) {
398+
let user = user.map(String::from).map(Arc::new);
399+
let session_auth = session_auth.map(Arc::new);
400+
let expected = match expected {
401+
HomeDbCacheKey::SessionAuth(_) => HomeDbCacheKey::SessionAuth(SessionAuthKey(
402+
Arc::clone(session_auth.as_ref().unwrap()),
403+
)),
404+
_ => expected,
405+
};
406+
assert_eq!(
407+
HomeDbCacheKey::new(user.as_ref(), session_auth.as_ref()),
408+
expected
409+
);
410+
}
411+
412+
#[rstest]
413+
#[case(0, 0)]
414+
#[case(1, 1)]
415+
#[case(5, 1)]
416+
#[case(50, 1)]
417+
#[case(60, 2)]
418+
#[case(100, 4)]
419+
#[case(200, 10)]
420+
#[case(1_000, 69)]
421+
#[case(10_000, 921)]
422+
#[case(100_000, 11_512)]
423+
#[case(1_000_000, 138_155)]
424+
fn test_cache_pruning_size(#[case] max_size: usize, #[case] expected: usize) {
425+
let cache = HomeDbCache::new(max_size);
426+
assert_eq!(cache.config.prune_size, expected);
427+
}
428+
429+
#[test]
430+
fn test_pruning() {
431+
const SIZE: usize = 200;
432+
const PRUNE_SIZE: usize = 10;
433+
let cache = HomeDbCache::new(SIZE);
434+
// sanity check
435+
assert_eq!(cache.config.prune_size, PRUNE_SIZE);
436+
437+
let users: Vec<_> = (0..=SIZE).map(|i| Arc::new(format!("user{i}"))).collect();
438+
let keys: Vec<_> = (0..=SIZE)
439+
.map(|i| HomeDbCacheKey::new(Some(&users[i]), None))
440+
.collect();
441+
let entries: Vec<_> = (0..=SIZE).map(|i| Arc::new(format!("db{i}"))).collect();
442+
443+
// WHEN: cache is filled to the max
444+
for i in 0..SIZE {
445+
cache.update(keys[i].clone(), Arc::clone(&entries[i]));
446+
}
447+
// THEN: no entry has been removed
448+
for i in 0..SIZE {
449+
assert_eq!(cache.get(&keys[i]), Some(Arc::clone(&entries[i])));
450+
}
451+
452+
// WHEN: The oldest entry is touched
453+
cache.get(&keys[0]);
454+
// AND: cache is filled with one more entry
455+
cache.update(keys[SIZE].clone(), Arc::clone(&entries[SIZE]));
456+
// THEN: the oldest PRUNE_SIZE entries (2nd to (PRUNE_SIZE + 1)th) are pruned
457+
for key in keys.iter().skip(1).take(PRUNE_SIZE) {
458+
assert_eq!(cache.get(key), None);
459+
}
460+
// AND: the rest of the entries are still in the cache
461+
assert_eq!(cache.get(&keys[0]), Some(Arc::clone(&entries[0])));
462+
for i in PRUNE_SIZE + 2..=SIZE {
463+
assert_eq!(cache.get(&keys[i]), Some(Arc::clone(&entries[i])));
464+
}
465+
}
466+
}

0 commit comments

Comments
 (0)