Skip to content

Commit f37e391

Browse files
committed
Simplify weight calculation
1 parent 030bc9a commit f37e391

File tree

2 files changed

+20
-36
lines changed

2 files changed

+20
-36
lines changed

strsimpy/weighted_levenshtein.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,28 +21,28 @@
2121
from .string_distance import StringDistance
2222

2323

24-
class CharacterInsDelInterface:
24+
def default_insertion_cost(char):
25+
return 1.0
2526

26-
def deletion_cost(self, c):
27-
raise NotImplementedError()
2827

29-
def insertion_cost(self, c):
30-
raise NotImplementedError()
28+
def default_deletion_cost(char):
29+
return 1.0
3130

3231

33-
class CharacterSubstitutionInterface:
34-
35-
def cost(self, c0, c1):
36-
raise NotImplementedError()
32+
def default_substitution_cost(char_a, char_b):
33+
return 1.0
3734

3835

3936
class WeightedLevenshtein(StringDistance):
4037

41-
def __init__(self, character_substitution, character_ins_del=None):
42-
self.character_ins_del = character_ins_del
43-
if character_substitution is None:
44-
raise TypeError("Argument character_substitution is NoneType.")
45-
self.character_substitution = character_substitution
38+
def __init__(self,
39+
substitution_cost_fn=default_substitution_cost,
40+
insertion_cost_fn=default_insertion_cost,
41+
deletion_cost_fn=default_deletion_cost,
42+
):
43+
self.substitution_cost_fn = substitution_cost_fn
44+
self.insertion_cost_fn = insertion_cost_fn
45+
self.deletion_cost_fn = deletion_cost_fn
4646

4747
def distance(self, s0, s1):
4848
if s0 is None:
@@ -60,30 +60,20 @@ def distance(self, s0, s1):
6060

6161
v0[0] = 0
6262
for i in range(1, len(v0)):
63-
v0[i] = v0[i - 1] + self._insertion_cost(s1[i - 1])
63+
v0[i] = v0[i - 1] + self.insertion_cost_fn(s1[i - 1])
6464

6565
for i in range(len(s0)):
6666
s1i = s0[i]
67-
deletion_cost = self._deletion_cost(s1i)
67+
deletion_cost = self.deletion_cost_fn(s1i)
6868
v1[0] = v0[0] + deletion_cost
6969

7070
for j in range(len(s1)):
7171
s2j = s1[j]
7272
cost = 0
7373
if s1i != s2j:
74-
cost = self.character_substitution.cost(s1i, s2j)
75-
insertion_cost = self._insertion_cost(s2j)
74+
cost = self.substitution_cost_fn(s1i, s2j)
75+
insertion_cost = self.insertion_cost_fn(s2j)
7676
v1[j + 1] = min(v1[j] + insertion_cost, v0[j + 1] + deletion_cost, v0[j] + cost)
7777
v0, v1 = v1, v0
7878

7979
return v0[len(s1)]
80-
81-
def _insertion_cost(self, c):
82-
if self.character_ins_del is None:
83-
return 1.0
84-
return self.character_ins_del.insertion_cost(c)
85-
86-
def _deletion_cost(self, c):
87-
if self.character_ins_del is None:
88-
return 1.0
89-
return self.character_ins_del.deletion_cost(c)

strsimpy/weighted_levenshtein_test.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,13 @@
2020

2121
import unittest
2222

23-
from .weighted_levenshtein import WeightedLevenshtein, CharacterSubstitutionInterface
24-
25-
26-
class CharSub(CharacterSubstitutionInterface):
27-
28-
def cost(self, c0, c1):
29-
return 1.0
23+
from .weighted_levenshtein import WeightedLevenshtein
3024

3125

3226
class TestWeightedLevenshtein(unittest.TestCase):
3327

3428
def test_weighted_levenshtein(self):
35-
a = WeightedLevenshtein(character_substitution=CharSub())
29+
a = WeightedLevenshtein()
3630
s0 = ""
3731
s1 = ""
3832
s2 = "上海"

0 commit comments

Comments
 (0)