Skip to content

Commit fd626d4

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes UNB-2161 - Adding explainability scores fails when running pytorch example
1 parent 5b74443 commit fd626d4

File tree

1 file changed

+45
-45
lines changed

1 file changed

+45
-45
lines changed

examples/text-classification/pytorch/pytorch.ipynb

Lines changed: 45 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
},
4848
{
4949
"cell_type": "code",
50-
"execution_count": 2,
50+
"execution_count": 1,
5151
"id": "fb8d43e9",
5252
"metadata": {},
5353
"outputs": [],
@@ -67,7 +67,7 @@
6767
},
6868
{
6969
"cell_type": "code",
70-
"execution_count": 3,
70+
"execution_count": 2,
7171
"id": "2a5d7425",
7272
"metadata": {},
7373
"outputs": [
@@ -86,7 +86,7 @@
8686
" 'Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.')]"
8787
]
8888
},
89-
"execution_count": 3,
89+
"execution_count": 2,
9090
"metadata": {},
9191
"output_type": "execute_result"
9292
}
@@ -106,7 +106,7 @@
106106
},
107107
{
108108
"cell_type": "code",
109-
"execution_count": 4,
109+
"execution_count": 3,
110110
"id": "4ba18c72",
111111
"metadata": {},
112112
"outputs": [],
@@ -123,7 +123,7 @@
123123
},
124124
{
125125
"cell_type": "code",
126-
"execution_count": 5,
126+
"execution_count": 4,
127127
"id": "7d03c0fe",
128128
"metadata": {},
129129
"outputs": [],
@@ -134,7 +134,7 @@
134134
},
135135
{
136136
"cell_type": "code",
137-
"execution_count": 6,
137+
"execution_count": 5,
138138
"id": "2a43dfd9",
139139
"metadata": {},
140140
"outputs": [],
@@ -156,7 +156,7 @@
156156
},
157157
{
158158
"cell_type": "code",
159-
"execution_count": 7,
159+
"execution_count": 6,
160160
"id": "225c9a01",
161161
"metadata": {},
162162
"outputs": [],
@@ -182,7 +182,7 @@
182182
},
183183
{
184184
"cell_type": "code",
185-
"execution_count": 8,
185+
"execution_count": 7,
186186
"id": "da8b2413",
187187
"metadata": {},
188188
"outputs": [],
@@ -225,7 +225,7 @@
225225
},
226226
{
227227
"cell_type": "code",
228-
"execution_count": 9,
228+
"execution_count": 8,
229229
"id": "53def141",
230230
"metadata": {},
231231
"outputs": [],
@@ -248,7 +248,7 @@
248248
},
249249
{
250250
"cell_type": "code",
251-
"execution_count": 10,
251+
"execution_count": 9,
252252
"id": "6ff581cd",
253253
"metadata": {},
254254
"outputs": [
@@ -257,16 +257,16 @@
257257
"output_type": "stream",
258258
"text": [
259259
"| epoch 1 | 500/ 1782 batches | accuracy 0.684\n",
260-
"| epoch 1 | 1000/ 1782 batches | accuracy 0.854\n",
261-
"| epoch 1 | 1500/ 1782 batches | accuracy 0.876\n",
260+
"| epoch 1 | 1000/ 1782 batches | accuracy 0.855\n",
261+
"| epoch 1 | 1500/ 1782 batches | accuracy 0.877\n",
262262
"-----------------------------------------------------------\n",
263-
"| end of epoch 1 | time: 12.70s | valid accuracy 0.885 \n",
263+
"| end of epoch 1 | time: 14.62s | valid accuracy 0.884 \n",
264264
"-----------------------------------------------------------\n",
265265
"| epoch 2 | 500/ 1782 batches | accuracy 0.900\n",
266-
"| epoch 2 | 1000/ 1782 batches | accuracy 0.898\n",
267-
"| epoch 2 | 1500/ 1782 batches | accuracy 0.901\n",
266+
"| epoch 2 | 1000/ 1782 batches | accuracy 0.896\n",
267+
"| epoch 2 | 1500/ 1782 batches | accuracy 0.904\n",
268268
"-----------------------------------------------------------\n",
269-
"| end of epoch 2 | time: 13.83s | valid accuracy 0.901 \n",
269+
"| end of epoch 2 | time: 14.14s | valid accuracy 0.876 \n",
270270
"-----------------------------------------------------------\n"
271271
]
272272
}
@@ -317,7 +317,7 @@
317317
},
318318
{
319319
"cell_type": "code",
320-
"execution_count": 11,
320+
"execution_count": 10,
321321
"id": "3a668a76",
322322
"metadata": {},
323323
"outputs": [
@@ -326,7 +326,7 @@
326326
"output_type": "stream",
327327
"text": [
328328
"Checking the results of test dataset.\n",
329-
"test accuracy 0.896\n"
329+
"test accuracy 0.876\n"
330330
]
331331
}
332332
],
@@ -366,7 +366,7 @@
366366
},
367367
{
368368
"cell_type": "code",
369-
"execution_count": 12,
369+
"execution_count": 11,
370370
"id": "compressed-occupation",
371371
"metadata": {},
372372
"outputs": [],
@@ -386,24 +386,24 @@
386386
},
387387
{
388388
"cell_type": "code",
389-
"execution_count": 13,
389+
"execution_count": 12,
390390
"id": "19408128",
391391
"metadata": {},
392392
"outputs": [
393393
{
394394
"name": "stdout",
395395
"output_type": "stream",
396396
"text": [
397-
"Created your project. Navigate to http://localhost:8000/projects/3 to see it in the UI.\n"
397+
"Found your project. Navigate to http://localhost:8000/projects/30 to see it.\n"
398398
]
399399
}
400400
],
401401
"source": [
402402
"from unboxapi.tasks import TaskType\n",
403403
"\n",
404-
"project = client.create_project(name=\"Text classification with PyTorch\",\n",
405-
" task_type=TaskType.TextClassification,\n",
406-
" description=\"Evaluating NN for text classification\")"
404+
"project = client.create_or_load_project(name=\"Text classification with PyTorch\",\n",
405+
" task_type=TaskType.TextClassification,\n",
406+
" description=\"Evaluating NN for text classification\")"
407407
]
408408
},
409409
{
@@ -461,16 +461,16 @@
461461
},
462462
{
463463
"cell_type": "code",
464-
"execution_count": 16,
464+
"execution_count": 13,
465465
"id": "supposed-survey",
466466
"metadata": {},
467467
"outputs": [],
468468
"source": [
469-
"def predict_proba(model, texts, tokenizer, vocab):\n",
469+
"def predict_proba(model, texts, tokenizer_fn, vocab):\n",
470470
" with torch.no_grad():\n",
471471
" texts = [\n",
472472
" torch.tensor(\n",
473-
" [vocab[token] for token in tokenizer(text)]\n",
473+
" [vocab[token] for token in tokenizer_fn(text)]\n",
474474
" ) \n",
475475
" for text in texts]\n",
476476
" text_list = torch.tensor(torch.cat(texts)).long()\n",
@@ -495,7 +495,7 @@
495495
},
496496
{
497497
"cell_type": "code",
498-
"execution_count": 17,
498+
"execution_count": 14,
499499
"id": "north-valuation",
500500
"metadata": {},
501501
"outputs": [],
@@ -520,34 +520,34 @@
520520
},
521521
{
522522
"cell_type": "code",
523-
"execution_count": 18,
523+
"execution_count": 15,
524524
"id": "comprehensive-jenny",
525525
"metadata": {},
526526
"outputs": [
527527
{
528528
"name": "stderr",
529529
"output_type": "stream",
530530
"text": [
531-
"/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_21571/785500925.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
531+
"/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_22576/710996952.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
532532
" text_list = torch.tensor(torch.cat(texts)).long()\n",
533-
"/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_21571/785500925.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
533+
"/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_22576/710996952.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n",
534534
" return sm(output).numpy().tolist()\n"
535535
]
536536
},
537537
{
538538
"data": {
539539
"text/plain": [
540-
"[[0.004791636019945145,\n",
541-
" 0.9912257790565491,\n",
542-
" 0.0018143205670639873,\n",
543-
" 0.0021683182567358017],\n",
544-
" [0.009553060866892338,\n",
545-
" 0.9899933934211731,\n",
546-
" 6.70066146994941e-05,\n",
547-
" 0.0003865564940497279]]"
540+
"[[0.012467482127249241,\n",
541+
" 0.9524526596069336,\n",
542+
" 0.0024990958627313375,\n",
543+
" 0.03258078917860985],\n",
544+
" [0.024693824350833893,\n",
545+
" 0.9746410846710205,\n",
546+
" 1.4036187167221215e-05,\n",
547+
" 0.0006511638639494777]]"
548548
]
549549
},
550-
"execution_count": 18,
550+
"execution_count": 15,
551551
"metadata": {},
552552
"output_type": "execute_result"
553553
}
@@ -566,23 +566,23 @@
566566
},
567567
{
568568
"cell_type": "code",
569-
"execution_count": 19,
569+
"execution_count": 17,
570570
"id": "f0b3eb3f",
571571
"metadata": {},
572572
"outputs": [
573573
{
574574
"name": "stdout",
575575
"output_type": "stream",
576576
"text": [
577-
"[2022-08-08 10:56:08,400] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n",
577+
"[2022-08-24 10:54:58,678] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n",
578578
"Bundling model and artifacts...\n"
579579
]
580580
},
581581
{
582582
"name": "stderr",
583583
"output_type": "stream",
584584
"text": [
585-
"/Users/gustavocid/miniconda3/envs/pytorch-notebook/lib/python3.8/site-packages/bentoml/frameworks/pytorch.py:162: ResourceWarning: unclosed file <_io.BufferedWriter name='/private/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/bentoml-temp-kjhae170/TemplateModel/artifacts/model.pt'>\n",
585+
"/Users/gustavocid/miniconda3/envs/unbox-examples/lib/python3.8/site-packages/bentoml/frameworks/pytorch.py:162: ResourceWarning: unclosed file <_io.BufferedWriter name='/private/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/bentoml-temp-3ct5zn56/TemplateModel/artifacts/model.pt'>\n",
586586
" return cloudpickle.dump(self._model, open(self._file_path(dst), \"wb\"))\n",
587587
"ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
588588
]
@@ -598,15 +598,15 @@
598598
"source": [
599599
"from unboxapi.models import ModelType\n",
600600
"\n",
601-
"model = project.add_model(\n",
601+
"ml_model = project.add_model(\n",
602602
" function=predict_proba, \n",
603603
" model=model,\n",
604604
" model_type=ModelType.pytorch,\n",
605605
" class_names=['world', 'sports', 'business', 'sci/tec'],\n",
606606
" name='pytorch 4',\n",
607607
" commit_message='this is my pytorch model',\n",
608608
" requirements_txt_file='requirements.txt',\n",
609-
" tokenizer=tokenizer,\n",
609+
" tokenizer_fn=tokenizer,\n",
610610
" vocab=vocab,\n",
611611
")"
612612
]

0 commit comments

Comments
 (0)