|
13 | 13 | "from medcat.cat import CAT\n",
|
14 | 14 | "from medcat.meta_cat import MetaCAT\n",
|
15 | 15 | "from medcat.config_meta_cat import ConfigMetaCAT\n",
|
16 |
| - "from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBPE, TokenizerWrapperBERT\n", |
17 |
| - "from tokenizers import ByteLevelBPETokenizer" |
| 16 | + "from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT" |
18 | 17 | ]
|
19 | 18 | },
|
20 | 19 | {
|
|
31 | 30 | },
|
32 | 31 | {
|
33 | 32 | "cell_type": "markdown",
|
34 |
| - "id": "5d0606ec", |
| 33 | + "id": "f310cef3", |
35 | 34 | "metadata": {},
|
36 | 35 | "source": [
|
37 |
| - "# Set parameters" |
| 36 | + "### Load the model pack with MetaCATs\n" |
38 | 37 | ]
|
39 | 38 | },
|
40 | 39 | {
|
41 | 40 | "cell_type": "code",
|
42 |
| - "execution_count": 3, |
| 41 | + "execution_count": null, |
43 | 42 | "id": "dd7a2e97",
|
44 | 43 | "metadata": {},
|
45 | 44 | "outputs": [],
|
46 | 45 | "source": [
|
47 |
| - "# relative path to working_with_cogstack folder\n", |
48 |
| - "_rel_path = os.path.join(\"..\", \"..\", \"..\")\n", |
49 |
| - "# absolute path to working_with_cogstack folder\n", |
50 |
| - "base_path = os.path.abspath(_rel_path)\n", |
51 |
| - "# Load mct export\n", |
52 |
| - "ann_dir = os.path.join(base_path, \"data\", \"medcattrainer_export\")\n", |
53 |
| - "\n", |
54 |
| - "mctrainer_export_path = ann_dir + \"\" # name of your mct export\n", |
55 |
| - "\n", |
| 46 | + "model_pack = '<enter path to the model pack>' # .zip model pack location \n", |
| 47 | + "mctrainer_export = \"<enter mct export location>\" # name of your mct export" |
| 48 | + ] |
| 49 | + }, |
| 50 | + { |
| 51 | + "cell_type": "code", |
| 52 | + "execution_count": 4, |
| 53 | + "id": "921d5e9e", |
| 54 | + "metadata": {}, |
| 55 | + "outputs": [], |
| 56 | + "source": [ |
56 | 57 | "# Load model\n",
|
57 |
| - "model_dir = os.path.join(base_path, \"models\", \"modelpack\")\n", |
58 |
| - "modelpack = '' # name of modelpack\n", |
59 |
| - "model_pack_path = os.path.join(model_dir, modelpack)\n", |
60 |
| - " #output_modelpack = model_dir + f\"{today}_trained_model\"\n", |
| 58 | + "cat = CAT.load_model_pack(model_pack)" |
| 59 | + ] |
| 60 | + }, |
| 61 | + { |
| 62 | + "cell_type": "code", |
| 63 | + "execution_count": 5, |
| 64 | + "id": "b205d51b", |
| 65 | + "metadata": {}, |
| 66 | + "outputs": [ |
| 67 | + { |
| 68 | + "name": "stdout", |
| 69 | + "output_type": "stream", |
| 70 | + "text": [ |
| 71 | + "There are: 3 meta cat models in this model pack.\n" |
| 72 | + ] |
| 73 | + } |
| 74 | + ], |
| 75 | + "source": [ |
61 | 76 | "\n",
|
62 |
| - "# will be used to date the trained model\n", |
63 |
| - "today = str(date.today())\n", |
64 |
| - "today = today.replace(\"-\",\"\")\n", |
| 77 | + "# Check what meta cat models are in this model pack.\n", |
| 78 | + "print(f'There are: {len(cat._meta_cats)} meta cat models in this model pack.')" |
| 79 | + ] |
| 80 | + }, |
| 81 | + { |
| 82 | + "cell_type": "code", |
| 83 | + "execution_count": 6, |
| 84 | + "id": "31d7632a", |
| 85 | + "metadata": {}, |
| 86 | + "outputs": [ |
| 87 | + { |
| 88 | + "name": "stdout", |
| 89 | + "output_type": "stream", |
| 90 | + "text": [ |
| 91 | + "{\n", |
| 92 | + " \"Category Name\": \"Temporality\",\n", |
| 93 | + " \"Description\": \"No description\",\n", |
| 94 | + " \"Classes\": {\n", |
| 95 | + " \"Past\": 0,\n", |
| 96 | + " \"Recent\": 1,\n", |
| 97 | + " \"Future\": 2\n", |
| 98 | + " },\n", |
| 99 | + " \"Model\": \"bert\"\n", |
| 100 | + "}\n" |
| 101 | + ] |
| 102 | + } |
| 103 | + ], |
| 104 | + "source": [ |
| 105 | + "print(cat._meta_cats[0])" |
| 106 | + ] |
| 107 | + }, |
| 108 | + { |
| 109 | + "cell_type": "code", |
| 110 | + "execution_count": 7, |
| 111 | + "id": "e9180c4c", |
| 112 | + "metadata": {}, |
| 113 | + "outputs": [ |
| 114 | + { |
| 115 | + "name": "stdout", |
| 116 | + "output_type": "stream", |
| 117 | + "text": [ |
| 118 | + "{\n", |
| 119 | + " \"Category Name\": \"Presence\",\n", |
| 120 | + " \"Description\": \"No description\",\n", |
| 121 | + " \"Classes\": {\n", |
| 122 | + " \"Hypothetical (N/A)\": 1,\n", |
| 123 | + " \"Not present (False)\": 0,\n", |
| 124 | + " \"Present (True)\": 2\n", |
| 125 | + " },\n", |
| 126 | + " \"Model\": \"bert\"\n", |
| 127 | + "}\n" |
| 128 | + ] |
| 129 | + } |
| 130 | + ], |
| 131 | + "source": [ |
| 132 | + "print(cat._meta_cats[1])" |
| 133 | + ] |
| 134 | + }, |
| 135 | + { |
| 136 | + "cell_type": "code", |
| 137 | + "execution_count": 8, |
| 138 | + "id": "275ca9ff", |
| 139 | + "metadata": {}, |
| 140 | + "outputs": [ |
| 141 | + { |
| 142 | + "name": "stdout", |
| 143 | + "output_type": "stream", |
| 144 | + "text": [ |
| 145 | + "{\n", |
| 146 | + " \"Category Name\": \"Experiencer\",\n", |
| 147 | + " \"Description\": \"No description\",\n", |
| 148 | + " \"Classes\": {\n", |
| 149 | + " \"Family\": 1,\n", |
| 150 | + " \"Other\": 0,\n", |
| 151 | + " \"Patient\": 2\n", |
| 152 | + " },\n", |
| 153 | + " \"Model\": \"bert\"\n", |
| 154 | + "}\n" |
| 155 | + ] |
| 156 | + } |
| 157 | + ], |
| 158 | + "source": [ |
| 159 | + "print(cat._meta_cats[2])" |
| 160 | + ] |
| 161 | + }, |
| 162 | + { |
| 163 | + "cell_type": "markdown", |
| 164 | + "id": "3047b1d9", |
| 165 | + "metadata": {}, |
| 166 | + "source": [ |
| 167 | + "<b> NOTE: </b> \n", |
| 168 | + " The name for the classification task can vary. E.g: The Category Name for 'Experiencer' can be 'Subject', as it has been configured an annoated in MedCATTrainer this way, but the model expects 'Experiencer'\n", |
| 169 | + " \n", |
| 170 | + " To accomodate for this, we have a list that stores the variations for the alternate names. This attribute can be found under `mc.config.general.alternative_category_names`\n", |
65 | 171 | "\n",
|
66 |
| - "# Initialise meta_ann models\n", |
67 |
| - "if model_pack_path[-4:] == '.zip':\n", |
68 |
| - " base_dir_meta_models = model_pack_path[:-4]\n", |
69 |
| - "else:\n", |
70 |
| - " base_dir_meta_models = model_pack_path\n", |
| 172 | + "E.g. for Experiencer, it will be pre-loaded as alternative_category_names = ['Experiencer','Subject']\n", |
71 | 173 | "\n",
|
72 |
| - "# Iterate through the meta_models contained in the model\n", |
73 |
| - "meta_model_names = [] # These Meta_annotation tasks should correspond to the ones labelled in the mcttrainer export\n", |
74 |
| - "for dirpath, dirnames, filenames in os.walk(base_dir_meta_models):\n", |
75 |
| - " for dirname in dirnames:\n", |
76 |
| - " if dirname.startswith('meta_'):\n", |
77 |
| - " meta_model_names.append(dirname[5:])" |
| 174 | + "Set this list to ensure during training / fine-tuning the model is aware of alternative names for classes." |
| 175 | + ] |
| 176 | + }, |
| 177 | + { |
| 178 | + "cell_type": "code", |
| 179 | + "execution_count": null, |
| 180 | + "id": "1ca00fb0", |
| 181 | + "metadata": {}, |
| 182 | + "outputs": [], |
| 183 | + "source": [ |
| 184 | + "print(cat._meta_cats[0].config.general.alternative_category_names)" |
78 | 185 | ]
|
79 | 186 | },
|
80 | 187 | {
|
81 | 188 | "cell_type": "markdown",
|
82 |
| - "id": "35aa5605", |
| 189 | + "id": "5dba296c", |
83 | 190 | "metadata": {},
|
84 | 191 | "source": [
|
85 |
| - "Before you run the next section please double check that the model meta_annotation names matches to those specified in the mct export.\n", |
86 |
| - "\n" |
| 192 | + "💡 In case you are using older modelpacks, the above field will be empty. In that case, " |
| 193 | + ] |
| 194 | + }, |
| 195 | + { |
| 196 | + "cell_type": "code", |
| 197 | + "execution_count": null, |
| 198 | + "id": "92e41964", |
| 199 | + "metadata": {}, |
| 200 | + "outputs": [], |
| 201 | + "source": [ |
| 202 | + "# Only run in case the above output is an empty list\n", |
| 203 | + "category_name_mapping = [[\"Presence\"],[\"Temporality\",\"Time\"],[\"Experiencer\",\"Subject\"]]\n", |
| 204 | + "lookup = {item: group for group in category_name_mapping for item in group}\n", |
| 205 | + "\n", |
| 206 | + "for meta_model in range(len(cat._meta_cats)):\n", |
| 207 | + " cat._meta_cats[meta_model].config.general.alternative_category_names = lookup.get(cat._meta_cats[meta_model].config.general.category_name)" |
87 | 208 | ]
|
88 | 209 | },
|
89 | 210 | {
|
90 | 211 | "cell_type": "markdown",
|
91 |
| - "id": "8bf6f5c3", |
| 212 | + "id": "12e91f77", |
92 | 213 | "metadata": {},
|
93 | 214 | "source": [
|
94 |
| - "Depending on the model pack you have, please run the LSTM model or BERT model section. <br>\n", |
95 |
| - "If you are unsure, use this section to check the model type." |
| 215 | + "<b> NOTE: </b> \n", |
| 216 | + " The name for the classes can vary too. Some sites may have trained a MetaCAT model for the same task, but called a class value a slightly different name.\n", |
| 217 | + " \n", |
| 218 | + " E.g: For the Presence task, the class name can be 'Not present (False)' or 'False'\n", |
| 219 | + " \n", |
| 220 | + " To accomodate for this, we have a mapping that stores the variations for the alternate names. This attribute can be found under `mc.config.general.alternative_class_names`\n", |
| 221 | + "\n", |
| 222 | + " E.g. for Presence, it will be pre-loaded as alternative_class_names = [[\"Hypothetical (N/A)\",\"Hypothetical\"],[\"Not present (False)\",\"False\"],[\"Present (True)\",\"True\"]]" |
| 223 | + ] |
| 224 | + }, |
| 225 | + { |
| 226 | + "cell_type": "code", |
| 227 | + "execution_count": null, |
| 228 | + "id": "5f6b06e2", |
| 229 | + "metadata": {}, |
| 230 | + "outputs": [], |
| 231 | + "source": [ |
| 232 | + "print(cat._meta_cats[0].config.general.alternative_class_names)" |
| 233 | + ] |
| 234 | + }, |
| 235 | + { |
| 236 | + "cell_type": "markdown", |
| 237 | + "id": "3c97c986", |
| 238 | + "metadata": {}, |
| 239 | + "source": [ |
| 240 | + "💡 In case you are using older modelpacks, the above field will be empty. In that case, please run the following code:" |
96 | 241 | ]
|
97 | 242 | },
|
98 | 243 | {
|
99 | 244 | "cell_type": "code",
|
100 | 245 | "execution_count": null,
|
101 |
| - "id": "2933f7e1", |
| 246 | + "id": "0fdfae70", |
102 | 247 | "metadata": {},
|
103 | 248 | "outputs": [],
|
104 | 249 | "source": [
|
105 |
| - "for meta_model in meta_model_names:\n", |
106 |
| - " config_file = os.path.join(base_dir_meta_models,\"meta_\"+meta_model,\"config.json\")\n", |
107 |
| - " with open(config_file, 'r') as jfile:\n", |
108 |
| - " config_dict = json.load(jfile)\n", |
109 |
| - " print(f\"Model used for meta_{meta_model}:\",config_dict['model']['model_name'])" |
| 250 | + "# Only run in case the above output is an empty list\n", |
| 251 | + "class_name_mapping = {\n", |
| 252 | + " \"Temporality\": [[\"Past\"], [\"Recent\", \"Present\"], [\"Future\"]],\n", |
| 253 | + " \"Time\": [[\"Past\"], [\"Recent\", \"Present\"], [\"Future\"]],\n", |
| 254 | + " \"Experiencer\": [[\"Family\"], [\"Other\"], [\"Patient\"]],\n", |
| 255 | + " \"Subject\": [[\"Family\"], [\"Other\"], [\"Patient\"]],\n", |
| 256 | + " \"Presence\": [[\"Hypothetical (N/A)\", \"Hypothetical\"], [\"Not present (False)\", \"False\"], [\"Present (True)\", \"True\"]]\n", |
| 257 | + "}\n", |
| 258 | + "\n", |
| 259 | + "for meta_model in range(len(cat._meta_cats)):\n", |
| 260 | + " cat._meta_cats[meta_model].config.general.alternative_class_names = class_name_mapping[cat._meta_cats[meta_model].config.general.category_name]" |
110 | 261 | ]
|
111 | 262 | },
|
112 | 263 | {
|
|
124 | 275 | "metadata": {},
|
125 | 276 | "outputs": [],
|
126 | 277 | "source": [
|
127 |
| - "for meta_model in meta_model_names:\n", |
128 |
| - " \n", |
129 |
| - " # load the meta_model\n", |
130 |
| - " mc = MetaCAT.load(save_dir_path=os.path.join(base_dir_meta_models,\"meta_\"+meta_model))\n", |
| 278 | + "# Train the first meta cat model - 'Temporality' Task.\n", |
| 279 | + "meta_cat = cat._meta_cats[0]\n", |
131 | 280 | "\n",
|
132 |
| - " # changing parameters\n", |
133 |
| - " mc.config.train['nepochs'] = 15\n", |
| 281 | + "# to overwrite the existing model, resave the fine-tuned model with the same model pack dir\n", |
| 282 | + "meta_cat_task = meta_cat.config.general.category_name\n", |
| 283 | + "model_pack_dir = '<enter path to meta model pack>'\n", |
| 284 | + "save_dir_path = os.path.join(model_pack_dir,\"meta_\"+ meta_cat_task)\n", |
134 | 285 | "\n",
|
135 |
| - " save_dir_path= \"test_meta_\"+meta_model # Where to save the meta_model and results. \n", |
136 |
| - " #Ideally this should replace the meta_models inside the modelpack\n", |
| 286 | + "# to save the new model elsewhere, uncomment the below line\n", |
| 287 | + "#save_dir_path= \"test_meta_\"+meta_cat_task # Where to save the meta_model and results. \n", |
137 | 288 | "\n",
|
138 |
| - " # train the meta_model\n", |
139 |
| - " results = mc.train_from_json(mctrainer_export_path, save_dir_path=save_dir_path)\n", |
140 |
| - " \n", |
141 |
| - " # Save results\n", |
142 |
| - " json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results.json'), 'w'))" |
| 289 | + "# train the meta_model\n", |
| 290 | + "results = meta_cat.train_from_json(mctrainer_export, save_dir_path=save_dir_path)\n", |
| 291 | + "\n", |
| 292 | + "# Save results\n", |
| 293 | + "json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_cat_task+'_results.json'), 'w'))" |
143 | 294 | ]
|
144 | 295 | },
|
145 | 296 | {
|
146 | 297 | "cell_type": "markdown",
|
147 | 298 | "id": "ab23e424",
|
148 | 299 | "metadata": {},
|
149 | 300 | "source": [
|
150 |
| - "## If you dont have the model packs, and are training from scratch" |
| 301 | + "## If you dont have the model packs, and are training from scratch\n", |
| 302 | + "<b>⚠️This is very rare, it is recommended to always use the model packs and then fine-tune them</b>" |
151 | 303 | ]
|
152 | 304 | },
|
153 | 305 | {
|
|
167 | 319 | "\n",
|
168 | 320 | "tokenizer = TokenizerWrapperBERT.load(\"\", config.model['model_variant'])\n",
|
169 | 321 | "\n",
|
170 |
| - "save_dir_path= \"test_meta\" # Where to save the meta_model and results. \n", |
171 |
| - "#Ideally this should replace the meta_models inside the modelpack\n", |
| 322 | + "save_dir_path= \"test_meta_\" + meta_cat_task # Where to save the meta_model and results. \n", |
172 | 323 | "\n",
|
173 | 324 | "# Initialise and train meta_model\n",
|
174 | 325 | "mc = MetaCAT(tokenizer=tokenizer, embeddings=None, config=config)\n",
|
175 |
| - "results = mc.train_from_json(mctrainer_export_path, save_dir_path=save_dir_path)\n", |
| 326 | + "results = mc.train_from_json(mctrainer_export, save_dir_path=save_dir_path)\n", |
176 | 327 | "\n",
|
177 | 328 | "# Save results\n",
|
178 |
| - "json.dump(results['report'], open(os.path.join(save_dir_path,'meta_'+meta_model+'_results.json'), 'w'))" |
| 329 | + "json.dump(results['report'], open(os.path.join(save_dir_path,'meta_' + meta_cat_task+'_results.json'), 'w'))" |
179 | 330 | ]
|
180 | 331 | }
|
181 | 332 | ],
|
182 | 333 | "metadata": {
|
183 | 334 | "kernelspec": {
|
184 |
| - "display_name": "Python 3", |
| 335 | + "display_name": "Python [conda env:cattrainer]", |
185 | 336 | "language": "python",
|
186 |
| - "name": "python3" |
| 337 | + "name": "conda-env-cattrainer-py" |
187 | 338 | },
|
188 | 339 | "language_info": {
|
189 | 340 | "codemirror_mode": {
|
|
195 | 346 | "name": "python",
|
196 | 347 | "nbconvert_exporter": "python",
|
197 | 348 | "pygments_lexer": "ipython3",
|
198 |
| - "version": "3.8.8" |
| 349 | + "version": "3.11.11" |
199 | 350 | }
|
200 | 351 | },
|
201 | 352 | "nbformat": 4,
|
|
0 commit comments