diff --git a/src/rate_limiter.rs b/src/rate_limiter.rs index 5268696..15a4913 100644 --- a/src/rate_limiter.rs +++ b/src/rate_limiter.rs @@ -1,8 +1,9 @@ //! Rate limiting for token processing. //! -//! Provides a fixed-window rate limiter with per-minute and per-hour limits +//! Provides a sliding-window rate limiter with per-minute and per-hour limits //! to prevent spam and replay attacks. +use std::collections::VecDeque; use std::hash::Hash; use std::num::NonZeroUsize; use std::time::Duration; @@ -58,21 +59,38 @@ impl RateLimitResult { /// Entry tracking rate limit counters for a single key. #[derive(Debug, Clone)] struct RateLimitEntry { - minute_count: u32, - minute_window_start: Instant, - hour_count: u32, - hour_window_start: Instant, + minute_hits: VecDeque, + hour_hits: VecDeque, } impl RateLimitEntry { - fn new(now: Instant) -> Self { + fn new() -> Self { Self { - minute_count: 0, - minute_window_start: now, - hour_count: 0, - hour_window_start: now, + minute_hits: VecDeque::new(), + hour_hits: VecDeque::new(), } } + + fn prune(&mut self, now: Instant) { + prune_hits(&mut self.minute_hits, now, Duration::from_secs(60)); + prune_hits(&mut self.hour_hits, now, Duration::from_secs(3600)); + } +} + +fn prune_hits(hits: &mut VecDeque, now: Instant, window: Duration) { + let old_len = hits.len(); + + while hits + .front() + .is_some_and(|hit| now.duration_since(*hit) >= window) + { + hits.pop_front(); + } + + let new_len = hits.len(); + if old_len > new_len && (new_len == 0 || hits.capacity() > new_len.saturating_mul(2).max(8)) { + hits.shrink_to_fit(); + } } /// Statistics returned from rate limiter cleanup. @@ -109,7 +127,12 @@ impl Default for RateLimitConfig { } } -/// Fixed-window rate limiter with per-minute and per-hour limits. +/// Sliding-window rate limiter with per-minute and per-hour limits. +/// +/// Each active entry stores admitted request timestamps for true sliding-window +/// enforcement, so memory scales with the number of admitted hits per key. With +/// the defaults, a hot key can hold up to roughly 5,240 `Instant` values across +/// the minute and hour windows before older hits are pruned. /// /// Uses a bounded cache to limit memory usage. When the cache is full, /// unknown keys are rejected until cleanup removes stale entries. This @@ -152,35 +175,25 @@ impl RateLimiter { if entries.len() >= entries.cap().get() { return RateLimitResult::ExceededCapacityLimit; } - entries.put(key.clone(), RateLimitEntry::new(now)); + entries.put(key.clone(), RateLimitEntry::new()); entries.get_mut(key).expect("just inserted") }; - // Reset minute window if expired - if now.duration_since(entry.minute_window_start) >= Duration::from_secs(60) { - entry.minute_count = 0; - entry.minute_window_start = now; - } - - // Reset hour window if expired - if now.duration_since(entry.hour_window_start) >= Duration::from_secs(3600) { - entry.hour_count = 0; - entry.hour_window_start = now; - } + entry.prune(now); // Check minute limit first (more likely to be hit) - if entry.minute_count >= self.max_per_minute { + if entry.minute_hits.len() >= self.max_per_minute as usize { return RateLimitResult::ExceededMinuteLimit; } // Check hour limit - if entry.hour_count >= self.max_per_hour { + if entry.hour_hits.len() >= self.max_per_hour as usize { return RateLimitResult::ExceededHourLimit; } // Increment counters - entry.minute_count += 1; - entry.hour_count += 1; + entry.minute_hits.push_back(now); + entry.hour_hits.push_back(now); RateLimitResult::Allowed } @@ -194,9 +207,9 @@ impl RateLimiter { pub async fn rollback_increment(&self, key: &K) { let mut entries = self.entries.write().await; let should_remove = if let Some(entry) = entries.get_mut(key) { - entry.minute_count = entry.minute_count.saturating_sub(1); - entry.hour_count = entry.hour_count.saturating_sub(1); - entry.minute_count == 0 && entry.hour_count == 0 + entry.minute_hits.pop_back(); + entry.hour_hits.pop_back(); + entry.minute_hits.is_empty() && entry.hour_hits.is_empty() } else { false }; @@ -222,7 +235,10 @@ impl RateLimiter { .iter() .rev() .take(CLEANUP_BATCH_SIZE) - .filter(|(_, entry)| now.duration_since(entry.hour_window_start) >= hour) + .filter(|(_, entry)| match entry.hour_hits.back() { + Some(hit) => now.duration_since(*hit) >= hour, + None => true, + }) .map(|(k, _)| k.clone()) .collect(); @@ -248,7 +264,7 @@ impl RateLimiter { let entries = self.entries.read().await; entries .peek(key) - .map(|entry| (entry.minute_count, entry.hour_count)) + .map(|entry| (entry.minute_hits.len() as u32, entry.hour_hits.len() as u32)) } } @@ -349,6 +365,37 @@ mod tests { assert!(limiter.check_and_increment(&1u64).await.is_allowed()); } + #[tokio::test] + async fn test_minute_limit_slides_across_window_boundary() { + tokio::time::pause(); + + let limiter: RateLimiter = RateLimiter::new(RateLimitConfig { + max_per_minute: 5, + max_per_hour: 100, + max_entries: 100, + }); + + assert!(limiter.check_and_increment(&1u64).await.is_allowed()); + + tokio::time::advance(Duration::from_millis(59_900)).await; + + for _ in 0..4 { + assert!(limiter.check_and_increment(&1u64).await.is_allowed()); + } + assert_eq!( + limiter.check_and_increment(&1u64).await, + RateLimitResult::ExceededMinuteLimit + ); + + tokio::time::advance(Duration::from_millis(100)).await; + + assert!(limiter.check_and_increment(&1u64).await.is_allowed()); + assert_eq!( + limiter.check_and_increment(&1u64).await, + RateLimitResult::ExceededMinuteLimit + ); + } + #[tokio::test] async fn test_hour_window_reset() { tokio::time::pause(); @@ -370,6 +417,37 @@ mod tests { assert!(limiter.check_and_increment(&1u64).await.is_allowed()); } + #[tokio::test] + async fn test_hour_limit_slides_across_window_boundary() { + tokio::time::pause(); + + let limiter: RateLimiter = RateLimiter::new(RateLimitConfig { + max_per_minute: 100, + max_per_hour: 5, + max_entries: 100, + }); + + assert!(limiter.check_and_increment(&1u64).await.is_allowed()); + + tokio::time::advance(Duration::from_millis(3_599_900)).await; + + for _ in 0..4 { + assert!(limiter.check_and_increment(&1u64).await.is_allowed()); + } + assert_eq!( + limiter.check_and_increment(&1u64).await, + RateLimitResult::ExceededHourLimit + ); + + tokio::time::advance(Duration::from_millis(100)).await; + + assert!(limiter.check_and_increment(&1u64).await.is_allowed()); + assert_eq!( + limiter.check_and_increment(&1u64).await, + RateLimitResult::ExceededHourLimit + ); + } + #[tokio::test] async fn test_independent_keys() { let limiter: RateLimiter = RateLimiter::new(RateLimitConfig {