Skip to content

Commit 3e5ec5d

Browse files
authored
Merge pull request #21 from CogStack/MetaCAT_upd
Meta cat upd
2 parents f214549 + 1a5d846 commit 3e5ec5d

File tree

6 files changed

+298
-192
lines changed

6 files changed

+298
-192
lines changed

.github/workflows/main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@ on:
99
jobs:
1010
native-py:
1111

12-
runs-on: ubuntu-20.04
12+
runs-on: ubuntu-24.04
1313
strategy:
1414
matrix:
15-
python-version: [ '3.8', '3.9', '3.10', '3.11' ]
15+
python-version: [ '3.9', '3.10', '3.11', '3.12' ]
1616
max-parallel: 4
1717

1818
steps:

medcat/2_train_model/2_supervised_training/meta_annotation_training.ipynb

Lines changed: 216 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
"from medcat.cat import CAT\n",
1414
"from medcat.meta_cat import MetaCAT\n",
1515
"from medcat.config_meta_cat import ConfigMetaCAT\n",
16-
"from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBPE, TokenizerWrapperBERT\n",
17-
"from tokenizers import ByteLevelBPETokenizer"
16+
"from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT"
1817
]
1918
},
2019
{
@@ -31,82 +30,234 @@
3130
},
3231
{
3332
"cell_type": "markdown",
34-
"id": "5d0606ec",
33+
"id": "f310cef3",
3534
"metadata": {},
3635
"source": [
37-
"# Set parameters"
36+
"### Load the model pack with MetaCATs\n"
3837
]
3938
},
4039
{
4140
"cell_type": "code",
42-
"execution_count": 3,
41+
"execution_count": null,
4342
"id": "dd7a2e97",
4443
"metadata": {},
4544
"outputs": [],
4645
"source": [
47-
"# relative path to working_with_cogstack folder\n",
48-
"_rel_path = os.path.join(\"..\", \"..\", \"..\")\n",
49-
"# absolute path to working_with_cogstack folder\n",
50-
"base_path = os.path.abspath(_rel_path)\n",
51-
"# Load mct export\n",
52-
"ann_dir = os.path.join(base_path, \"data\", \"medcattrainer_export\")\n",
53-
"\n",
54-
"mctrainer_export_path = ann_dir + \"\" # name of your mct export\n",
55-
"\n",
46+
"model_pack = '<enter path to the model pack>' # .zip model pack location \n",
47+
"mctrainer_export = \"<enter mct export location>\" # name of your mct export"
48+
]
49+
},
50+
{
51+
"cell_type": "code",
52+
"execution_count": 4,
53+
"id": "921d5e9e",
54+
"metadata": {},
55+
"outputs": [],
56+
"source": [
5657
"# Load model\n",
57-
"model_dir = os.path.join(base_path, \"models\", \"modelpack\")\n",
58-
"modelpack = '' # name of modelpack\n",
59-
"model_pack_path = os.path.join(model_dir, modelpack)\n",
60-
" #output_modelpack = model_dir + f\"{today}_trained_model\"\n",
58+
"cat = CAT.load_model_pack(model_pack)"
59+
]
60+
},
61+
{
62+
"cell_type": "code",
63+
"execution_count": 5,
64+
"id": "b205d51b",
65+
"metadata": {},
66+
"outputs": [
67+
{
68+
"name": "stdout",
69+
"output_type": "stream",
70+
"text": [
71+
"There are: 3 meta cat models in this model pack.\n"
72+
]
73+
}
74+
],
75+
"source": [
6176
"\n",
62-
"# will be used to date the trained model\n",
63-
"today = str(date.today())\n",
64-
"today = today.replace(\"-\",\"\")\n",
77+
"# Check what meta cat models are in this model pack.\n",
78+
"print(f'There are: {len(cat._meta_cats)} meta cat models in this model pack.')"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 6,
84+
"id": "31d7632a",
85+
"metadata": {},
86+
"outputs": [
87+
{
88+
"name": "stdout",
89+
"output_type": "stream",
90+
"text": [
91+
"{\n",
92+
" \"Category Name\": \"Temporality\",\n",
93+
" \"Description\": \"No description\",\n",
94+
" \"Classes\": {\n",
95+
" \"Past\": 0,\n",
96+
" \"Recent\": 1,\n",
97+
" \"Future\": 2\n",
98+
" },\n",
99+
" \"Model\": \"bert\"\n",
100+
"}\n"
101+
]
102+
}
103+
],
104+
"source": [
105+
"print(cat._meta_cats[0])"
106+
]
107+
},
108+
{
109+
"cell_type": "code",
110+
"execution_count": 7,
111+
"id": "e9180c4c",
112+
"metadata": {},
113+
"outputs": [
114+
{
115+
"name": "stdout",
116+
"output_type": "stream",
117+
"text": [
118+
"{\n",
119+
" \"Category Name\": \"Presence\",\n",
120+
" \"Description\": \"No description\",\n",
121+
" \"Classes\": {\n",
122+
" \"Hypothetical (N/A)\": 1,\n",
123+
" \"Not present (False)\": 0,\n",
124+
" \"Present (True)\": 2\n",
125+
" },\n",
126+
" \"Model\": \"bert\"\n",
127+
"}\n"
128+
]
129+
}
130+
],
131+
"source": [
132+
"print(cat._meta_cats[1])"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": 8,
138+
"id": "275ca9ff",
139+
"metadata": {},
140+
"outputs": [
141+
{
142+
"name": "stdout",
143+
"output_type": "stream",
144+
"text": [
145+
"{\n",
146+
" \"Category Name\": \"Experiencer\",\n",
147+
" \"Description\": \"No description\",\n",
148+
" \"Classes\": {\n",
149+
" \"Family\": 1,\n",
150+
" \"Other\": 0,\n",
151+
" \"Patient\": 2\n",
152+
" },\n",
153+
" \"Model\": \"bert\"\n",
154+
"}\n"
155+
]
156+
}
157+
],
158+
"source": [
159+
"print(cat._meta_cats[2])"
160+
]
161+
},
162+
{
163+
"cell_type": "markdown",
164+
"id": "3047b1d9",
165+
"metadata": {},
166+
"source": [
167+
"<b> NOTE: </b> \n",
168+
" The name for the classification task can vary. E.g: The Category Name for 'Experiencer' can be 'Subject', as it has been configured an annoated in MedCATTrainer this way, but the model expects 'Experiencer'\n",
169+
" \n",
170+
" To accomodate for this, we have a list that stores the variations for the alternate names. This attribute can be found under `mc.config.general.alternative_category_names`\n",
65171
"\n",
66-
"# Initialise meta_ann models\n",
67-
"if model_pack_path[-4:] == '.zip':\n",
68-
" base_dir_meta_models = model_pack_path[:-4]\n",
69-
"else:\n",
70-
" base_dir_meta_models = model_pack_path\n",
172+
"E.g. for Experiencer, it will be pre-loaded as alternative_category_names = ['Experiencer','Subject']\n",
71173
"\n",
72-
"# Iterate through the meta_models contained in the model\n",
73-
"meta_model_names = [] # These Meta_annotation tasks should correspond to the ones labelled in the mcttrainer export\n",
74-
"for dirpath, dirnames, filenames in os.walk(base_dir_meta_models):\n",
75-
" for dirname in dirnames:\n",
76-
" if dirname.startswith('meta_'):\n",
77-
" meta_model_names.append(dirname[5:])"
174+
"Set this list to ensure during training / fine-tuning the model is aware of alternative names for classes."
175+
]
176+
},
177+
{
178+
"cell_type": "code",
179+
"execution_count": null,
180+
"id": "1ca00fb0",
181+
"metadata": {},
182+
"outputs": [],
183+
"source": [
184+
"print(cat._meta_cats[0].config.general.alternative_category_names)"
78185
]
79186
},
80187
{
81188
"cell_type": "markdown",
82-
"id": "35aa5605",
189+
"id": "5dba296c",
83190
"metadata": {},
84191
"source": [
85-
"Before you run the next section please double check that the model meta_annotation names matches to those specified in the mct export.\n",
86-
"\n"
192+
"💡 In case you are using older modelpacks, the above field will be empty. In that case, "
193+
]
194+
},
195+
{
196+
"cell_type": "code",
197+
"execution_count": null,
198+
"id": "92e41964",
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"# Only run in case the above output is an empty list\n",
203+
"category_name_mapping = [[\"Presence\"],[\"Temporality\",\"Time\"],[\"Experiencer\",\"Subject\"]]\n",
204+
"lookup = {item: group for group in category_name_mapping for item in group}\n",
205+
"\n",
206+
"for meta_model in range(len(cat._meta_cats)):\n",
207+
" cat._meta_cats[meta_model].config.general.alternative_category_names = lookup.get(cat._meta_cats[meta_model].config.general.category_name)"
87208
]
88209
},
89210
{
90211
"cell_type": "markdown",
91-
"id": "8bf6f5c3",
212+
"id": "12e91f77",
92213
"metadata": {},
93214
"source": [
94-
"Depending on the model pack you have, please run the LSTM model or BERT model section. <br>\n",
95-
"If you are unsure, use this section to check the model type."
215+
"<b> NOTE: </b> \n",
216+
" The name for the classes can vary too. Some sites may have trained a MetaCAT model for the same task, but called a class value a slightly different name.\n",
217+
" \n",
218+
" E.g: For the Presence task, the class name can be 'Not present (False)' or 'False'\n",
219+
" \n",
220+
" To accomodate for this, we have a mapping that stores the variations for the alternate names. This attribute can be found under `mc.config.general.alternative_class_names`\n",
221+
"\n",
222+
" E.g. for Presence, it will be pre-loaded as alternative_class_names = [[\"Hypothetical (N/A)\",\"Hypothetical\"],[\"Not present (False)\",\"False\"],[\"Present (True)\",\"True\"]]"
223+
]
224+
},
225+
{
226+
"cell_type": "code",
227+
"execution_count": null,
228+
"id": "5f6b06e2",
229+
"metadata": {},
230+
"outputs": [],
231+
"source": [
232+
"print(cat._meta_cats[0].config.general.alternative_class_names)"
233+
]
234+
},
235+
{
236+
"cell_type": "markdown",
237+
"id": "3c97c986",
238+
"metadata": {},
239+
"source": [
240+
"💡 In case you are using older modelpacks, the above field will be empty. In that case, please run the following code:"
96241
]
97242
},
98243
{
99244
"cell_type": "code",
100245
"execution_count": null,
101-
"id": "2933f7e1",
246+
"id": "0fdfae70",
102247
"metadata": {},
103248
"outputs": [],
104249
"source": [
105-
"for meta_model in meta_model_names:\n",
106-
" config_file = os.path.join(base_dir_meta_models,\"meta_\"+meta_model,\"config.json\")\n",
107-
" with open(config_file, 'r') as jfile:\n",
108-
" config_dict = json.load(jfile)\n",
109-
" print(f\"Model used for meta_{meta_model}:\",config_dict['model']['model_name'])"
250+
"# Only run in case the above output is an empty list\n",
251+
"class_name_mapping = {\n",
252+
" \"Temporality\": [[\"Past\"], [\"Recent\", \"Present\"], [\"Future\"]],\n",
253+
" \"Time\": [[\"Past\"], [\"Recent\", \"Present\"], [\"Future\"]],\n",
254+
" \"Experiencer\": [[\"Family\"], [\"Other\"], [\"Patient\"]],\n",
255+
" \"Subject\": [[\"Family\"], [\"Other\"], [\"Patient\"]],\n",
256+
" \"Presence\": [[\"Hypothetical (N/A)\", \"Hypothetical\"], [\"Not present (False)\", \"False\"], [\"Present (True)\", \"True\"]]\n",
257+
"}\n",
258+
"\n",
259+
"for meta_model in range(len(cat._meta_cats)):\n",
260+
" cat._meta_cats[meta_model].config.general.alternative_class_names = class_name_mapping[cat._meta_cats[meta_model].config.general.category_name]"
110261
]
111262
},
112263
{
@@ -124,30 +275,31 @@
124275
"metadata": {},
125276
"outputs": [],
126277
"source": [
127-
"for meta_model in meta_model_names:\n",
128-
" \n",
129-
" # load the meta_model\n",
130-
" mc = MetaCAT.load(save_dir_path=os.path.join(base_dir_meta_models,\"meta_\"+meta_model))\n",
278+
"# Train the first meta cat model - 'Temporality' Task.\n",
279+
"meta_cat = cat._meta_cats[0]\n",
131280
"\n",
132-
" # changing parameters\n",
133-
" mc.config.train['nepochs'] = 15\n",
281+
"# to overwrite the existing model, resave the fine-tuned model with the same model pack dir\n",
282+
"meta_cat_task = meta_cat.config.general.category_name\n",
283+
"model_pack_dir = '<enter path to meta model pack>'\n",
284+
"save_dir_path = os.path.join(model_pack_dir,\"meta_\"+ meta_cat_task)\n",
134285
"\n",
135-
" save_dir_path= \"test_meta_\"+meta_model # Where to save the meta_model and results. \n",
136-
" #Ideally this should replace the meta_models inside the modelpack\n",
286+
"# to save the new model elsewhere, uncomment the below line\n",
287+
"#save_dir_path= \"test_meta_\"+meta_cat_task # Where to save the meta_model and results. \n",
137288
"\n",
138-
" # train the meta_model\n",
139-
" results = mc.train_from_json(mctrainer_export_path, save_dir_path=save_dir_path)\n",
140-
" \n",
141-
" # Save results\n",
142-
" json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results.json'), 'w'))"
289+
"# train the meta_model\n",
290+
"results = meta_cat.train_from_json(mctrainer_export, save_dir_path=save_dir_path)\n",
291+
"\n",
292+
"# Save results\n",
293+
"json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_cat_task+'_results.json'), 'w'))"
143294
]
144295
},
145296
{
146297
"cell_type": "markdown",
147298
"id": "ab23e424",
148299
"metadata": {},
149300
"source": [
150-
"## If you dont have the model packs, and are training from scratch"
301+
"## If you dont have the model packs, and are training from scratch\n",
302+
"<b>⚠️This is very rare, it is recommended to always use the model packs and then fine-tune them</b>"
151303
]
152304
},
153305
{
@@ -167,23 +319,22 @@
167319
"\n",
168320
"tokenizer = TokenizerWrapperBERT.load(\"\", config.model['model_variant'])\n",
169321
"\n",
170-
"save_dir_path= \"test_meta\" # Where to save the meta_model and results. \n",
171-
"#Ideally this should replace the meta_models inside the modelpack\n",
322+
"save_dir_path= \"test_meta_\" + meta_cat_task # Where to save the meta_model and results. \n",
172323
"\n",
173324
"# Initialise and train meta_model\n",
174325
"mc = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)\n",
175-
"results = mc.train_from_json(mctrainer_export_path, save_dir_path=save_dir_path)\n",
326+
"results = mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path)\n",
176327
"\n",
177328
"# Save results\n",
178-
"json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results.json'), 'w'))"
329+
"json.dump(results['report'], open(os.path.join(save_dir_path,'meta_' + meta_cat_task+'_results.json'), 'w'))"
179330
]
180331
}
181332
],
182333
"metadata": {
183334
"kernelspec": {
184-
"display_name": "Python 3",
335+
"display_name": "Python [conda env:cattrainer]",
185336
"language": "python",
186-
"name": "python3"
337+
"name": "conda-env-cattrainer-py"
187338
},
188339
"language_info": {
189340
"codemirror_mode": {
@@ -195,7 +346,7 @@
195346
"name": "python",
196347
"nbconvert_exporter": "python",
197348
"pygments_lexer": "ipython3",
198-
"version": "3.8.8"
349+
"version": "3.11.11"
199350
}
200351
},
201352
"nbformat": 4,

0 commit comments

Comments
 (0)