Skip to content

Commit 44f4615

Browse files
committed
update test and replace_aug
1 parent 25c7b9f commit 44f4615

File tree

4 files changed

+75
-3
lines changed

4 files changed

+75
-3
lines changed

TODO.txt

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
- 标注数据集 docker 镜像维护
2-
- 数据增强 - 实体替换
32
- 数据增强 - 随机替换非关键字符
43
- 数据增强 - 其他
54
- 时间解析

jionlp/textaug/replace_entity.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,11 @@ def _augment_one(self, text, entities):
9898
if self.random.random() < self.replace_ratio:
9999
# 将该实体从词典中随机选择一个做替换
100100
orig_entity = self.random.choice(entities)
101-
new_entity_text = self.random.choice(
102-
list(self.entities_dict[orig_entity['type']].keys()))
101+
102+
candidate_list = list(self.entities_dict[orig_entity['type']].keys())
103+
if len(candidate_list) == 0:
104+
continue
105+
new_entity_text = self.random.choice(candidate_list)
103106

104107
orig_len = len(orig_entity['text'])
105108
new_len = len(new_entity_text)

test/test_main.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
2+
import unittest
3+
4+
from test_text_aug import TestTextAug
5+
6+
7+
if __name__ == '__main__':
8+
9+
suite = unittest.TestSuite()
10+
test_text_aug = [TestTextAug('test_ReplaceEntity')]
11+
suite.addTests(test_text_aug)
12+
13+
14+
15+
runner = unittest.TextTestRunner(verbosity=1)
16+
runner.run(suite)
17+
18+
19+

test/test_text_aug.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
3+
import unittest
4+
5+
import jionlp as jio
6+
7+
8+
class TestTextAug(unittest.TestCase):
9+
""" 测试文本数据增强工具 """
10+
11+
def test_ReplaceEntity(self):
12+
""" test class ReplaceEntity """
13+
14+
# 准备的词典
15+
entities_dict = {
16+
"Person": {"马成宇": 1},
17+
"Company": {"百度": 4, "国力教育公司": 1},
18+
"Organization": {"延平区人民法院": 1}
19+
}
20+
# 输入的序列标注样本
21+
text = '腾讯致力于解决冲突,阿里巴巴致力于玩。小马爱玩。'
22+
entities = [{'type': 'Company', 'text': '腾讯', 'offset': (0, 2)},
23+
{'type': 'Company', 'text': '阿里巴巴', 'offset': (10, 14)},
24+
{'type': 'Person', 'text': '小马', 'offset': (19, 21)}]
25+
replace_entity = jio.ReplaceEntity(entities_dict)
26+
texts, entities = replace_entity(text, entities)
27+
28+
# 预期结果
29+
standard_texts = ['腾讯致力于解决冲突,国力教育公司致力于玩。小马爱玩。',
30+
'百度致力于解决冲突,阿里巴巴致力于玩。小马爱玩。',
31+
'腾讯致力于解决冲突,阿里巴巴致力于玩。马成宇爱玩。']
32+
standard_entities = [
33+
[{'type': 'Company', 'text': '腾讯', 'offset': (0, 2)},
34+
{'text': '国力教育公司', 'type': 'Company', 'offset': [10, 16]},
35+
{'text': '小马', 'type': 'Person', 'offset': (21, 23)}],
36+
[{'text': '百度', 'type': 'Company', 'offset': [0, 2]},
37+
{'text': '阿里巴巴', 'type': 'Company', 'offset': (10, 14)},
38+
{'text': '小马', 'type': 'Person', 'offset': (19, 21)}],
39+
[{'type': 'Company', 'text': '腾讯', 'offset': (0, 2)},
40+
{'type': 'Company', 'text': '阿里巴巴', 'offset': (10, 14)},
41+
{'text': '马成宇', 'type': 'Person', 'offset': [19, 22]}]]
42+
43+
self.assertEqual(texts, standard_texts)
44+
self.assertEqual(entities, standard_entities)
45+
46+
# def test_
47+
48+
49+
50+
51+

0 commit comments

Comments
 (0)