Skip to content

Commit dd6792a

Browse files
committed
Refactor the model to make it usable without the server
Currently, we can only effectively use the `SemanticMatch` and `EquivalenceTable` classes with the server. This creates a new and improved `SemanticMatchDictStore`, which implements the algorithm to search for matches. We then refactor the service to use this new class. Furthermore, we clean up the service to use it in a more pythonic way, eliminating the need for the `service_model` module.
1 parent ce50050 commit dd6792a

File tree

7 files changed

+603
-55
lines changed

7 files changed

+603
-55
lines changed

semantic_matcher/algorithm.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from typing import List, Tuple
2+
import heapq
3+
4+
import networkx as nx
5+
6+
7+
class SemanticMatchGraph(nx.DiGraph):
8+
def __init__(self):
9+
super().__init__()
10+
11+
def add_semantic_match(self,
12+
base_semantic_id: str,
13+
match_semantic_id: str,
14+
score: float):
15+
self.add_edge(
16+
u_of_edge=base_semantic_id,
17+
v_of_edge=match_semantic_id,
18+
weight=score,
19+
)
20+
21+
22+
class MatchResult:
23+
base_semantic_id: str
24+
match_semantic_id: str
25+
score: float
26+
path: List[str] # The path of `semantic_id`s that the algorithm took
27+
28+
def __init__(self,
29+
base_semantic_id: str,
30+
match_semantic_id: str,
31+
score: float,
32+
path: List[str]):
33+
self.base_semantic_id = base_semantic_id
34+
self.match_semantic_id = match_semantic_id
35+
self.score = score
36+
self.path = path
37+
38+
def __repr__(self) -> str:
39+
return f"{' -> '.join(self.path + [self.match_semantic_id])} = {self.score}"
40+
41+
42+
def find_semantic_matches(
43+
graph: SemanticMatchGraph,
44+
semantic_id: str,
45+
min_score: float = 0.5
46+
) -> List[MatchResult]:
47+
"""
48+
Find semantic matches for a given node with a minimum score threshold.
49+
50+
Args:
51+
graph (nx.DiGraph): The directed graph with weighted edges.
52+
semantic_id (str): The starting semantic_id.
53+
min_score (float): The minimum similarity score to consider.
54+
This value is necessary to ensure the search terminates also with sufficiently large graphs.
55+
56+
Returns:
57+
List[MatchResult]:
58+
A list of MatchResults, sorted by their score with the highest score first.
59+
"""
60+
if semantic_id not in graph:
61+
return []
62+
63+
# We need to make sure that all possible paths starting from the given semantic_id are explored.
64+
# To achieve this, we use the concept of "priority queue". While we could use a simple FIFO list of matches to
65+
# explore, this way we actually end up with an already sorted result with the highest match at the beginning of the
66+
# list. As possible implementation of this abstract data structure, we choose to use a "max-heap".
67+
# However, there is no efficient implementation of a max-heap in Python, so rather we use the built-in "min-heap"
68+
# and negate the score values. A priority queue ensures that elements with the highest priority are processed first,
69+
# regardless of when they were added.
70+
# We initialize the priority queue:
71+
pq: List[Tuple[float, str, List[str]]] = [(-1.0, semantic_id, [])] # (neg_score, node, path)
72+
# The queue is structured as follows:
73+
# - `neg_score`: The negative score of the match
74+
# - `node`: The `match_semantic_id` of the match
75+
# - `path`: The path between the `semantic_id` and the `match_semantic_id`
76+
77+
# Prepare the result list
78+
results: List[MatchResult] = []
79+
80+
# Run the priority queue until all possible paths have been explored
81+
# This means in each iteration:
82+
# - We pop the top element of the queue as it's the next highest semantic match we want to explore
83+
# - If the match has a score higher or equal to the given `min_score`, we add it to the results
84+
# - We add all connected `semantic_id`s to the priority queue to be treated next
85+
# - We go to the next element of the queue
86+
while pq:
87+
# Get the highest-score match from the queue
88+
neg_score, node, path = heapq.heappop(pq)
89+
score = -neg_score # Convert back to positive
90+
91+
# Store result if above threshold (except the start node)
92+
if node != semantic_id and score >= min_score:
93+
results.append(MatchResult(
94+
base_semantic_id=semantic_id,
95+
match_semantic_id=node,
96+
score=score,
97+
path=path
98+
))
99+
100+
# Traverse to the neighboring and therefore connected `semantic_id`s
101+
for neighbor, edge_data in graph[node].items():
102+
new_score: float = score * edge_data["weight"] # Multiplicative propagation
103+
104+
# Prevent loops by ensuring we do not revisit the start node after the first iteration
105+
if neighbor == semantic_id:
106+
continue # Avoid re-exploring the start node
107+
108+
# We add the newly found `semantic_id`s to the queue to be explored next in order of their score
109+
if new_score >= min_score:
110+
heapq.heappush(pq, (-new_score, neighbor, path + [node])) # Push updated path
111+
112+
return results
113+
114+
115+
if __name__ == "__main__":
116+
# Create graph
117+
G = SemanticMatchGraph()
118+
G.add_edge("A", "B", weight=0.8)
119+
G.add_edge("B", "C", weight=0.7)
120+
G.add_edge("C", "D", weight=0.9)
121+
G.add_edge("B", "D", weight=0.6)
122+
123+
# Find matches for "A"
124+
matches: List[MatchResult] = find_semantic_matches(G, "A", min_score=0)
125+
126+
# Print results
127+
for match in matches:
128+
print(match)
129+

semantic_matcher/model.py

Lines changed: 155 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,171 @@
1-
from typing import Dict, List
1+
import json
2+
import copy
3+
from typing import Dict, List, Set, Optional, Iterable
24

35
from pydantic import BaseModel
46

57

8+
# Todo: Adapt to use the new algorithm
9+
10+
611
class SemanticMatch(BaseModel):
712
"""
813
A semantic match, mapping two semanticIDs with a matching score. Can be imagined as a weighted graph with
914
`base_semantic_id` ---`score`---> `match_semantic_id`
1015
11-
Todo: Think about static and TTL, but that is optimization
12-
Todo: Maybe we want to have the matching method as debug information
16+
:cvar base_semantic_id:
17+
:cvar match_semantic_id:
18+
:cvar score: The semantic similarity score, a float between 0 and 1
19+
:cvar path: Optionally, if the `SemanticMatch` did not come from a source but is inferred by another `SemanticMatch`
20+
the `path` stores the SemanticMatches it came from
21+
:cvar meta_information: Optional meta_information, such as the source of the `SemanticMatch`
1322
"""
1423
base_semantic_id: str
1524
match_semantic_id: str
1625
score: float
17-
meta_information: Dict
26+
path: Optional[List["SemanticMatch"]] = None
27+
meta_information: Optional[Dict] = None
28+
29+
def __hash__(self):
30+
return hash((
31+
self.base_semantic_id,
32+
self.match_semantic_id,
33+
self.score,
34+
self.path,
35+
frozenset(self.meta_information.items())
36+
))
37+
38+
@classmethod
39+
def combine_semantic_matches(cls, first: "SemanticMatch", second: "SemanticMatch") -> "SemanticMatch":
40+
"""
41+
Construct a new `SemanticMatch` by combining two `SemanticMatch`es.
42+
43+
Given the following situation:
44+
A --0.4--> B
45+
B --0.5--> C
46+
this constructs a new `SemanticMatch`:
47+
A --(0.4*0.5)--> C
48+
while updating the `path` information of the new `SemanticMatch`
49+
50+
:param first: First `SemanticMatch`
51+
:param second: Second `SemanticMatch`. Note that `second.base_semantic_id` needs to be the same
52+
as `first.match_semantic_id`
53+
:return: The combined `SemanticMatch`
54+
"""
55+
if not first.match_semantic_id == second.base_semantic_id:
56+
raise KeyError(f"Cannot combine. `first.match_semantic_id` ({first.match_semantic_id}) does not "
57+
f"fit `second.base_semantic_id` ({second.base_semantic_id}).")
58+
if second.path:
59+
new_path = copy.copy(second.path)
60+
new_path.insert(0, second)
61+
else:
62+
new_path = [second]
63+
return SemanticMatch(
64+
base_semantic_id=first.base_semantic_id,
65+
match_semantic_id=second.match_semantic_id,
66+
score=first.score*second.score,
67+
path=new_path,
68+
)
69+
70+
class SemanticMatchDictStore:
71+
"""
72+
A collection of `SemanticMatch`es, stored in a Dict, where the Key is the `base_semantic_id` and the Value is
73+
the `SemanticMatch` object. This allows for efficient resolution of the `SemanticMatches` of the `base_semantic_id`.
74+
"""
75+
def __init__(self, matches: Iterable[SemanticMatch]):
76+
self._store: Dict[str, Set[SemanticMatch]] = {}
77+
for x in matches:
78+
self.add(x)
79+
80+
def add(self, match: SemanticMatch) -> None:
81+
"""
82+
Add a `SemanticMatch` to the store
83+
"""
84+
if match.base_semantic_id in self._store:
85+
self._store[match.base_semantic_id].add(match)
86+
else:
87+
self._store[match.base_semantic_id] = {match}
88+
89+
def discard(self, match: SemanticMatch) -> None:
90+
"""
91+
Discard a `SemanticMatch` from the store
92+
"""
93+
# First we remove the `SemanticMatch` from the set of matches for that `base_semantic_id`
94+
self._store[match.base_semantic_id].discard(match)
95+
# Then, if there is no more `SemanticMatch`es for that `base_semantic_id`, we remove the Dict entry completely
96+
if not len(self._store[match.base_semantic_id]):
97+
self._store.pop(match.base_semantic_id)
98+
99+
def get_all_matches(self) -> Set[SemanticMatch]:
100+
"""
101+
Return a set of all `SemanticMatch`es currently inside the store
102+
"""
103+
all_matches: Set[SemanticMatch] = set()
104+
for i in self._store.values():
105+
all_matches.update(i)
106+
return all_matches
107+
108+
def get_matches(self, semantic_id: str, min_score: Optional[float] = None) -> Set[SemanticMatch]:
109+
"""
110+
Return all 'SemanticMatches' of a given semantic_id currently inside a store that have a higher or equal
111+
score than the `min_score`.
112+
This is a recursive function, that also queries the matches of the matches, as long as the multiplicative
113+
scores of the matches is still higher or equal to the `min_score`.
114+
"""
115+
matches: Set[SemanticMatch] = set() # This is our return Set
116+
117+
# First, we check on the current level
118+
current_matches_with_any_score = self._store.get(semantic_id, set())
119+
current_matches = {
120+
match for match in current_matches_with_any_score if min_score is None or match.score >= min_score
121+
}
122+
# We can already update our return Set, since we know that the `current_matches` will definitely be inside
123+
matches.update(current_matches)
124+
125+
# Now we do the same query each of the current_matches that have a score larger or equal to min_score
126+
# Todo: We currently have a loop in here that we need to break
127+
for match in current_matches:
128+
# We calculate the new minimal score
129+
# Unified score is multiplied: score(A->B) * score(B->C)
130+
# This score should be larger or equal than the requested min_score:
131+
# score(A->B) * score(B->C) >= min_score
132+
# score(A->B) is well known, as it is the `match.score`
133+
# => score(B->C) >= (min_score/score(A->B))
134+
if min_score:
135+
new_min_score = min_score/match.score
136+
else:
137+
new_min_score = min_score
138+
# Here's the recursive function call, we do the same thing again with the new matches and the
139+
# updated `min_score`:
140+
new_matches = self.get_matches(semantic_id=match.base_semantic_id, min_score=new_min_score)
141+
# These new matches are now not relative to the original `base_semantic_id`, so we need to create new
142+
# `SemanticMatch`es and somehow store the path.
143+
for new_match in new_matches:
144+
matches.add(SemanticMatch.combine_semantic_matches(
145+
first=match,
146+
second=new_match
147+
))
148+
149+
# In the end, we return our return Set
150+
return matches
151+
152+
def to_file(self, filename: str) -> None:
153+
matches: List[Dict] = [match.model_dump() for match in self.get_all_matches()]
154+
with open(filename, "w") as file:
155+
json.dump(matches, file, indent=4)
156+
157+
@classmethod
158+
def from_file(cls, filename: str) -> "SemanticMatchDictStore":
159+
with open(filename, "r") as file:
160+
matches_data = json.load(file)
161+
matches = [SemanticMatch(**match_dict) for match_dict in matches_data]
162+
return cls(matches)
163+
164+
def __len__(self) -> int:
165+
length = 0
166+
for i in self._store.values():
167+
length += len(i)
168+
return length
18169

19170

20171
class EquivalenceTable(BaseModel):

0 commit comments

Comments
 (0)