Skip to content

Commit 2a01c99

Browse files
authored
fix compat issue with sklearn (#19)
see ``https://github.com/TeamHG-Memex/sklearn-crfsuite/pull/69/files`` for more details.
1 parent 536523f commit 2a01c99

File tree

2 files changed

+211
-2
lines changed

2 files changed

+211
-2
lines changed

spacy_crfsuite/compat.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
from sklearn_crfsuite import CRF as _CRF
2+
3+
4+
class CRF(_CRF):
5+
"""
6+
python-crfsuite wrapper with interface siimlar to scikit-learn.
7+
It allows to use a familiar fit/predict interface and scikit-learn
8+
model selection utilities (cross-validation, hyperparameter optimization).
9+
10+
Unlike pycrfsuite.Trainer / pycrfsuite.Tagger this object is picklable;
11+
on-disk files are managed automatically.
12+
13+
Parameters
14+
----------
15+
algorithm : str, optional (default='lbfgs')
16+
Training algorithm. Allowed values:
17+
18+
* ``'lbfgs'`` - Gradient descent using the L-BFGS method
19+
* ``'l2sgd'`` - Stochastic Gradient Descent with L2 regularization term
20+
* ``'ap'`` - Averaged Perceptron
21+
* ``'pa'`` - Passive Aggressive (PA)
22+
* ``'arow'`` - Adaptive Regularization Of Weight Vector (AROW)
23+
24+
min_freq : float, optional (default=0)
25+
Cut-off threshold for occurrence
26+
frequency of a feature. CRFsuite will ignore features whose
27+
frequencies of occurrences in the training data are no greater
28+
than `min_freq`. The default is no cut-off.
29+
30+
all_possible_states : bool, optional (default=False)
31+
Specify whether CRFsuite generates state features that do not even
32+
occur in the training data (i.e., negative state features).
33+
When True, CRFsuite generates state features that associate all of
34+
possible combinations between attributes and labels.
35+
36+
Suppose that the numbers of attributes and labels are A and L
37+
respectively, this function will generate (A * L) features.
38+
Enabling this function may improve the labeling accuracy because
39+
the CRF model can learn the condition where an item is not predicted
40+
to its reference label. However, this function may also increase
41+
the number of features and slow down the training process
42+
drastically. This function is disabled by default.
43+
44+
all_possible_transitions : bool, optional (default=False)
45+
Specify whether CRFsuite generates transition features that
46+
do not even occur in the training data (i.e., negative transition
47+
features). When True, CRFsuite generates transition features that
48+
associate all of possible label pairs. Suppose that the number
49+
of labels in the training data is L, this function will
50+
generate (L * L) transition features.
51+
This function is disabled by default.
52+
53+
c1 : float, optional (default=0)
54+
The coefficient for L1 regularization.
55+
If a non-zero value is specified, CRFsuite switches to the
56+
Orthant-Wise Limited-memory Quasi-Newton (OWL-QN) method.
57+
The default value is zero (no L1 regularization).
58+
59+
Supported training algorithms: lbfgs
60+
61+
c2 : float, optional (default=1.0)
62+
The coefficient for L2 regularization.
63+
64+
Supported training algorithms: l2sgd, lbfgs
65+
66+
max_iterations : int, optional (default=None)
67+
The maximum number of iterations for optimization algorithms.
68+
Default value depends on training algorithm:
69+
70+
* lbfgs - unlimited;
71+
* l2sgd - 1000;
72+
* ap - 100;
73+
* pa - 100;
74+
* arow - 100.
75+
76+
num_memories : int, optional (default=6)
77+
The number of limited memories for approximating the inverse hessian
78+
matrix.
79+
80+
Supported training algorithms: lbfgs
81+
82+
epsilon : float, optional (default=1e-5)
83+
The epsilon parameter that determines the condition of convergence.
84+
85+
Supported training algorithms: ap, arow, lbfgs, pa
86+
87+
period : int, optional (default=10)
88+
The duration of iterations to test the stopping criterion.
89+
90+
Supported training algorithms: l2sgd, lbfgs
91+
92+
delta : float, optional (default=1e-5)
93+
The threshold for the stopping criterion; an iteration stops
94+
when the improvement of the log likelihood over the last
95+
`period` iterations is no greater than this threshold.
96+
97+
Supported training algorithms: l2sgd, lbfgs
98+
99+
linesearch : str, optional (default='MoreThuente')
100+
The line search algorithm used in L-BFGS updates. Allowed values:
101+
102+
* ``'MoreThuente'`` - More and Thuente's method;
103+
* ``'Backtracking'`` - backtracking method with regular Wolfe condition;
104+
* ``'StrongBacktracking'`` - backtracking method with strong Wolfe
105+
condition.
106+
107+
Supported training algorithms: lbfgs
108+
109+
max_linesearch : int, optional (default=20)
110+
The maximum number of trials for the line search algorithm.
111+
112+
Supported training algorithms: lbfgs
113+
114+
calibration_eta : float, optional (default=0.1)
115+
The initial value of learning rate (eta) used for calibration.
116+
117+
Supported training algorithms: l2sgd
118+
119+
calibration_rate : float, optional (default=2.0)
120+
The rate of increase/decrease of learning rate for calibration.
121+
122+
Supported training algorithms: l2sgd
123+
124+
calibration_samples : int, optional (default=1000)
125+
The number of instances used for calibration.
126+
The calibration routine randomly chooses instances no larger
127+
than `calibration_samples`.
128+
129+
Supported training algorithms: l2sgd
130+
131+
calibration_candidates : int, optional (default=10)
132+
The number of candidates of learning rate.
133+
The calibration routine terminates after finding
134+
`calibration_samples` candidates of learning rates
135+
that can increase log-likelihood.
136+
137+
Supported training algorithms: l2sgd
138+
139+
calibration_max_trials : int, optional (default=20)
140+
The maximum number of trials of learning rates for calibration.
141+
The calibration routine terminates after trying
142+
`calibration_max_trials` candidate values of learning rates.
143+
144+
Supported training algorithms: l2sgd
145+
146+
pa_type : int, optional (default=1)
147+
The strategy for updating feature weights. Allowed values:
148+
149+
* 0 - PA without slack variables;
150+
* 1 - PA type I;
151+
* 2 - PA type II.
152+
153+
Supported training algorithms: pa
154+
155+
c : float, optional (default=1)
156+
Aggressiveness parameter (used only for PA-I and PA-II).
157+
This parameter controls the influence of the slack term on the
158+
objective function.
159+
160+
Supported training algorithms: pa
161+
162+
error_sensitive : bool, optional (default=True)
163+
If this parameter is True, the optimization routine includes
164+
into the objective function the square root of the number of
165+
incorrect labels predicted by the model.
166+
167+
Supported training algorithms: pa
168+
169+
averaging : bool, optional (default=True)
170+
If this parameter is True, the optimization routine computes
171+
the average of feature weights at all updates in the training
172+
process (similarly to Averaged Perceptron).
173+
174+
Supported training algorithms: pa
175+
176+
variance : float, optional (default=1)
177+
The initial variance of every feature weight.
178+
The algorithm initialize a vector of feature weights as
179+
a multivariate Gaussian distribution with mean 0
180+
and variance `variance`.
181+
182+
Supported training algorithms: arow
183+
184+
gamma : float, optional (default=1)
185+
The tradeoff between loss function and changes of feature weights.
186+
187+
Supported training algorithms: arow
188+
189+
verbose : bool, optional (default=False)
190+
Enable trainer verbose mode.
191+
192+
model_filename : str, optional (default=None)
193+
A path to an existing CRFSuite model.
194+
This parameter allows to load and use existing crfsuite models.
195+
196+
By default, model files are created automatically and saved
197+
in temporary locations; the preferred way to save/load CRF models
198+
is to use pickle (or its alternatives like joblib).
199+
200+
"""
201+
202+
@property
203+
def model_filename(self):
204+
return self.modelfile and self.modelfile.name
205+
206+
@property
207+
def keep_tempfiles(self):
208+
return self.modelfile and self.modelfile.keep_tempfiles

spacy_crfsuite/crf_extractor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
import itertools
22
import joblib
33
import warnings
4+
import sklearn_crfsuite
45

56
from collections import Counter
67
from pathlib import Path
78
from typing import Dict, Text, Any, Optional, List, Tuple, Union, Callable
89

9-
from sklearn_crfsuite import CRF, metrics
1010
from spacy.language import Language
1111
from spacy.tokens.doc import Doc
1212
from sklearn.metrics import classification_report, f1_score
1313

1414
from spacy_crfsuite.bilou import entity_name_from_tag, bilou_prefix_from_tag, NO_ENTITY_TAG
15+
from spacy_crfsuite.compat import CRF
1516
from spacy_crfsuite.features import CRFToken, Featurizer
1617
from spacy_crfsuite.tokenizer import Token, SpacyTokenizer
1718
from spacy_crfsuite.utils import override_defaults
@@ -99,7 +100,7 @@ def from_disk(self, path: Union[Path, str] = "model.pkl") -> "CRFExtractor":
99100
IOError
100101
"""
101102
ent_tagger = joblib.load(path)
102-
assert isinstance(ent_tagger, CRF)
103+
assert isinstance(ent_tagger, sklearn_crfsuite.estimator.CRF)
103104

104105
self.ent_tagger = ent_tagger
105106
return self

0 commit comments

Comments
 (0)