21
21
from .string_distance import StringDistance
22
22
23
23
24
- class CharacterInsDelInterface :
24
+ def default_insertion_cost (char ):
25
+ return 1.0
25
26
26
- def deletion_cost (self , c ):
27
- raise NotImplementedError ()
28
27
29
- def insertion_cost ( self , c ):
30
- raise NotImplementedError ()
28
+ def default_deletion_cost ( char ):
29
+ return 1.0
31
30
32
31
33
- class CharacterSubstitutionInterface :
34
-
35
- def cost (self , c0 , c1 ):
36
- raise NotImplementedError ()
32
+ def default_substitution_cost (char_a , char_b ):
33
+ return 1.0
37
34
38
35
39
36
class WeightedLevenshtein (StringDistance ):
40
37
41
- def __init__ (self , character_substitution , character_ins_del = None ):
42
- self .character_ins_del = character_ins_del
43
- if character_substitution is None :
44
- raise TypeError ("Argument character_substitution is NoneType." )
45
- self .character_substitution = character_substitution
38
+ def __init__ (self ,
39
+ substitution_cost_fn = default_substitution_cost ,
40
+ insertion_cost_fn = default_insertion_cost ,
41
+ deletion_cost_fn = default_deletion_cost ,
42
+ ):
43
+ self .substitution_cost_fn = substitution_cost_fn
44
+ self .insertion_cost_fn = insertion_cost_fn
45
+ self .deletion_cost_fn = deletion_cost_fn
46
46
47
47
def distance (self , s0 , s1 ):
48
48
if s0 is None :
@@ -60,30 +60,20 @@ def distance(self, s0, s1):
60
60
61
61
v0 [0 ] = 0
62
62
for i in range (1 , len (v0 )):
63
- v0 [i ] = v0 [i - 1 ] + self ._insertion_cost (s1 [i - 1 ])
63
+ v0 [i ] = v0 [i - 1 ] + self .insertion_cost_fn (s1 [i - 1 ])
64
64
65
65
for i in range (len (s0 )):
66
66
s1i = s0 [i ]
67
- deletion_cost = self ._deletion_cost (s1i )
67
+ deletion_cost = self .deletion_cost_fn (s1i )
68
68
v1 [0 ] = v0 [0 ] + deletion_cost
69
69
70
70
for j in range (len (s1 )):
71
71
s2j = s1 [j ]
72
72
cost = 0
73
73
if s1i != s2j :
74
- cost = self .character_substitution . cost (s1i , s2j )
75
- insertion_cost = self ._insertion_cost (s2j )
74
+ cost = self .substitution_cost_fn (s1i , s2j )
75
+ insertion_cost = self .insertion_cost_fn (s2j )
76
76
v1 [j + 1 ] = min (v1 [j ] + insertion_cost , v0 [j + 1 ] + deletion_cost , v0 [j ] + cost )
77
77
v0 , v1 = v1 , v0
78
78
79
79
return v0 [len (s1 )]
80
-
81
- def _insertion_cost (self , c ):
82
- if self .character_ins_del is None :
83
- return 1.0
84
- return self .character_ins_del .insertion_cost (c )
85
-
86
- def _deletion_cost (self , c ):
87
- if self .character_ins_del is None :
88
- return 1.0
89
- return self .character_ins_del .deletion_cost (c )
0 commit comments