-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathcache.py
More file actions
159 lines (137 loc) · 5.35 KB
/
Copy pathcache.py
File metadata and controls
159 lines (137 loc) · 5.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
#!/usr/bin/env python3
"""
LLM Response Cache Module
Content-hash caching for LLM responses. Eliminates redundant API calls
when running the pipeline on unchanged GNN files.
Cache key: sha256(file_content + model_name + prompt_template)
Storage: output/13_llm_output/.cache/<hash>.json
"""
import hashlib
import json
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Optional, cast
logger = logging.getLogger(__name__)
class LLMCache:
"""Content-addressed cache for LLM prompt responses."""
def __init__(
self, cache_dir: Optional[Path] = None, base_output_dir: Optional[Path] = None
) -> None:
"""
Initialize cache.
Args:
cache_dir: Explicit directory to store cached responses. Takes
precedence over ``base_output_dir``.
base_output_dir: If provided (and ``cache_dir`` is None), the
cache root is resolved via
``pipeline.config.get_output_dir_for_script("13_llm.py",
base_output_dir) / ".cache"``. This allows multi-workspace /
multi-pipeline runs to avoid sharing one on-disk cache.
When both are None, uses ``output/13_llm_output/.cache``
relative to CWD.
"""
if cache_dir is not None:
self.cache_dir = Path(cache_dir)
elif base_output_dir is not None:
try:
from pipeline.config import get_output_dir_for_script
self.cache_dir = (
get_output_dir_for_script("13_llm.py", Path(base_output_dir))
/ ".cache"
)
except ImportError:
self.cache_dir = Path(base_output_dir) / "13_llm_output" / ".cache"
else:
self.cache_dir = Path("output/13_llm_output/.cache")
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.hits = 0
self.misses = 0
self.writes = 0
@staticmethod
def _make_key(file_content: str, model_name: str, prompt_template: str) -> str:
"""Generate a deterministic cache key from content + model + prompt."""
payload = f"{file_content}\x00{model_name}\x00{prompt_template}"
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _cache_path(self, key: str) -> Path:
"""Return the file path for a given cache key."""
return self.cache_dir / f"{key}.json"
def get(
self, file_content: str, model_name: str, prompt_template: str
) -> Optional[str]:
"""
Look up a cached response.
Args:
file_content: Raw GNN file content.
model_name: Ollama / provider model name.
prompt_template: The prompt text sent to the LLM.
Returns:
Cached response string, or None on miss.
"""
key = self._make_key(file_content, model_name, prompt_template)
path = self._cache_path(key)
if path.exists():
try:
with open(path, "r") as f:
entry = json.load(f)
self.hits += 1
logger.debug(f"Cache HIT: {key[:12]}…")
return cast("str | None", entry.get("response"))
except (json.JSONDecodeError, KeyError) as exc:
logger.warning(f"Corrupt cache entry {key[:12]}…: {exc}")
path.unlink(missing_ok=True)
self.misses += 1
logger.debug(f"Cache MISS: {key[:12]}…")
return None
def put(
self,
file_content: str,
model_name: str,
prompt_template: str,
response: str,
) -> None:
"""
Store a response in the cache.
Args:
file_content: Raw GNN file content.
model_name: Ollama / provider model name.
prompt_template: The prompt text sent to the LLM.
response: The LLM response to cache.
"""
key = self._make_key(file_content, model_name, prompt_template)
path = self._cache_path(key)
entry: dict[str, Any] = {
"key": key,
"model": model_name,
"timestamp": datetime.now().isoformat(),
"content_length": len(file_content),
"prompt_length": len(prompt_template),
"response": response,
}
try:
with open(path, "w") as f:
json.dump(entry, f, indent=2)
self.writes += 1
logger.debug(f"Cache WRITE: {key[:12]}…")
except OSError as exc:
logger.warning(f"Failed to write cache entry {key[:12]}…: {exc}")
def summary(self) -> dict:
"""Return cache statistics for logging."""
total = self.hits + self.misses
ratio = (self.hits / total * 100) if total > 0 else 0.0
return {
"hits": self.hits,
"misses": self.misses,
"writes": self.writes,
"hit_ratio_pct": round(ratio, 1),
"cache_dir": str(self.cache_dir),
"entries_on_disk": sum(1 for _ in self.cache_dir.glob("*.json")),
}
def clear(self) -> int:
"""Remove all cached entries. Returns count of entries removed."""
count = 0
for p in self.cache_dir.glob("*.json"):
p.unlink()
count += 1
logger.info(f"Cache cleared: {count} entries removed")
return count