Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/cover_class/config/static.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,6 @@ datasets:
- /some/path
water:
- http:example.com
- /some/path
- /some/path
left-edge-correction:
- /some/path/to/hdf5: [100, 115, 200]
16 changes: 15 additions & 1 deletion src/cover_class/static/preprocessor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Tuple, List
from torch import FloatTensor, Tensor
import torch
from numpy.typing import NDArray
Expand All @@ -19,3 +19,17 @@ def interior_interpolation(

def convolve(data_matrix:FloatTensor) -> FloatTensor: ... # type: ignore

def left_edge_scale(
data_matrix: NDArray[np.float32],
left_edges: List[int]
) -> None:
"""
This function takes in the data matrix and the left edges of the edge discontinuity.
It then scales the left side of the spectra to be on the same magnitude.
As such, there will be cumulative scaling of the left-portion of the data matrix until the
edges are all sequentially processed/corrected.
"""
for edge in sorted(left_edges):
denom = np.where(data_matrix[:, edge] == 0, 1e-8, data_matrix[:, edge]) # division by 0 protection
scaling_factors = data_matrix[:, edge+1] / denom
data_matrix[:, :edge+1] *= scaling_factors[:, np.newaxis]
9 changes: 7 additions & 2 deletions src/cover_class/static/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import requests # type: ignore[import]

from cover_class.utils import read_config
from cover_class.static.preprocessor import interior_interpolation
from cover_class.static.preprocessor import interior_interpolation, left_edge_scale


def download(uri: str) -> Tuple[NDArray[np.float32], NDArray[np.float32]]:
Expand Down Expand Up @@ -85,14 +85,19 @@ def generate_hdf5_from_config(config_path:str) -> None:
ds = config['datasets']
outdir = ds['output-directory']
assert Path(outdir).is_dir(), f"'output-directory': {outdir} is not a directory"
edges = config.get('left-edge-correction', dict({}))

for d in (ds_classes := ds['classes']):
if ds_classes[d] == None: continue
for location in ds_classes[d]:
# 1. get the wavelength and spectra from the locations
# 1a. get the wavelength and spectra from the locations
if Path(location).is_file(): file_wavelengths, spectra = vfs_csv(location)
else: file_wavelengths, spectra = download(location)

# [Optional] 1b. correct for the left edges
if location in edges:
left_edge_scale(spectra, edges[location])

# 2. interpolate the wavelengths
spectra = spectra[~np.isnan(spectra).any(axis=1)]
spectra_interp, target_wavelengths = interior_interpolation(spectra, file_wavelengths)
Expand Down
29 changes: 29 additions & 0 deletions tests/static/retrieval_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import io
import contextlib
from copy import deepcopy

from cover_class.static import retrieval # type: ignore[import]
MODULE = "cover_class.static.retrieval"
Expand Down Expand Up @@ -95,6 +96,34 @@ def test_vfs_csv(self):
np.testing.assert_allclose(wl, wls)
np.testing.assert_allclose(sp, spectra)

def test_left_edge_correction(self):
data = np.array([
[1.0, 2.0, 4.0, 8.0],
[2.0, 4.0, 8.0, 16.0],
], dtype=np.float32)
d0 = deepcopy(data)
d1 = deepcopy(data)

left_edges = [1]
expected = np.array([
[2.0, 4.0, 4.0, 8.0],
[4.0, 8.0, 8.0, 16.0],
], dtype=np.float32)
retrieval.left_edge_scale(d0, left_edges)
np.testing.assert_allclose(d0, expected, rtol=1e-6)

left_edges = [1,2]
expected = np.array([
[4.0, 8.0, 8.0, 8.0],
[8.0, 16.0, 16.0, 16.0],
], dtype=np.float32)
retrieval.left_edge_scale(d1, left_edges)
np.testing.assert_allclose(d1, expected, rtol=1e-6)

# test out of order indices
left_edges = [2, 1]
retrieval.left_edge_scale(data, left_edges)
np.testing.assert_allclose(data, expected, rtol=1e-6)

if __name__ == "__main__":
unittest.main()