Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4a4e86f

Browse files
committedFeb 22, 2024·
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent aa825b5 commit 4a4e86f

File tree

2 files changed

+20
-14
lines changed

2 files changed

+20
-14
lines changed
 

‎docs/partial-inputs-flood-tutorial.ipynb

+11-13
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@
410410
"multi_model = CLAYModule.load_from_checkpoint(\n",
411411
" CKPT_PATH,\n",
412412
" mask_ratio=0.0,\n",
413-
" band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4,5)},\n",
413+
" band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4, 5)},\n",
414414
" bands=6,\n",
415415
" strict=False, # ignore the extra parameters in the checkpoint\n",
416416
")\n",
@@ -427,7 +427,6 @@
427427
" 2893.86, # nir\n",
428428
" 2303.00, # swir16\n",
429429
" 1807.79, # swir22\n",
430-
" \n",
431430
" ]\n",
432431
" STD = [\n",
433432
" 2026.96, # red\n",
@@ -472,7 +471,7 @@
472471
" batch[\"pixels\"] = batch[\"pixels\"].to(multi_model.device)\n",
473472
" # Pass just the specific band through the model\n",
474473
" batch[\"timestep\"] = batch[\"timestep\"].to(multi_model.device)\n",
475-
" batch[\"date\"] = batch[\"date\"] #.to(multi_model.device)\n",
474+
" batch[\"date\"] = batch[\"date\"] # .to(multi_model.device)\n",
476475
" batch[\"latlon\"] = batch[\"latlon\"].to(multi_model.device)\n",
477476
"\n",
478477
" # Pass pixels, latlon, timestep through the encoder to create encoded patches\n",
@@ -651,17 +650,17 @@
651650
"plt.xticks(rotation=-30)\n",
652651
"# All points\n",
653652
"plt.scatter(tss, pca_result, color=\"blue\")\n",
654-
"#plt.scatter(stack.time, pca_result, color=\"blue\")\n",
653+
"# plt.scatter(stack.time, pca_result, color=\"blue\")\n",
655654
"\n",
656655
"# Cloudy images\n",
657656
"plt.scatter(tss[7], pca_result[7], color=\"green\")\n",
658657
"plt.scatter(tss[8], pca_result[8], color=\"green\")\n",
659-
"#plt.scatter(stack.time[7], pca_result[7], color=\"green\")\n",
660-
"#plt.scatter(stack.time[8], pca_result[8], color=\"green\")\n",
658+
"# plt.scatter(stack.time[7], pca_result[7], color=\"green\")\n",
659+
"# plt.scatter(stack.time[8], pca_result[8], color=\"green\")\n",
661660
"\n",
662661
"# After flood\n",
663662
"plt.scatter(tss[-7:], pca_result[-7:], color=\"red\")\n",
664-
"#plt.scatter(stack.time[-7:], pca_result[-7:], color=\"red\")"
663+
"# plt.scatter(stack.time[-7:], pca_result[-7:], color=\"red\")"
665664
]
666665
},
667666
{
@@ -691,7 +690,6 @@
691690
"outputs": [],
692691
"source": [
693692
"from sklearn.manifold import TSNE\n",
694-
"from sklearn.ensemble import IsolationForest\n",
695693
"\n",
696694
"# Perform t-SNE on the embeddings\n",
697695
"tsne = TSNE(n_components=2, perplexity=5)\n",
@@ -722,11 +720,11 @@
722720
"\n",
723721
"# Annotate each point with the corresponding date\n",
724722
"for i, (x, y) in enumerate(zip(X_tsne[:, 0], X_tsne[:, 1])):\n",
725-
" plt.annotate(f'{tss[i]}', (x, y))\n",
726-
" \n",
727-
"plt.title('t-SNE Visualization')\n",
728-
"plt.xlabel('t-SNE Component 1')\n",
729-
"plt.ylabel('t-SNE Component 2')\n",
723+
" plt.annotate(f\"{tss[i]}\", (x, y))\n",
724+
"\n",
725+
"plt.title(\"t-SNE Visualization\")\n",
726+
"plt.xlabel(\"t-SNE Component 1\")\n",
727+
"plt.ylabel(\"t-SNE Component 2\")\n",
730728
"plt.show()"
731729
]
732730
},

‎src/datamodule.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
os.environ["GDAL_DISABLE_READDIR_ON_OPEN"] = "EMPTY_DIR"
2020
os.environ["GDAL_HTTP_MERGE_CONSECUTIVE_RANGES"] = "YES"
2121

22+
2223
# %%
2324
# Regular torch Dataset
2425
class ClayDataset(Dataset):
@@ -60,7 +61,14 @@ def read_chip(self, chip_path):
6061

6162
# read timestep & normalize
6263
date = chip.tags()["date"] # YYYY-MM-DD
63-
year, month, day, year_non_norm, month_non_norm, day_non_norm = self.normalize_timestamp(date)
64+
(
65+
year,
66+
month,
67+
day,
68+
year_non_norm,
69+
month_non_norm,
70+
day_non_norm,
71+
) = self.normalize_timestamp(date)
6472

6573
# read lat,lon from UTM to WGS84 & normalize
6674
bounds = chip.bounds # xmin, ymin, xmax, ymax

0 commit comments

Comments
 (0)
Please sign in to comment.