|
410 | 410 | "multi_model = CLAYModule.load_from_checkpoint(\n",
|
411 | 411 | " CKPT_PATH,\n",
|
412 | 412 | " mask_ratio=0.0,\n",
|
413 |
| - " band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4,5)},\n", |
| 413 | + " band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4, 5)},\n", |
414 | 414 | " bands=6,\n",
|
415 | 415 | " strict=False, # ignore the extra parameters in the checkpoint\n",
|
416 | 416 | ")\n",
|
|
427 | 427 | " 2893.86, # nir\n",
|
428 | 428 | " 2303.00, # swir16\n",
|
429 | 429 | " 1807.79, # swir22\n",
|
430 |
| - " \n", |
431 | 430 | " ]\n",
|
432 | 431 | " STD = [\n",
|
433 | 432 | " 2026.96, # red\n",
|
|
472 | 471 | " batch[\"pixels\"] = batch[\"pixels\"].to(multi_model.device)\n",
|
473 | 472 | " # Pass just the specific band through the model\n",
|
474 | 473 | " batch[\"timestep\"] = batch[\"timestep\"].to(multi_model.device)\n",
|
475 |
| - " batch[\"date\"] = batch[\"date\"] #.to(multi_model.device)\n", |
| 474 | + " batch[\"date\"] = batch[\"date\"] # .to(multi_model.device)\n", |
476 | 475 | " batch[\"latlon\"] = batch[\"latlon\"].to(multi_model.device)\n",
|
477 | 476 | "\n",
|
478 | 477 | " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n",
|
|
651 | 650 | "plt.xticks(rotation=-30)\n",
|
652 | 651 | "# All points\n",
|
653 | 652 | "plt.scatter(tss, pca_result, color=\"blue\")\n",
|
654 |
| - "#plt.scatter(stack.time, pca_result, color=\"blue\")\n", |
| 653 | + "# plt.scatter(stack.time, pca_result, color=\"blue\")\n", |
655 | 654 | "\n",
|
656 | 655 | "# Cloudy images\n",
|
657 | 656 | "plt.scatter(tss[7], pca_result[7], color=\"green\")\n",
|
658 | 657 | "plt.scatter(tss[8], pca_result[8], color=\"green\")\n",
|
659 |
| - "#plt.scatter(stack.time[7], pca_result[7], color=\"green\")\n", |
660 |
| - "#plt.scatter(stack.time[8], pca_result[8], color=\"green\")\n", |
| 658 | + "# plt.scatter(stack.time[7], pca_result[7], color=\"green\")\n", |
| 659 | + "# plt.scatter(stack.time[8], pca_result[8], color=\"green\")\n", |
661 | 660 | "\n",
|
662 | 661 | "# After flood\n",
|
663 | 662 | "plt.scatter(tss[-7:], pca_result[-7:], color=\"red\")\n",
|
664 |
| - "#plt.scatter(stack.time[-7:], pca_result[-7:], color=\"red\")" |
| 663 | + "# plt.scatter(stack.time[-7:], pca_result[-7:], color=\"red\")" |
665 | 664 | ]
|
666 | 665 | },
|
667 | 666 | {
|
|
691 | 690 | "outputs": [],
|
692 | 691 | "source": [
|
693 | 692 | "from sklearn.manifold import TSNE\n",
|
694 |
| - "from sklearn.ensemble import IsolationForest\n", |
695 | 693 | "\n",
|
696 | 694 | "# Perform t-SNE on the embeddings\n",
|
697 | 695 | "tsne = TSNE(n_components=2, perplexity=5)\n",
|
|
722 | 720 | "\n",
|
723 | 721 | "# Annotate each point with the corresponding date\n",
|
724 | 722 | "for i, (x, y) in enumerate(zip(X_tsne[:, 0], X_tsne[:, 1])):\n",
|
725 |
| - " plt.annotate(f'{tss[i]}', (x, y))\n", |
726 |
| - " \n", |
727 |
| - "plt.title('t-SNE Visualization')\n", |
728 |
| - "plt.xlabel('t-SNE Component 1')\n", |
729 |
| - "plt.ylabel('t-SNE Component 2')\n", |
| 723 | + " plt.annotate(f\"{tss[i]}\", (x, y))\n", |
| 724 | + "\n", |
| 725 | + "plt.title(\"t-SNE Visualization\")\n", |
| 726 | + "plt.xlabel(\"t-SNE Component 1\")\n", |
| 727 | + "plt.ylabel(\"t-SNE Component 2\")\n", |
730 | 728 | "plt.show()"
|
731 | 729 | ]
|
732 | 730 | },
|
|
0 commit comments