Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 110 additions & 32 deletions src/rate_limiter.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Rate limiting for token processing.
//!
//! Provides a fixed-window rate limiter with per-minute and per-hour limits
//! Provides a sliding-window rate limiter with per-minute and per-hour limits
//! to prevent spam and replay attacks.

use std::collections::VecDeque;
use std::hash::Hash;
use std::num::NonZeroUsize;
use std::time::Duration;
Expand Down Expand Up @@ -58,21 +59,38 @@ impl RateLimitResult {
/// Entry tracking rate limit counters for a single key.
#[derive(Debug, Clone)]
struct RateLimitEntry {
minute_count: u32,
minute_window_start: Instant,
hour_count: u32,
hour_window_start: Instant,
minute_hits: VecDeque<Instant>,
hour_hits: VecDeque<Instant>,
}

impl RateLimitEntry {
fn new(now: Instant) -> Self {
fn new() -> Self {
Self {
minute_count: 0,
minute_window_start: now,
hour_count: 0,
hour_window_start: now,
minute_hits: VecDeque::new(),
hour_hits: VecDeque::new(),
}
}

fn prune(&mut self, now: Instant) {
prune_hits(&mut self.minute_hits, now, Duration::from_secs(60));
prune_hits(&mut self.hour_hits, now, Duration::from_secs(3600));
}
}

fn prune_hits(hits: &mut VecDeque<Instant>, now: Instant, window: Duration) {
let old_len = hits.len();

while hits
.front()
.is_some_and(|hit| now.duration_since(*hit) >= window)
{
hits.pop_front();
}

let new_len = hits.len();
if old_len > new_len && (new_len == 0 || hits.capacity() > new_len.saturating_mul(2).max(8)) {
hits.shrink_to_fit();
}
}

/// Statistics returned from rate limiter cleanup.
Expand Down Expand Up @@ -109,7 +127,12 @@ impl Default for RateLimitConfig {
}
}

/// Fixed-window rate limiter with per-minute and per-hour limits.
/// Sliding-window rate limiter with per-minute and per-hour limits.
///
/// Each active entry stores admitted request timestamps for true sliding-window
/// enforcement, so memory scales with the number of admitted hits per key. With
/// the defaults, a hot key can hold up to roughly 5,240 `Instant` values across
/// the minute and hour windows before older hits are pruned.
///
/// Uses a bounded cache to limit memory usage. When the cache is full,
/// unknown keys are rejected until cleanup removes stale entries. This
Expand Down Expand Up @@ -152,35 +175,25 @@ impl<K: Hash + Eq + Clone + Send + Sync + 'static> RateLimiter<K> {
if entries.len() >= entries.cap().get() {
return RateLimitResult::ExceededCapacityLimit;
}
entries.put(key.clone(), RateLimitEntry::new(now));
entries.put(key.clone(), RateLimitEntry::new());
entries.get_mut(key).expect("just inserted")
};

// Reset minute window if expired
if now.duration_since(entry.minute_window_start) >= Duration::from_secs(60) {
entry.minute_count = 0;
entry.minute_window_start = now;
}

// Reset hour window if expired
if now.duration_since(entry.hour_window_start) >= Duration::from_secs(3600) {
entry.hour_count = 0;
entry.hour_window_start = now;
}
entry.prune(now);

// Check minute limit first (more likely to be hit)
if entry.minute_count >= self.max_per_minute {
if entry.minute_hits.len() >= self.max_per_minute as usize {
return RateLimitResult::ExceededMinuteLimit;
}

// Check hour limit
if entry.hour_count >= self.max_per_hour {
if entry.hour_hits.len() >= self.max_per_hour as usize {
return RateLimitResult::ExceededHourLimit;
}

// Increment counters
entry.minute_count += 1;
entry.hour_count += 1;
entry.minute_hits.push_back(now);
entry.hour_hits.push_back(now);

RateLimitResult::Allowed
}
Expand All @@ -194,9 +207,9 @@ impl<K: Hash + Eq + Clone + Send + Sync + 'static> RateLimiter<K> {
pub async fn rollback_increment(&self, key: &K) {
let mut entries = self.entries.write().await;
let should_remove = if let Some(entry) = entries.get_mut(key) {
entry.minute_count = entry.minute_count.saturating_sub(1);
entry.hour_count = entry.hour_count.saturating_sub(1);
entry.minute_count == 0 && entry.hour_count == 0
entry.minute_hits.pop_back();
entry.hour_hits.pop_back();
entry.minute_hits.is_empty() && entry.hour_hits.is_empty()
} else {
false
};
Expand All @@ -222,7 +235,10 @@ impl<K: Hash + Eq + Clone + Send + Sync + 'static> RateLimiter<K> {
.iter()
.rev()
.take(CLEANUP_BATCH_SIZE)
.filter(|(_, entry)| now.duration_since(entry.hour_window_start) >= hour)
.filter(|(_, entry)| match entry.hour_hits.back() {
Some(hit) => now.duration_since(*hit) >= hour,
None => true,
})
.map(|(k, _)| k.clone())
.collect();

Expand All @@ -248,7 +264,7 @@ impl<K: Hash + Eq + Clone + Send + Sync + 'static> RateLimiter<K> {
let entries = self.entries.read().await;
entries
.peek(key)
.map(|entry| (entry.minute_count, entry.hour_count))
.map(|entry| (entry.minute_hits.len() as u32, entry.hour_hits.len() as u32))
}
}

Expand Down Expand Up @@ -349,6 +365,37 @@ mod tests {
assert!(limiter.check_and_increment(&1u64).await.is_allowed());
}

#[tokio::test]
async fn test_minute_limit_slides_across_window_boundary() {
tokio::time::pause();

let limiter: RateLimiter<u64> = RateLimiter::new(RateLimitConfig {
max_per_minute: 5,
max_per_hour: 100,
max_entries: 100,
});

assert!(limiter.check_and_increment(&1u64).await.is_allowed());

tokio::time::advance(Duration::from_millis(59_900)).await;

for _ in 0..4 {
assert!(limiter.check_and_increment(&1u64).await.is_allowed());
}
assert_eq!(
limiter.check_and_increment(&1u64).await,
RateLimitResult::ExceededMinuteLimit
);

tokio::time::advance(Duration::from_millis(100)).await;

assert!(limiter.check_and_increment(&1u64).await.is_allowed());
assert_eq!(
limiter.check_and_increment(&1u64).await,
RateLimitResult::ExceededMinuteLimit
);
}

#[tokio::test]
async fn test_hour_window_reset() {
tokio::time::pause();
Expand All @@ -370,6 +417,37 @@ mod tests {
assert!(limiter.check_and_increment(&1u64).await.is_allowed());
}

#[tokio::test]
async fn test_hour_limit_slides_across_window_boundary() {
tokio::time::pause();

let limiter: RateLimiter<u64> = RateLimiter::new(RateLimitConfig {
max_per_minute: 100,
max_per_hour: 5,
max_entries: 100,
});

assert!(limiter.check_and_increment(&1u64).await.is_allowed());

tokio::time::advance(Duration::from_millis(3_599_900)).await;

for _ in 0..4 {
assert!(limiter.check_and_increment(&1u64).await.is_allowed());
}
assert_eq!(
limiter.check_and_increment(&1u64).await,
RateLimitResult::ExceededHourLimit
);

tokio::time::advance(Duration::from_millis(100)).await;

assert!(limiter.check_and_increment(&1u64).await.is_allowed());
assert_eq!(
limiter.check_and_increment(&1u64).await,
RateLimitResult::ExceededHourLimit
);
}

#[tokio::test]
async fn test_independent_keys() {
let limiter: RateLimiter<u64> = RateLimiter::new(RateLimitConfig {
Expand Down