-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcache.go
More file actions
213 lines (182 loc) · 5.38 KB
/
cache.go
File metadata and controls
213 lines (182 loc) · 5.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
package guardrails
import (
"context"
"sync"
"time"
"github.com/initializ/guardrails/models"
)
// PolicyCache manages cached guardrail policies with version-based invalidation
type PolicyCache struct {
cache map[string]*CacheEntry // key: "entityType:orgID:entityID"
mu sync.RWMutex
loader PolicyLoader // optional, only needed for legacy MongoDB flow
ttl time.Duration // max cache age before version check
}
// CacheEntry represents a cached policy
type CacheEntry struct {
Policy *EffectivePolicy
Version int64
CachedAt time.Time
LastAccessed time.Time
}
// NewPolicyCache creates a new policy cache
func NewPolicyCache(loader PolicyLoader, ttl time.Duration) *PolicyCache {
return &PolicyCache{
cache: make(map[string]*CacheEntry),
loader: loader,
ttl: ttl,
}
}
// cacheKey generates the cache key including entity type
func cacheKey(entityType EntityType, orgID, entityID string) string {
return string(entityType) + ":" + orgID + ":" + entityID
}
// Get retrieves or loads a policy with version validation (legacy MongoDB flow)
func (c *PolicyCache) Get(ctx context.Context, entityType EntityType, entityID, orgID string) (*EffectivePolicy, error) {
key := cacheKey(entityType, orgID, entityID)
// Fast path: check cache with read lock
c.mu.RLock()
entry, exists := c.cache[key]
c.mu.RUnlock()
if exists && time.Since(entry.CachedAt) < c.ttl {
// Cache hit within TTL, update access time
c.mu.Lock()
entry.LastAccessed = time.Now()
c.mu.Unlock()
return entry.Policy, nil
}
// Cache miss or stale - validate version
currentVersion, err := c.loader.GetVersion(ctx, entityType, entityID, orgID)
if err != nil {
// On version check failure, use stale cache if available
if exists {
return entry.Policy, nil
}
return nil, err
}
// Check if cached version is still valid
if exists && entry.Version == currentVersion {
// Version match - refresh cache timestamp
c.mu.Lock()
entry.CachedAt = time.Now()
entry.LastAccessed = time.Now()
c.mu.Unlock()
return entry.Policy, nil
}
// Load fresh policy
return c.loadAndCache(ctx, entityType, entityID, orgID, key)
}
// GetOrCompile retrieves a cached policy or compiles from the provided StructuredGuardrails config.
// This is the primary method for the new caller-passes-config flow.
func (c *PolicyCache) GetOrCompile(entityType EntityType, orgID, entityID string,
configVersion int64, sg *models.StructuredGuardrails) (*EffectivePolicy, error) {
key := cacheKey(entityType, orgID, entityID)
// Fast path: check cache with read lock
c.mu.RLock()
entry, exists := c.cache[key]
c.mu.RUnlock()
if exists && entry.Version == configVersion {
c.mu.Lock()
entry.LastAccessed = time.Now()
c.mu.Unlock()
return entry.Policy, nil
}
// Compile new policy
policy, err := CompileStructuredGuardrails(entityType, orgID, entityID, configVersion, false, sg)
if err != nil {
return nil, err
}
// Cache the compiled policy
c.mu.Lock()
c.cache[key] = &CacheEntry{
Policy: policy,
Version: configVersion,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
c.mu.Unlock()
return policy, nil
}
// loadAndCache loads policy from DB and caches it
func (c *PolicyCache) loadAndCache(ctx context.Context, entityType EntityType, entityID, orgID, key string) (*EffectivePolicy, error) {
c.mu.Lock()
defer c.mu.Unlock()
// Double-check after acquiring write lock
if entry, exists := c.cache[key]; exists {
version, _ := c.loader.GetVersion(ctx, entityType, entityID, orgID)
if entry.Version == version {
return entry.Policy, nil
}
}
// Load from database
config, err := c.loader.LoadPolicy(ctx, entityType, entityID, orgID)
if err != nil {
return nil, err
}
// Compile policy
policy, err := c.loader.CompilePolicy(config)
if err != nil {
return nil, err
}
// Cache the compiled policy
c.cache[key] = &CacheEntry{
Policy: policy,
Version: config.ConfigVersion,
CachedAt: time.Now(),
LastAccessed: time.Now(),
}
return policy, nil
}
// Invalidate removes a specific policy from cache
func (c *PolicyCache) Invalidate(entityType EntityType, orgID, entityID string) {
key := cacheKey(entityType, orgID, entityID)
c.mu.Lock()
delete(c.cache, key)
c.mu.Unlock()
}
// InvalidateOrg removes all policies for an org
func (c *PolicyCache) InvalidateOrg(orgID string) {
c.mu.Lock()
for key := range c.cache {
// Keys are "entityType:orgID:entityID", check if orgID segment matches
// We search for ":orgID:" pattern in the key
prefix := ":" + orgID + ":"
if len(key) > len(prefix) && containsOrgID(key, orgID) {
delete(c.cache, key)
}
}
c.mu.Unlock()
}
// containsOrgID checks if the cache key contains the orgID in the second segment
func containsOrgID(key, orgID string) bool {
// key format: "entityType:orgID:entityID"
firstColon := -1
for i, ch := range key {
if ch == ':' {
if firstColon == -1 {
firstColon = i
} else {
// second colon found at i
return key[firstColon+1:i] == orgID
}
}
}
return false
}
// Cleanup removes stale entries (run periodically)
func (c *PolicyCache) Cleanup(maxAge time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
now := time.Now()
for key, entry := range c.cache {
if now.Sub(entry.LastAccessed) > maxAge {
delete(c.cache, key)
}
}
}
// Size returns the number of cached entries
func (c *PolicyCache) Size() int {
c.mu.RLock()
defer c.mu.RUnlock()
return len(c.cache)
}