@@ -99,10 +99,15 @@ def from_disk(self, path: Union[Path, str] = "model.pkl") -> "CRFExtractor":
99
99
Raises:
100
100
IOError
101
101
"""
102
- ent_tagger = joblib .load (path )
103
- assert isinstance (ent_tagger , sklearn_crfsuite .estimator .CRF )
102
+ state_dict = joblib .load (path )
103
+
104
+ if isinstance (state_dict , sklearn_crfsuite .estimator .CRF ):
105
+ self .ent_tagger = state_dict
106
+
107
+ elif isinstance (state_dict , dict ):
108
+ for attr , attr_value in state_dict .items ():
109
+ setattr (self , attr , attr_value )
104
110
105
- self .ent_tagger = ent_tagger
106
111
return self
107
112
108
113
def to_disk (self , path : Union [Path , str ] = "model.pkl" ) -> None :
@@ -115,7 +120,12 @@ def to_disk(self, path: Union[Path, str] = "model.pkl") -> None:
115
120
RuntimeError - if entity tagger is not fitted for runtime.
116
121
"""
117
122
self ._check_runtime ()
118
- joblib .dump (self .ent_tagger , path )
123
+
124
+ state_dict = {
125
+ "component_config" : self .component_config ,
126
+ "ent_tagger" : self .ent_tagger ,
127
+ }
128
+ joblib .dump (state_dict , path )
119
129
120
130
def use_dense_features (self ) -> bool :
121
131
"""A predicate to test if dense features should be used, according
0 commit comments