|
47 | 47 | },
|
48 | 48 | {
|
49 | 49 | "cell_type": "code",
|
50 |
| - "execution_count": 2, |
| 50 | + "execution_count": 1, |
51 | 51 | "id": "fb8d43e9",
|
52 | 52 | "metadata": {},
|
53 | 53 | "outputs": [],
|
|
67 | 67 | },
|
68 | 68 | {
|
69 | 69 | "cell_type": "code",
|
70 |
| - "execution_count": 3, |
| 70 | + "execution_count": 2, |
71 | 71 | "id": "2a5d7425",
|
72 | 72 | "metadata": {},
|
73 | 73 | "outputs": [
|
|
86 | 86 | " 'Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world oil prices, toppling records and straining wallets, present a new economic menace barely three months before the US presidential elections.')]"
|
87 | 87 | ]
|
88 | 88 | },
|
89 |
| - "execution_count": 3, |
| 89 | + "execution_count": 2, |
90 | 90 | "metadata": {},
|
91 | 91 | "output_type": "execute_result"
|
92 | 92 | }
|
|
106 | 106 | },
|
107 | 107 | {
|
108 | 108 | "cell_type": "code",
|
109 |
| - "execution_count": 4, |
| 109 | + "execution_count": 3, |
110 | 110 | "id": "4ba18c72",
|
111 | 111 | "metadata": {},
|
112 | 112 | "outputs": [],
|
|
123 | 123 | },
|
124 | 124 | {
|
125 | 125 | "cell_type": "code",
|
126 |
| - "execution_count": 5, |
| 126 | + "execution_count": 4, |
127 | 127 | "id": "7d03c0fe",
|
128 | 128 | "metadata": {},
|
129 | 129 | "outputs": [],
|
|
134 | 134 | },
|
135 | 135 | {
|
136 | 136 | "cell_type": "code",
|
137 |
| - "execution_count": 6, |
| 137 | + "execution_count": 5, |
138 | 138 | "id": "2a43dfd9",
|
139 | 139 | "metadata": {},
|
140 | 140 | "outputs": [],
|
|
156 | 156 | },
|
157 | 157 | {
|
158 | 158 | "cell_type": "code",
|
159 |
| - "execution_count": 7, |
| 159 | + "execution_count": 6, |
160 | 160 | "id": "225c9a01",
|
161 | 161 | "metadata": {},
|
162 | 162 | "outputs": [],
|
|
182 | 182 | },
|
183 | 183 | {
|
184 | 184 | "cell_type": "code",
|
185 |
| - "execution_count": 8, |
| 185 | + "execution_count": 7, |
186 | 186 | "id": "da8b2413",
|
187 | 187 | "metadata": {},
|
188 | 188 | "outputs": [],
|
|
225 | 225 | },
|
226 | 226 | {
|
227 | 227 | "cell_type": "code",
|
228 |
| - "execution_count": 9, |
| 228 | + "execution_count": 8, |
229 | 229 | "id": "53def141",
|
230 | 230 | "metadata": {},
|
231 | 231 | "outputs": [],
|
|
248 | 248 | },
|
249 | 249 | {
|
250 | 250 | "cell_type": "code",
|
251 |
| - "execution_count": 10, |
| 251 | + "execution_count": 9, |
252 | 252 | "id": "6ff581cd",
|
253 | 253 | "metadata": {},
|
254 | 254 | "outputs": [
|
|
257 | 257 | "output_type": "stream",
|
258 | 258 | "text": [
|
259 | 259 | "| epoch 1 | 500/ 1782 batches | accuracy 0.684\n",
|
260 |
| - "| epoch 1 | 1000/ 1782 batches | accuracy 0.854\n", |
261 |
| - "| epoch 1 | 1500/ 1782 batches | accuracy 0.876\n", |
| 260 | + "| epoch 1 | 1000/ 1782 batches | accuracy 0.855\n", |
| 261 | + "| epoch 1 | 1500/ 1782 batches | accuracy 0.877\n", |
262 | 262 | "-----------------------------------------------------------\n",
|
263 |
| - "| end of epoch 1 | time: 12.70s | valid accuracy 0.885 \n", |
| 263 | + "| end of epoch 1 | time: 14.62s | valid accuracy 0.884 \n", |
264 | 264 | "-----------------------------------------------------------\n",
|
265 | 265 | "| epoch 2 | 500/ 1782 batches | accuracy 0.900\n",
|
266 |
| - "| epoch 2 | 1000/ 1782 batches | accuracy 0.898\n", |
267 |
| - "| epoch 2 | 1500/ 1782 batches | accuracy 0.901\n", |
| 266 | + "| epoch 2 | 1000/ 1782 batches | accuracy 0.896\n", |
| 267 | + "| epoch 2 | 1500/ 1782 batches | accuracy 0.904\n", |
268 | 268 | "-----------------------------------------------------------\n",
|
269 |
| - "| end of epoch 2 | time: 13.83s | valid accuracy 0.901 \n", |
| 269 | + "| end of epoch 2 | time: 14.14s | valid accuracy 0.876 \n", |
270 | 270 | "-----------------------------------------------------------\n"
|
271 | 271 | ]
|
272 | 272 | }
|
|
317 | 317 | },
|
318 | 318 | {
|
319 | 319 | "cell_type": "code",
|
320 |
| - "execution_count": 11, |
| 320 | + "execution_count": 10, |
321 | 321 | "id": "3a668a76",
|
322 | 322 | "metadata": {},
|
323 | 323 | "outputs": [
|
|
326 | 326 | "output_type": "stream",
|
327 | 327 | "text": [
|
328 | 328 | "Checking the results of test dataset.\n",
|
329 |
| - "test accuracy 0.896\n" |
| 329 | + "test accuracy 0.876\n" |
330 | 330 | ]
|
331 | 331 | }
|
332 | 332 | ],
|
|
366 | 366 | },
|
367 | 367 | {
|
368 | 368 | "cell_type": "code",
|
369 |
| - "execution_count": 12, |
| 369 | + "execution_count": 11, |
370 | 370 | "id": "compressed-occupation",
|
371 | 371 | "metadata": {},
|
372 | 372 | "outputs": [],
|
|
386 | 386 | },
|
387 | 387 | {
|
388 | 388 | "cell_type": "code",
|
389 |
| - "execution_count": 13, |
| 389 | + "execution_count": 12, |
390 | 390 | "id": "19408128",
|
391 | 391 | "metadata": {},
|
392 | 392 | "outputs": [
|
393 | 393 | {
|
394 | 394 | "name": "stdout",
|
395 | 395 | "output_type": "stream",
|
396 | 396 | "text": [
|
397 |
| - "Created your project. Navigate to http://localhost:8000/projects/3 to see it in the UI.\n" |
| 397 | + "Found your project. Navigate to http://localhost:8000/projects/30 to see it.\n" |
398 | 398 | ]
|
399 | 399 | }
|
400 | 400 | ],
|
401 | 401 | "source": [
|
402 | 402 | "from unboxapi.tasks import TaskType\n",
|
403 | 403 | "\n",
|
404 |
| - "project = client.create_project(name=\"Text classification with PyTorch\",\n", |
405 |
| - " task_type=TaskType.TextClassification,\n", |
406 |
| - " description=\"Evaluating NN for text classification\")" |
| 404 | + "project = client.create_or_load_project(name=\"Text classification with PyTorch\",\n", |
| 405 | + " task_type=TaskType.TextClassification,\n", |
| 406 | + " description=\"Evaluating NN for text classification\")" |
407 | 407 | ]
|
408 | 408 | },
|
409 | 409 | {
|
|
461 | 461 | },
|
462 | 462 | {
|
463 | 463 | "cell_type": "code",
|
464 |
| - "execution_count": 16, |
| 464 | + "execution_count": 13, |
465 | 465 | "id": "supposed-survey",
|
466 | 466 | "metadata": {},
|
467 | 467 | "outputs": [],
|
468 | 468 | "source": [
|
469 |
| - "def predict_proba(model, texts, tokenizer, vocab):\n", |
| 469 | + "def predict_proba(model, texts, tokenizer_fn, vocab):\n", |
470 | 470 | " with torch.no_grad():\n",
|
471 | 471 | " texts = [\n",
|
472 | 472 | " torch.tensor(\n",
|
473 |
| - " [vocab[token] for token in tokenizer(text)]\n", |
| 473 | + " [vocab[token] for token in tokenizer_fn(text)]\n", |
474 | 474 | " ) \n",
|
475 | 475 | " for text in texts]\n",
|
476 | 476 | " text_list = torch.tensor(torch.cat(texts)).long()\n",
|
|
495 | 495 | },
|
496 | 496 | {
|
497 | 497 | "cell_type": "code",
|
498 |
| - "execution_count": 17, |
| 498 | + "execution_count": 14, |
499 | 499 | "id": "north-valuation",
|
500 | 500 | "metadata": {},
|
501 | 501 | "outputs": [],
|
|
520 | 520 | },
|
521 | 521 | {
|
522 | 522 | "cell_type": "code",
|
523 |
| - "execution_count": 18, |
| 523 | + "execution_count": 15, |
524 | 524 | "id": "comprehensive-jenny",
|
525 | 525 | "metadata": {},
|
526 | 526 | "outputs": [
|
527 | 527 | {
|
528 | 528 | "name": "stderr",
|
529 | 529 | "output_type": "stream",
|
530 | 530 | "text": [
|
531 |
| - "/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_21571/785500925.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", |
| 531 | + "/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_22576/710996952.py:8: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n", |
532 | 532 | " text_list = torch.tensor(torch.cat(texts)).long()\n",
|
533 |
| - "/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_21571/785500925.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", |
| 533 | + "/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/ipykernel_22576/710996952.py:17: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.\n", |
534 | 534 | " return sm(output).numpy().tolist()\n"
|
535 | 535 | ]
|
536 | 536 | },
|
537 | 537 | {
|
538 | 538 | "data": {
|
539 | 539 | "text/plain": [
|
540 |
| - "[[0.004791636019945145,\n", |
541 |
| - " 0.9912257790565491,\n", |
542 |
| - " 0.0018143205670639873,\n", |
543 |
| - " 0.0021683182567358017],\n", |
544 |
| - " [0.009553060866892338,\n", |
545 |
| - " 0.9899933934211731,\n", |
546 |
| - " 6.70066146994941e-05,\n", |
547 |
| - " 0.0003865564940497279]]" |
| 540 | + "[[0.012467482127249241,\n", |
| 541 | + " 0.9524526596069336,\n", |
| 542 | + " 0.0024990958627313375,\n", |
| 543 | + " 0.03258078917860985],\n", |
| 544 | + " [0.024693824350833893,\n", |
| 545 | + " 0.9746410846710205,\n", |
| 546 | + " 1.4036187167221215e-05,\n", |
| 547 | + " 0.0006511638639494777]]" |
548 | 548 | ]
|
549 | 549 | },
|
550 |
| - "execution_count": 18, |
| 550 | + "execution_count": 15, |
551 | 551 | "metadata": {},
|
552 | 552 | "output_type": "execute_result"
|
553 | 553 | }
|
|
566 | 566 | },
|
567 | 567 | {
|
568 | 568 | "cell_type": "code",
|
569 |
| - "execution_count": 19, |
| 569 | + "execution_count": 17, |
570 | 570 | "id": "f0b3eb3f",
|
571 | 571 | "metadata": {},
|
572 | 572 | "outputs": [
|
573 | 573 | {
|
574 | 574 | "name": "stdout",
|
575 | 575 | "output_type": "stream",
|
576 | 576 | "text": [
|
577 |
| - "[2022-08-08 10:56:08,400] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n", |
| 577 | + "[2022-08-24 10:54:58,678] WARNING - BentoML by default does not include spacy and torchvision package when using PytorchModelArtifact. To make sure BentoML bundle those packages if they are required for your model, either import those packages in BentoService definition file or manually add them via `@env(pip_packages=['torchvision'])` when defining a BentoService\n", |
578 | 578 | "Bundling model and artifacts...\n"
|
579 | 579 | ]
|
580 | 580 | },
|
581 | 581 | {
|
582 | 582 | "name": "stderr",
|
583 | 583 | "output_type": "stream",
|
584 | 584 | "text": [
|
585 |
| - "/Users/gustavocid/miniconda3/envs/pytorch-notebook/lib/python3.8/site-packages/bentoml/frameworks/pytorch.py:162: ResourceWarning: unclosed file <_io.BufferedWriter name='/private/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/bentoml-temp-kjhae170/TemplateModel/artifacts/model.pt'>\n", |
| 585 | + "/Users/gustavocid/miniconda3/envs/unbox-examples/lib/python3.8/site-packages/bentoml/frameworks/pytorch.py:162: ResourceWarning: unclosed file <_io.BufferedWriter name='/private/var/folders/9z/j3bd32nd47j_l0thnbj6vbnw0000gn/T/bentoml-temp-3ct5zn56/TemplateModel/artifacts/model.pt'>\n", |
586 | 586 | " return cloudpickle.dump(self._model, open(self._file_path(dst), \"wb\"))\n",
|
587 | 587 | "ResourceWarning: Enable tracemalloc to get the object allocation traceback\n"
|
588 | 588 | ]
|
|
598 | 598 | "source": [
|
599 | 599 | "from unboxapi.models import ModelType\n",
|
600 | 600 | "\n",
|
601 |
| - "model = project.add_model(\n", |
| 601 | + "ml_model = project.add_model(\n", |
602 | 602 | " function=predict_proba, \n",
|
603 | 603 | " model=model,\n",
|
604 | 604 | " model_type=ModelType.pytorch,\n",
|
605 | 605 | " class_names=['world', 'sports', 'business', 'sci/tec'],\n",
|
606 | 606 | " name='pytorch 4',\n",
|
607 | 607 | " commit_message='this is my pytorch model',\n",
|
608 | 608 | " requirements_txt_file='requirements.txt',\n",
|
609 |
| - " tokenizer=tokenizer,\n", |
| 609 | + " tokenizer_fn=tokenizer,\n", |
610 | 610 | " vocab=vocab,\n",
|
611 | 611 | ")"
|
612 | 612 | ]
|
|
0 commit comments