-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathensemble.py
More file actions
169 lines (139 loc) · 5.95 KB
/
ensemble.py
File metadata and controls
169 lines (139 loc) · 5.95 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
"""
ensemble.py — Ensemble tagger wrapper for the tag-context system.
Supports weighted voting across multiple tagging strategies. Production uses
"fixed" mode only with FixedTagger for keyword/pattern matching.
"""
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional, Tuple
from features import MessageFeatures
from tagger import TagAssignment, CORE_TAGS
from quality import QualityAgent
from tag_registry import get_registry
from fixed_tagger import FixedTagger
import config
@dataclass
class TaggerEntry:
"""A registered tagger in the ensemble."""
tagger_id: str
assign_fn: Callable[[MessageFeatures, str, str], TagAssignment]
weight: float = 1.0 # updated by quality agent scores
@dataclass
class EnsembleResult:
"""Result of ensemble tagging with per-tagger attribution."""
tags: List[str]
confidence: float
per_tagger: Dict[str, List[str]] # tagger_id → tags it contributed
tag_votes: Dict[str, float] # tag → weighted vote score
class EnsembleTagger:
"""
FixedTagger-compatible wrapper (formerly a multi-tagger ensemble).
Originally designed to run multiple taggers with weighted voting, but
since the GP tagger (DEAP-based) never worked, this now effectively
wraps a single FixedTagger. The ensemble infrastructure (weight updates,
per-tagger attribution, vote thresholds) is preserved for future use
if additional taggers are added.
Tags are included if their weighted vote exceeds the threshold.
Weights come from the quality agent's historical fitness scores.
"""
def __init__(
self,
quality_agent: Optional[QualityAgent] = None,
vote_threshold: float = 0.4,
min_weight: float = 0.1,
) -> None:
self._taggers: List[TaggerEntry] = []
self._qa = quality_agent
self._vote_threshold = vote_threshold
self._min_weight = min_weight
def register(self, tagger_id: str,
assign_fn: Callable[[MessageFeatures, str, str], TagAssignment],
initial_weight: float = 1.0) -> None:
"""Add a tagger to the ensemble."""
self._taggers.append(TaggerEntry(
tagger_id=tagger_id,
assign_fn=assign_fn,
weight=initial_weight,
))
def update_weights(self, last_n: int = 20) -> None:
"""Update tagger weights from quality agent scores."""
if not self._qa:
return
for entry in self._taggers:
fitness = self._qa.fitness(entry.tagger_id, last_n)
entry.weight = max(self._min_weight, fitness)
def assign(self, features: MessageFeatures,
user_text: str, assistant_text: str) -> EnsembleResult:
"""
Run all taggers, aggregate via weighted vote, prune low-confidence tags.
"""
if not self._taggers:
return EnsembleResult(
tags=[], confidence=0.0, per_tagger={}, tag_votes={},
)
# Collect results from all taggers
per_tagger: Dict[str, List[str]] = {}
tag_votes: Dict[str, float] = {}
total_weight = sum(e.weight for e in self._taggers)
for entry in self._taggers:
try:
result = entry.assign_fn(features, user_text, assistant_text)
tags = result.tags if isinstance(result, TagAssignment) else list(result)
except Exception:
tags = []
per_tagger[entry.tagger_id] = tags
normalised_weight = entry.weight / total_weight if total_weight > 0 else 0
for tag in tags:
tag_votes[tag] = tag_votes.get(tag, 0.0) + normalised_weight
# Threshold filter: only tags with sufficient weighted support
# Use registry to get active tags (core + candidate)
registry = get_registry()
active_tags = registry.get_active_tags()
accepted_tags = sorted(
tag for tag, vote in tag_votes.items()
if vote >= self._vote_threshold and tag in active_tags
)
# Aggregate confidence: mean vote score of accepted tags
if accepted_tags:
confidence = sum(tag_votes[t] for t in accepted_tags) / len(accepted_tags)
else:
confidence = 0.0
return EnsembleResult(
tags=accepted_tags,
confidence=confidence,
per_tagger=per_tagger,
tag_votes=tag_votes,
)
def explain(self, result: EnsembleResult) -> str:
"""Return a human-readable explanation of an ensemble tagging."""
lines = [f"Ensemble: {len(result.tags)} tags accepted (threshold={self._vote_threshold})"]
for tag in result.tags:
vote = result.tag_votes.get(tag, 0)
sources = [tid for tid, tags in result.per_tagger.items() if tag in tags]
lines.append(f" {tag:<25} vote={vote:.2f} from: {', '.join(sources)}")
rejected = sorted(t for t, v in result.tag_votes.items()
if t not in result.tags and v > 0)
if rejected:
lines.append(" Pruned (below threshold):")
for tag in rejected:
lines.append(f" {tag:<23} vote={result.tag_votes[tag]:.2f}")
return "\n".join(lines)
def build_ensemble(
mode: Optional[str] = None,
quality_agent: Optional[QualityAgent] = None,
vote_threshold: float = 0.4,
) -> EnsembleTagger:
"""
Build a tagger in fixed mode (FixedTagger + baseline).
Production uses "fixed" mode only — keyword/pattern matching from tags.yaml.
"""
mode = mode or config.TAGGER_MODE
if mode != "fixed":
raise ValueError(f"Only 'fixed' mode is supported. Got: {mode}")
ensemble = EnsembleTagger(
quality_agent=quality_agent,
vote_threshold=vote_threshold,
)
# Fixed tagger only — no DEAP dependency required
fixed = FixedTagger(config.TAGS_CONFIG)
ensemble.register("fixed", fixed.assign, initial_weight=1.0)
return ensemble