Skip to content

Commit 3e40ef6

Browse files
youkaichaoAminsed
authored andcommitted
[Core] Prefix cache: frequency- and cost-aware eviction (opt-in)
Signed-off-by: Amin Sedaghat <[email protected]>
1 parent 361a746 commit 3e40ef6

File tree

9 files changed

+294
-15
lines changed

9 files changed

+294
-15
lines changed

requirements/test.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1225,7 +1225,6 @@ typeshed-client==2.8.2
12251225
# via jsonargparse
12261226
typing-extensions==4.15.0
12271227
# via
1228-
# aiosignal
12291228
# albumentations
12301229
# alembic
12311230
# chz
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import time
5+
6+
import pytest
7+
8+
from vllm.v1.core.eviction_policies import FrequencyCostEvictionPolicy
9+
from vllm.v1.core.kv_cache_utils import KVCacheBlock
10+
11+
pytestmark = pytest.mark.cpu_test
12+
13+
14+
def test_frequency_cost_eviction_orders_by_score():
15+
policy = FrequencyCostEvictionPolicy(block_size=16, alpha=2.0)
16+
17+
blocks = []
18+
now = time.monotonic()
19+
# Create three cached-free blocks with different access patterns
20+
for i, (age, access) in enumerate([(10.0, 1), (5.0, 1), (5.0, 10)]):
21+
b = KVCacheBlock(block_id=i)
22+
# mark as free and cached by simulating a non-None hash
23+
b._block_hash = b"dummy_hash" # type: ignore[attr-defined]
24+
b.ref_cnt = 0
25+
# manually set tracking attributes used by the policy
26+
b.first_access_ts = now - age # type: ignore[attr-defined]
27+
b.access_count = access # type: ignore[attr-defined]
28+
blocks.append(b)
29+
policy.on_block_release(b)
30+
31+
evicted = policy.get_eviction_candidates(3)
32+
# The block with lowest frequency/age should be first (age=10, access=1)
33+
assert evicted[0] == 0
34+
# The most frequently accessed among recent ones should be retained longer
35+
assert set(evicted) == {0, 1, 2}
36+
37+
38+
def test_policy_remove_block():
39+
policy = FrequencyCostEvictionPolicy(block_size=16)
40+
b = KVCacheBlock(block_id=42)
41+
b._block_hash = b"dummy" # type: ignore[attr-defined]
42+
b.ref_cnt = 0
43+
b.first_access_ts = time.monotonic() - 1.0 # type: ignore[attr-defined]
44+
b.access_count = 5 # type: ignore[attr-defined]
45+
policy.on_block_release(b)
46+
47+
# Removing the block should make it unselectable
48+
policy.remove_block(b)
49+
selected = policy.get_eviction_candidates(1)
50+
assert 42 not in selected

vllm/config/cache.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CacheDType = Literal["auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"]
2525
MambaDType = Literal["auto", "float32"]
2626
PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"]
27+
PrefixCacheEvictionPolicy = Literal["lru", "frequency_cost"]
2728

2829

2930
@config
@@ -126,6 +127,17 @@ class CacheConfig:
126127
gpu_memory_utilization. Note that kv_cache_memory_bytes
127128
(when not-None) ignores gpu_memory_utilization"""
128129

130+
# Eviction policy for prefix caching (experimental, opt-in)
131+
prefix_cache_eviction_policy: PrefixCacheEvictionPolicy = "lru"
132+
"""Eviction policy for prefix caching free cached blocks. Default is LRU.
133+
Set to "frequency_cost" to enable frequency × cost-aware eviction."""
134+
135+
eviction_cost_alpha: float = 2.0
136+
"""Alpha exponent for the compute cost term (block_size^alpha)."""
137+
138+
eviction_time_decay: float = 0.0
139+
"""Optional exponential time decay factor for the frequency term."""
140+
129141
def compute_hash(self) -> str:
130142
"""
131143
WARNING: Whenever a new field is added to this config,

vllm/engine/arg_utils.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,12 @@ class EngineArgs:
417417
prefix_caching_hash_algo: PrefixCachingHashAlgo = (
418418
CacheConfig.prefix_caching_hash_algo
419419
)
420+
# Eviction policy flags for prefix caching
421+
prefix_cache_eviction_policy: Literal["lru", "frequency_cost"] = (
422+
CacheConfig.prefix_cache_eviction_policy
423+
)
424+
eviction_cost_alpha: float = CacheConfig.eviction_cost_alpha
425+
eviction_time_decay: float = CacheConfig.eviction_time_decay
420426
disable_sliding_window: bool = ModelConfig.disable_sliding_window
421427
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
422428
swap_space: float = CacheConfig.swap_space
@@ -881,6 +887,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
881887
cache_group.add_argument(
882888
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
883889
)
890+
cache_group.add_argument(
891+
"--prefix-cache-eviction-policy",
892+
**cache_kwargs["prefix_cache_eviction_policy"],
893+
)
894+
cache_group.add_argument(
895+
"--eviction-cost-alpha", **cache_kwargs["eviction_cost_alpha"]
896+
)
897+
cache_group.add_argument(
898+
"--eviction-time-decay", **cache_kwargs["eviction_time_decay"]
899+
)
884900
cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
885901
cache_group.add_argument(
886902
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
@@ -1386,6 +1402,9 @@ def create_engine_config(
13861402
sliding_window=sliding_window,
13871403
enable_prefix_caching=self.enable_prefix_caching,
13881404
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
1405+
prefix_cache_eviction_policy=self.prefix_cache_eviction_policy,
1406+
eviction_cost_alpha=self.eviction_cost_alpha,
1407+
eviction_time_decay=self.eviction_time_decay,
13891408
cpu_offload_gb=self.cpu_offload_gb,
13901409
calculate_kv_scales=self.calculate_kv_scales,
13911410
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,

vllm/v1/core/block_pool.py

Lines changed: 93 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
KVCacheEvent,
1212
)
1313
from vllm.logger import init_logger
14+
from vllm.v1.core.eviction_policies import FrequencyCostEvictionPolicy
1415
from vllm.v1.core.kv_cache_utils import (
1516
BlockHash,
1617
BlockHashWithGroupId,
@@ -166,6 +167,25 @@ def __init__(
166167
self.enable_kv_cache_events = enable_kv_cache_events
167168
self.kv_event_queue: list[KVCacheEvent] = []
168169

170+
# Optional frequency-cost policy (set via configure_eviction_policy)
171+
self._policy: FrequencyCostEvictionPolicy | None = None
172+
173+
def configure_eviction_policy(
174+
self,
175+
policy: str,
176+
*,
177+
block_size: int,
178+
alpha: float = 2.0,
179+
time_decay: float = 0.0,
180+
) -> None:
181+
"""Configure optional eviction policy. Defaults to LRU if not set."""
182+
if policy == "frequency_cost":
183+
self._policy = FrequencyCostEvictionPolicy(
184+
block_size=block_size, alpha=alpha, time_decay_factor=time_decay
185+
)
186+
else:
187+
self._policy = None
188+
169189
def get_cached_block(
170190
self, block_hash: BlockHash, kv_cache_group_ids: list[int]
171191
) -> list[KVCacheBlock] | None:
@@ -278,19 +298,65 @@ def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
278298
if num_blocks > self.get_num_free_blocks():
279299
raise ValueError(f"Cannot get {num_blocks} free blocks from the pool")
280300

281-
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
282-
283-
# In order to only iterate the list once, we duplicated code a bit
301+
# Fast path: no policy configured -> original LRU behavior
302+
if self._policy is None:
303+
ret: list[KVCacheBlock] = self.free_block_queue.popleft_n(num_blocks)
304+
if self.enable_caching:
305+
for block in ret:
306+
self._maybe_evict_cached_block(block)
307+
assert block.ref_cnt == 0
308+
block.ref_cnt += 1
309+
else:
310+
for block in ret:
311+
assert block.ref_cnt == 0
312+
block.ref_cnt += 1
313+
return ret
314+
315+
# Policy path: prefer non-cached free blocks from LRU head, then
316+
# choose cached-free blocks via policy ranking.
317+
selected: list[KVCacheBlock] = []
318+
deferred_cached: list[KVCacheBlock] = []
319+
320+
while len(selected) < num_blocks:
321+
# Exhausted free blocks -> impossible due to initial check
322+
blk = self.free_block_queue.popleft()
323+
if blk.block_hash is None:
324+
selected.append(blk)
325+
else:
326+
# remove from policy to avoid selecting it immediately
327+
if self._policy is not None:
328+
self._policy.remove_block(blk)
329+
deferred_cached.append(blk)
330+
if self.get_num_free_blocks() == 0 and len(selected) < num_blocks:
331+
break
332+
333+
if len(selected) < num_blocks:
334+
need = num_blocks - len(selected)
335+
# Ask policy for global cached-free candidates by block_id
336+
ids = self._policy.get_eviction_candidates(need)
337+
for block_id in ids:
338+
blk = self.blocks[block_id]
339+
# Remove from free list if still present
340+
if blk.prev_free_block is not None and blk.next_free_block is not None:
341+
self.free_block_queue.remove(blk)
342+
# Evict hash later below
343+
selected.append(blk)
344+
345+
# Return deferred cached blocks to the free list tail to keep queue sound
346+
for blk in deferred_cached:
347+
self.free_block_queue.append(blk)
348+
349+
# Finalize selection: evict hashes for cached blocks; inc ref_cnt
284350
if self.enable_caching:
285-
for block in ret:
286-
self._maybe_evict_cached_block(block)
287-
assert block.ref_cnt == 0
288-
block.ref_cnt += 1
351+
for blk in selected:
352+
self._maybe_evict_cached_block(blk)
353+
assert blk.ref_cnt == 0
354+
blk.ref_cnt += 1
289355
else:
290-
for block in ret:
291-
assert block.ref_cnt == 0
292-
block.ref_cnt += 1
293-
return ret
356+
for blk in selected:
357+
assert blk.ref_cnt == 0
358+
blk.ref_cnt += 1
359+
return selected
294360

295361
def _maybe_evict_cached_block(self, block: KVCacheBlock) -> bool:
296362
"""
@@ -342,7 +408,11 @@ def touch(self, blocks: tuple[Sequence[KVCacheBlock], ...]) -> None:
342408
# candidate), so remove it.
343409
if block.ref_cnt == 0 and not block.is_null:
344410
self.free_block_queue.remove(block)
411+
if self._policy is not None:
412+
self._policy.remove_block(block)
345413
block.ref_cnt += 1
414+
if self._policy is not None:
415+
self._policy.on_block_access(block)
346416

347417
def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
348418
"""Free a list of blocks. The blocks should be ordered by their
@@ -356,9 +426,15 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
356426
blocks_list = list(ordered_blocks)
357427
for block in blocks_list:
358428
block.ref_cnt -= 1
359-
self.free_block_queue.append_n(
360-
[block for block in blocks_list if block.ref_cnt == 0 and not block.is_null]
361-
)
429+
freed = [
430+
block for block in blocks_list if block.ref_cnt == 0 and not block.is_null
431+
]
432+
self.free_block_queue.append_n(freed)
433+
if self._policy is not None:
434+
for block in freed:
435+
# Track only cached-free blocks
436+
if block.block_hash is not None:
437+
self._policy.on_block_release(block)
362438

363439
def reset_prefix_cache(self) -> bool:
364440
"""Reset prefix cache. This function may be used in RLHF
@@ -390,6 +466,9 @@ def reset_prefix_cache(self) -> bool:
390466
if self.enable_kv_cache_events:
391467
self.kv_event_queue.append(AllBlocksCleared())
392468

469+
if self._policy is not None:
470+
self._policy.reset()
471+
393472
return True
394473

395474
def get_num_free_blocks(self) -> int:
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
from .frequency_cost import FrequencyCostEvictionPolicy
5+
6+
__all__ = ["FrequencyCostEvictionPolicy"]
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import heapq
5+
import math
6+
import time
7+
8+
from vllm.v1.core.kv_cache_utils import KVCacheBlock
9+
10+
11+
class FrequencyCostEvictionPolicy:
12+
"""Min-heap policy over cached-free blocks by retention score.
13+
14+
Implementation notes:
15+
- Uses lazy deletion with an `entry_finder` dict to avoid in-place heap edits.
16+
- Scores are computed lazily when a block becomes cached-free.
17+
- `block_size` is provided once at init; not stored on each block.
18+
- This class tracks only blocks that are both free (ref_cnt==0) and cached
19+
(i.e., have a non-None block_hash).
20+
"""
21+
22+
def __init__(
23+
self,
24+
block_size: int,
25+
alpha: float = 2.0,
26+
time_decay_factor: float = 0.0,
27+
min_time_window: float = 1.0,
28+
) -> None:
29+
self.block_size = block_size
30+
self.alpha = alpha
31+
self.time_decay_factor = time_decay_factor
32+
self.min_time_window = min_time_window
33+
34+
# Heap entries: (score, counter, block_id)
35+
self._heap: list[tuple[float, int, int]] = []
36+
self._entry_finder: dict[int, tuple[float, int, int]] = {}
37+
self._counter = 0
38+
39+
def _score(self, block: KVCacheBlock) -> float:
40+
# If the block was never accessed through prefix hits, treat as lowest value.
41+
first_ts = getattr(block, "first_access_ts", None)
42+
access_count = getattr(block, "access_count", 0) or 0
43+
if first_ts is None:
44+
return 0.0
45+
now = time.monotonic()
46+
dt = max(self.min_time_window, now - first_ts)
47+
if self.time_decay_factor > 0.0:
48+
w = math.exp(-self.time_decay_factor * dt)
49+
freq = (access_count * w) / dt
50+
else:
51+
freq = access_count / dt
52+
cost = float(self.block_size) ** self.alpha
53+
return min(freq * cost, 1e15)
54+
55+
def _add(self, block: KVCacheBlock) -> None:
56+
# Only track cached-free blocks
57+
if block.ref_cnt != 0 or block.block_hash is None:
58+
return
59+
score = self._score(block)
60+
self._counter += 1
61+
entry = (score, self._counter, block.block_id)
62+
self._entry_finder[block.block_id] = entry
63+
heapq.heappush(self._heap, entry)
64+
65+
def on_block_access(self, block: KVCacheBlock) -> None:
66+
# Minimal tracking on access for frequency stats
67+
first_ts = getattr(block, "first_access_ts", None)
68+
if first_ts is None:
69+
block.first_access_ts = time.monotonic()
70+
block.access_count = (getattr(block, "access_count", 0) or 0) + 1
71+
72+
def on_block_release(self, block: KVCacheBlock) -> None:
73+
# Block became cached-free
74+
self._add(block)
75+
76+
def get_eviction_candidates(self, num_blocks: int) -> list[int]:
77+
out: list[int] = []
78+
while self._heap and len(out) < num_blocks:
79+
score, counter, block_id = heapq.heappop(self._heap)
80+
if self._entry_finder.get(block_id) == (score, counter, block_id):
81+
self._entry_finder.pop(block_id, None)
82+
out.append(block_id)
83+
return out
84+
85+
def remove_block(self, block: KVCacheBlock) -> None:
86+
# Lazy deletion: ensure future pops skip this block
87+
self._entry_finder.pop(block.block_id, None)
88+
89+
def reset(self) -> None:
90+
self._heap.clear()
91+
self._entry_finder.clear()
92+
93+
@property
94+
def name(self) -> str:
95+
return f"FrequencyCost(alpha={self.alpha})"

0 commit comments

Comments
 (0)