Skip to content

Commit 1d0d0e6

Browse files
authored
change .to_disk() to also persist component config (#22)
* fix issue with persistance to disk * bump version -> 1.5.0
1 parent 14ff9b0 commit 1d0d0e6

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

spacy_crfsuite/about.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
__title__ = "spacy_crfsuite"
2-
__version__ = "1.4.0"
2+
__version__ = "1.5.0"
33
__summary__ = "spaCy pipeline component for CRF entity extraction"
44
__author__ = "Tal Almagor"
55
__email__ = "[email protected]"

spacy_crfsuite/crf_extractor.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,15 @@ def from_disk(self, path: Union[Path, str] = "model.pkl") -> "CRFExtractor":
9999
Raises:
100100
IOError
101101
"""
102-
ent_tagger = joblib.load(path)
103-
assert isinstance(ent_tagger, sklearn_crfsuite.estimator.CRF)
102+
state_dict = joblib.load(path)
103+
104+
if isinstance(state_dict, sklearn_crfsuite.estimator.CRF):
105+
self.ent_tagger = state_dict
106+
107+
elif isinstance(state_dict, dict):
108+
for attr, attr_value in state_dict.items():
109+
setattr(self, attr, attr_value)
104110

105-
self.ent_tagger = ent_tagger
106111
return self
107112

108113
def to_disk(self, path: Union[Path, str] = "model.pkl") -> None:
@@ -115,7 +120,12 @@ def to_disk(self, path: Union[Path, str] = "model.pkl") -> None:
115120
RuntimeError - if entity tagger is not fitted for runtime.
116121
"""
117122
self._check_runtime()
118-
joblib.dump(self.ent_tagger, path)
123+
124+
state_dict = {
125+
"component_config": self.component_config,
126+
"ent_tagger": self.ent_tagger,
127+
}
128+
joblib.dump(state_dict, path)
119129

120130
def use_dense_features(self) -> bool:
121131
"""A predicate to test if dense features should be used, according

0 commit comments

Comments
 (0)