Skip to content

Commit a42ae74

Browse files
authored
Fix hyperopt tuning via fine_tune method. (#27)
1 parent c4538e9 commit a42ae74

File tree

6 files changed

+91
-22
lines changed

6 files changed

+91
-22
lines changed

poetry.lock

Lines changed: 10 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ spacy = "^3.4.4"
3434
sklearn-crfsuite = "^0.3.6"
3535
joblib = "^1.2.0"
3636
scikit-learn = "^1.2.0"
37+
pytest = "^7.2.2"
3738

3839
[tool.poetry.group.dev.dependencies]
3940
autopep8 = "^2.0.1"

spacy_crfsuite/crf_extractor.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,22 @@
22
import joblib
33
import warnings
44
import sklearn_crfsuite
5+
import sklearn_crfsuite.metrics as _metrics
56

67
from collections import Counter
78
from pathlib import Path
89
from typing import Dict, Text, Any, Optional, List, Tuple, Union, Callable
910

1011
from spacy.language import Language
1112
from spacy.tokens.doc import Doc
12-
from sklearn.metrics import classification_report, f1_score
13-
14-
from spacy_crfsuite.bilou import entity_name_from_tag, bilou_prefix_from_tag, NO_ENTITY_TAG
13+
from sklearn.metrics import classification_report
14+
15+
from spacy_crfsuite.bilou import (
16+
entity_name_from_tag,
17+
bilou_prefix_from_tag,
18+
NO_ENTITY_TAG,
19+
BILOU_PREFIXES,
20+
)
1521
from spacy_crfsuite.compat import CRF
1622
from spacy_crfsuite.features import CRFToken, Featurizer
1723
from spacy_crfsuite.tokenizer import Token, SpacyTokenizer
@@ -254,8 +260,13 @@ def fine_tune(
254260

255261
X_train = [self._crf_tokens_to_features(sent) for sent in val_samples]
256262
y_train = [self._crf_tokens_to_tags(sent) for sent in val_samples]
257-
labels = list(set(itertools.chain.from_iterable(y_train)) - {NO_ENTITY_TAG})
258-
f1_scorer = make_scorer(f1_score, average="weighted", labels=labels)
263+
264+
labels = set(itertools.chain.from_iterable(y_train)) - {NO_ENTITY_TAG}
265+
labels = list(labels)
266+
267+
f1_scorer = make_scorer(
268+
_metrics.flat_f1_score, average="weighted", labels=labels, zero_division=1
269+
)
259270
rs = RandomizedSearchCV(
260271
crf,
261272
params_space,

spacy_crfsuite/test.py

Whitespace-only changes.

tests/test_example.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,8 @@ def create_component(nlp, name):
4040
"George Walker Bush (born July 6, 1946) is an American politician and businessman "
4141
"who served as the 43rd president of the United States from 2001 to 2009.")
4242

43-
for ent in doc.ents:
44-
print(ent, "-", ent.label_)
45-
46-
# Output:
47-
# George Walker Bush - PER
48-
# American - MISC
49-
# United States - LOC
43+
assert [(ent.text, ent.label_) for ent in doc.ents] == [
44+
('George Walker Bush', 'PER'),
45+
('American', 'MISC'),
46+
('United States', 'LOC')
47+
]

tests/test_hyperopt.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
3+
from spacy_crfsuite import CRFExtractor, read_file
4+
from spacy_crfsuite.tokenizer import SpacyTokenizer
5+
from spacy_crfsuite.train import gold_example_to_crf_tokens
6+
7+
8+
@pytest.fixture()
9+
def dev_examples(en_core_web_md):
10+
tokenizer = SpacyTokenizer(en_core_web_md)
11+
12+
dev_examples = [
13+
gold_example_to_crf_tokens(
14+
ex, tokenizer=tokenizer, use_dense_features=False, bilou=True
15+
) for ex in read_file("examples/restaurent_search.md")
16+
]
17+
18+
return dev_examples
19+
20+
21+
def test_hyperparam_optim(dev_examples):
22+
crf_extractor = CRFExtractor(component_config={
23+
"features": [
24+
[
25+
"low",
26+
"title",
27+
"upper",
28+
"pos",
29+
"pos2"
30+
],
31+
[
32+
"low",
33+
"bias",
34+
"prefix5",
35+
"prefix2",
36+
"suffix5",
37+
"suffix3",
38+
"suffix2",
39+
"upper",
40+
"title",
41+
"digit",
42+
"pos",
43+
"pos2"
44+
],
45+
[
46+
"low",
47+
"title",
48+
"upper",
49+
"pos",
50+
"pos2"
51+
],
52+
],
53+
"c1": 0.01,
54+
"c2": 0.22
55+
})
56+
57+
rs = crf_extractor.fine_tune(dev_examples, cv=5, n_iter=30, random_state=42)
58+
assert rs.best_params_ == {'c1': 0.029919384304340338, 'c2': 0.10056154322399698}
59+
assert rs.best_score_ == 0.39999999999999997

0 commit comments

Comments
 (0)