diff --git a/src/cover_class/config/static.yml b/src/cover_class/config/static.yml index f23fba1..dd4fad2 100644 --- a/src/cover_class/config/static.yml +++ b/src/cover_class/config/static.yml @@ -21,4 +21,6 @@ datasets: - /some/path water: - http:example.com - - /some/path \ No newline at end of file + - /some/path +left-edge-correction: + - /some/path/to/hdf5: [100, 115, 200] \ No newline at end of file diff --git a/src/cover_class/static/preprocessor.py b/src/cover_class/static/preprocessor.py index 8e3aaa9..c0847f7 100644 --- a/src/cover_class/static/preprocessor.py +++ b/src/cover_class/static/preprocessor.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Tuple, List from torch import FloatTensor, Tensor import torch from numpy.typing import NDArray @@ -19,3 +19,17 @@ def interior_interpolation( def convolve(data_matrix:FloatTensor) -> FloatTensor: ... # type: ignore +def left_edge_scale( + data_matrix: NDArray[np.float32], + left_edges: List[int] + ) -> None: + """ + This function takes in the data matrix and the left edges of the edge discontinuity. + It then scales the left side of the spectra to be on the same magnitude. + As such, there will be cumulative scaling of the left-portion of the data matrix until the + edges are all sequentially processed/corrected. + """ + for edge in sorted(left_edges): + denom = np.where(data_matrix[:, edge] == 0, 1e-8, data_matrix[:, edge]) # division by 0 protection + scaling_factors = data_matrix[:, edge+1] / denom + data_matrix[:, :edge+1] *= scaling_factors[:, np.newaxis] diff --git a/src/cover_class/static/retrieval.py b/src/cover_class/static/retrieval.py index ce59b03..f58db88 100644 --- a/src/cover_class/static/retrieval.py +++ b/src/cover_class/static/retrieval.py @@ -9,7 +9,7 @@ import requests # type: ignore[import] from cover_class.utils import read_config -from cover_class.static.preprocessor import interior_interpolation +from cover_class.static.preprocessor import interior_interpolation, left_edge_scale def download(uri: str) -> Tuple[NDArray[np.float32], NDArray[np.float32]]: @@ -85,14 +85,19 @@ def generate_hdf5_from_config(config_path:str) -> None: ds = config['datasets'] outdir = ds['output-directory'] assert Path(outdir).is_dir(), f"'output-directory': {outdir} is not a directory" + edges = config.get('left-edge-correction', dict({})) for d in (ds_classes := ds['classes']): if ds_classes[d] == None: continue for location in ds_classes[d]: - # 1. get the wavelength and spectra from the locations + # 1a. get the wavelength and spectra from the locations if Path(location).is_file(): file_wavelengths, spectra = vfs_csv(location) else: file_wavelengths, spectra = download(location) + # [Optional] 1b. correct for the left edges + if location in edges: + left_edge_scale(spectra, edges[location]) + # 2. interpolate the wavelengths spectra = spectra[~np.isnan(spectra).any(axis=1)] spectra_interp, target_wavelengths = interior_interpolation(spectra, file_wavelengths) diff --git a/tests/static/retrieval_test.py b/tests/static/retrieval_test.py index 6b31a95..1e56807 100644 --- a/tests/static/retrieval_test.py +++ b/tests/static/retrieval_test.py @@ -6,6 +6,7 @@ import torch import io import contextlib +from copy import deepcopy from cover_class.static import retrieval # type: ignore[import] MODULE = "cover_class.static.retrieval" @@ -95,6 +96,34 @@ def test_vfs_csv(self): np.testing.assert_allclose(wl, wls) np.testing.assert_allclose(sp, spectra) + def test_left_edge_correction(self): + data = np.array([ + [1.0, 2.0, 4.0, 8.0], + [2.0, 4.0, 8.0, 16.0], + ], dtype=np.float32) + d0 = deepcopy(data) + d1 = deepcopy(data) + + left_edges = [1] + expected = np.array([ + [2.0, 4.0, 4.0, 8.0], + [4.0, 8.0, 8.0, 16.0], + ], dtype=np.float32) + retrieval.left_edge_scale(d0, left_edges) + np.testing.assert_allclose(d0, expected, rtol=1e-6) + + left_edges = [1,2] + expected = np.array([ + [4.0, 8.0, 8.0, 8.0], + [8.0, 16.0, 16.0, 16.0], + ], dtype=np.float32) + retrieval.left_edge_scale(d1, left_edges) + np.testing.assert_allclose(d1, expected, rtol=1e-6) + + # test out of order indices + left_edges = [2, 1] + retrieval.left_edge_scale(data, left_edges) + np.testing.assert_allclose(data, expected, rtol=1e-6) if __name__ == "__main__": unittest.main()