-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathfixed_tagger.py
More file actions
228 lines (193 loc) · 8.22 KB
/
fixed_tagger.py
File metadata and controls
228 lines (193 loc) · 8.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
"""
fixed_tagger.py — User-configurable keyword/pattern-based tagger.
Reads tags.yaml (or a path given at construction). Hot-reloads on change.
Optionally loads per-user tags from a user tag file and merges them with
system tags (user tags override system tags on name collision).
Returns TagAssignment objects compatible with the ensemble tagger interface.
"""
import re
import threading
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Optional
try:
import yaml
YAML_AVAILABLE = True
except ImportError:
YAML_AVAILABLE = False
from features import MessageFeatures
from tagger import TagAssignment, _strip_metadata
DEFAULT_TAGS_PATH = Path(__file__).parent / "tags.yaml"
USER_TAGS_DIR = Path.home() / ".tag-context" / "tags.user"
@dataclass
class TagSpec:
name: str
keywords: List[str]
patterns: List[re.Pattern]
requires_all: bool
confidence: float
enabled: bool
def _parse_tag_specs(data: dict) -> List[TagSpec]:
"""Parse a tags YAML data dict into a list of TagSpec objects."""
result = []
for entry in data.get("tags", []):
if not entry.get("enabled", True):
continue
compiled_patterns = []
for p in entry.get("patterns", []):
try:
compiled_patterns.append(
re.compile(p, re.IGNORECASE | re.MULTILINE)
)
except re.error:
pass
result.append(TagSpec(
name=entry["name"],
keywords=[k.lower() for k in entry.get("keywords", [])],
patterns=compiled_patterns,
requires_all=entry.get("requires_all", False),
confidence=entry.get("confidence", 1.0),
enabled=True,
))
return result
class FixedTagger:
"""
Keyword/pattern-based tagger driven by a YAML config.
Hot-reloads: if tags.yaml or the user tag file mtime changes, reloads
automatically without restarting the service.
User tags are loaded from ~/.tag-context/tags.user/<channel_label>.yaml
when a channel_label is provided. User tags are merged with system tags;
user tags take precedence on name collision.
Production tagger for the tag-context system.
"""
def __init__(self, config_path: Optional[Path] = None,
user_tags_path: Optional[Path] = None,
reload_interval: float = 30.0) -> None:
self._path = config_path or DEFAULT_TAGS_PATH
self._user_tags_path: Optional[Path] = user_tags_path
self._reload_interval = reload_interval
self._tags: List[TagSpec] = []
self._mtime: float = 0.0
self._user_mtime: float = 0.0
self._lock = threading.RLock()
self._load()
@classmethod
def for_channel(cls, channel_label: Optional[str],
config_path: Optional[Path] = None,
reload_interval: float = 30.0) -> "FixedTagger":
"""
Convenience factory: creates a FixedTagger that merges system tags
with the user tags for the given channel_label.
If channel_label is None or the user tag file doesn't exist, returns
a tagger with system tags only.
"""
user_tags_path: Optional[Path] = None
if channel_label:
candidate = USER_TAGS_DIR / f"{channel_label}.yaml"
if candidate.exists():
user_tags_path = candidate
return cls(config_path=config_path, user_tags_path=user_tags_path,
reload_interval=reload_interval)
def _load(self) -> None:
"""Load or reload tags from YAML (system + optional user tags)."""
if not YAML_AVAILABLE:
raise ImportError("pyyaml required for FixedTagger: pip install pyyaml")
with self._lock:
try:
sys_mtime = self._path.stat().st_mtime
user_mtime: float = 0.0
if self._user_tags_path and self._user_tags_path.exists():
user_mtime = self._user_tags_path.stat().st_mtime
if sys_mtime == self._mtime and user_mtime == self._user_mtime:
return # no change
# Load system tags
with self._path.open() as f:
sys_data = yaml.safe_load(f)
sys_specs = _parse_tag_specs(sys_data)
# Load user tags (if path provided)
user_specs: List[TagSpec] = []
if self._user_tags_path and self._user_tags_path.exists():
with self._user_tags_path.open() as f:
user_data = yaml.safe_load(f)
user_specs = _parse_tag_specs(user_data or {})
# Merge: user tags override system tags on name collision
merged: Dict[str, TagSpec] = {}
for spec in sys_specs:
merged[spec.name] = spec
for spec in user_specs:
merged[spec.name] = spec # override
self._tags = list(merged.values())
self._mtime = sys_mtime
self._user_mtime = user_mtime
except Exception as e:
# On reload failure, keep existing tags
if not self._tags:
raise RuntimeError(f"Failed to load tags from {self._path}: {e}")
def _maybe_reload(self) -> None:
"""Check if configs have changed and reload if needed."""
try:
sys_mtime = self._path.stat().st_mtime
user_mtime: float = 0.0
if self._user_tags_path and self._user_tags_path.exists():
user_mtime = self._user_tags_path.stat().st_mtime
if sys_mtime != self._mtime or user_mtime != self._user_mtime:
self._load()
except OSError:
pass
def assign(self, features: MessageFeatures,
user_text: str, assistant_text: str) -> TagAssignment:
"""Apply fixed rules. Hot-reloads config if changed."""
self._maybe_reload()
user_text = _strip_metadata(user_text)
assistant_text = _strip_metadata(assistant_text)
combined = (user_text + " " + assistant_text).lower()
fired_tags = []
fired_rules = []
confidences = []
with self._lock:
for spec in self._tags:
matched = self._matches(spec, combined)
if matched:
fired_tags.append(spec.name)
fired_rules.append(f"fixed:{spec.name}")
confidences.append(spec.confidence)
avg_conf = sum(confidences) / len(confidences) if confidences else 0.0
return TagAssignment(
tags=sorted(fired_tags),
confidence=avg_conf,
rules_fired=fired_rules,
)
def _matches(self, spec: TagSpec, combined: str) -> bool:
hits = []
# Keyword matching (word-boundary)
for kw in spec.keywords:
# For multi-word keywords, we need word boundaries at the start and end only,
# not between words. Split on spaces and build a pattern that allows
# spaces between the words.
if ' ' in kw:
# Multi-word keyword: match with flexible whitespace
words = kw.split()
pattern = r"\b" + r"\s+".join(re.escape(w) for w in words) + r"\b"
else:
# Single-word keyword: simple word boundary match
pattern = r"\b" + re.escape(kw) + r"\b"
if re.search(pattern, combined):
hits.append(True)
if not spec.requires_all:
break # Short-circuit for OR logic
# Pattern matching (only if needed)
if not (hits and not spec.requires_all):
for pat in spec.patterns:
if pat.search(combined):
hits.append(True)
if not spec.requires_all:
break
if spec.requires_all:
expected = len(spec.keywords) + len(spec.patterns)
return len(hits) >= expected
return len(hits) > 0
@property
def tag_names(self) -> List[str]:
"""Return list of active tag names."""
with self._lock:
return [t.name for t in self._tags]