diff --git a/nsrdb/data_model/variable_factory.py b/nsrdb/data_model/variable_factory.py index bb706e81..89c2509f 100755 --- a/nsrdb/data_model/variable_factory.py +++ b/nsrdb/data_model/variable_factory.py @@ -24,6 +24,7 @@ class VarFactory: """Factory pattern to retrieve ancillary variable helper objects.""" # mapping of NSRDB variable names to helper objects + # TODO: split this up into psm and mlcloud variables MAPPING: ClassVar = { 'asymmetry': AsymVar, 'air_temperature': MerraVar, diff --git a/nsrdb/preprocessing/base_data_model.py b/nsrdb/preprocessing/base_data_model.py new file mode 100644 index 00000000..63459a1b --- /dev/null +++ b/nsrdb/preprocessing/base_data_model.py @@ -0,0 +1,376 @@ +"""Shared utilities for converting satellite data to UWISC format.""" + +import logging +import os +import re +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import cached_property +from glob import glob + +import numpy as np +import pandas as pd +from rex.utilities.solar_position import SolarPosition + +DROP_VARS = ['relative_time'] + +UWISC_CLOUD_TYPE = { + 'N/A': -15, + 'Clear': 0, + 'Probably Clear': 1, + 'Fog': 2, + 'Water': 3, + 'Super-Cooled Water': 4, + 'Mixed': 5, + 'Opaque Ice': 6, + 'Cirrus': 7, + 'Overlapping': 8, + 'Overshooting': 9, + 'Unknown': 10, + 'Dust': 11, + 'Smoke': 12, +} + + +def expand_input_patterns(input_pattern): + """Expand one or more glob patterns into a de-duplicated file list.""" + patterns = ( + [input_pattern] + if isinstance(input_pattern, str) + else list(input_pattern) + ) + files = [] + for pattern in patterns: + files.extend(glob(pattern, recursive=True)) + return list(dict.fromkeys(files)) + + +def run_data_model_jobs( + data_model_class, + input_pattern, + output_pattern, + *, + max_workers=None, + group_inputs=None, + logger=None, +): + """Run a data-model conversion job over expanded input patterns.""" + logger = logger or logging.getLogger(data_model_class.__module__) + files = expand_input_patterns(input_pattern) + job_inputs = group_inputs(files) if group_inputs is not None else files + + if max_workers == 1: + for input_data in job_inputs: + data_model_class.run(input_data, output_pattern) + else: + with ThreadPoolExecutor(max_workers=max_workers) as executor: + futures = {} + for input_data in job_inputs: + future = executor.submit( + data_model_class.run, + input_data, + output_pattern, + ) + futures[future] = input_data + + for future in as_completed(futures): + try: + future.result() + except Exception as error: + logger.error( + 'Error processing file(s): %s', + futures[future], + ) + logger.exception(error) + + logger.info('Finished converting %s files.', len(files)) + + +class BaseUwiscDataModel: + """Shared preprocessing pipeline for UWISC-format conversion.""" + + NAME_MAP = {} + CLOUD_TYPE_MAP = {} + CLOUD_TYPE_SOURCE_VAR = None + TIMESTAMP_PATTERN = r'.*_([0-9]+).([0-9]+).([0-9]+).\w+' + + def __init__(self, input_data, output_pattern): + self.input_data = input_data + self.output_pattern = output_pattern + + @staticmethod + def _as_list(input_data): + """Normalize a single input or grouped inputs into a list.""" + if isinstance(input_data, (list, tuple)): + return list(input_data) + return [input_data] + + @cached_property + def input_files(self): + """Get the normalized list of input files.""" + return self._as_list(self.input_data) + + @classmethod + def get_primary_input_file(cls, input_files): + """Get the primary input file used for naming outputs.""" + return input_files[0] + + @cached_property + def primary_input_file(self): + """Get the primary input file used for naming outputs.""" + return self.get_primary_input_file(self.input_files) + + @classmethod + def parse_timestamp(cls, input_file): + """Parse a timestamp tuple from an input file path.""" + ts = re.match(cls.TIMESTAMP_PATTERN, input_file).groups() + year, doy, hour = ts + hour = hour[:2] + minute = hour[2:] if len(hour) > 2 else '00' + secs = '000' + return year, doy, hour, minute, secs + + @classmethod + def parse_timestamp_string(cls, input_file): + """Parse the output timestamp string from an input file path.""" + return f's{"".join(cls.parse_timestamp(input_file))}' + + @cached_property + def timestamp(self): + """Get the parsed timestamp tuple for the primary input file.""" + return self.parse_timestamp(self.primary_input_file) + + @cached_property + def timestamp_string(self): + """Get the parsed output timestamp string for the primary input.""" + return self.parse_timestamp_string(self.primary_input_file) + + @cached_property + def time_index(self): + """Get a single-step time index for the primary input file.""" + year, doy, hour, minute, _ = self.timestamp + timestamp = pd.to_datetime( + f'{year}{doy}{hour}{minute}00', format='%Y%j%H%M%S' + ) + return pd.DatetimeIndex([timestamp]) + + @cached_property + def output_file(self): + """Get output file name for the configured output pattern.""" + year, doy, *_ = self.timestamp + return self.output_pattern.format( + year=year, + doy=doy, + timestamp=self.timestamp_string, + ) + + @classmethod + def open_dataset(cls, input_data): + """Open the raw input data as an xarray dataset.""" + raise NotImplementedError + + @cached_property + def ds(self): + """Get xarray dataset for raw input data.""" + return self.open_dataset(self.input_data) + + @classmethod + def transform_raw_data(cls, ds): + """Apply any subclass-specific preprocessing to the raw dataset.""" + return ds + + def get_solar_zenith(self, ds): + """Derive the solar zenith angle for the dataset.""" + lats = ds['latitude'].values + lons = ds['longitude'].values + return SolarPosition._zenith( + self.time_index, lats, lons + ) + + def get_solar_azimuth(self, ds): + """Derive the solar azimuth angle for the dataset.""" + lats = ds['latitude'].values + lons = ds['longitude'].values + return SolarPosition._azimuth( + self.time_index, lats, lons + ) + + @classmethod + def rename_vars(cls, ds): + """Rename variables to uwisc conventions.""" + for current_name, target_name in cls.NAME_MAP.items(): + if current_name in ds.data_vars: + ds = ds.rename({current_name: target_name}) + return ds + + @classmethod + def drop_vars(cls, ds): + """Drop variables that are not part of the UWISC output schema.""" + for var_name in DROP_VARS: + if var_name in ds.data_vars: + ds = ds.drop_vars(var_name) + return ds + + @staticmethod + def _rename_spatial_dims(ds): + """Rename alternative spatial dimensions to the shared convention.""" + rename_map = {} + if 'Lines' in ds.dims: + rename_map['Lines'] = 'south_north' + if 'Pixels' in ds.dims: + rename_map['Pixels'] = 'west_east' + if 'dim_y' in ds.dims: + rename_map['dim_y'] = 'south_north' + if 'dim_x' in ds.dims: + rename_map['dim_x'] = 'west_east' + if rename_map: + ds = ds.rename(rename_map) + return ds + + @staticmethod + def _promote_lat_lon_coords(ds): + """Ensure latitude and longitude are stored as coordinates.""" + sdims = ('south_north', 'west_east') + if ('lat' in ds.coords or 'lat' in ds.data_vars) and ( + 'latitude' not in ds.coords and 'latitude' not in ds.data_vars + ): + ds = ds.rename({'lat': 'latitude', 'lon': 'longitude'}) + + if ds['latitude'].ndim == 1 and ds['longitude'].ndim == 1: + ds['south_north'] = ds['latitude'] + ds['west_east'] = ds['longitude'] + + lons, lats = np.meshgrid(ds['longitude'], ds['latitude']) + ds = ds.assign_coords({ + 'latitude': (sdims, lats), + 'longitude': (sdims, lons), + }) + + coord_names = [ + name for name in ('latitude', 'longitude') if name in ds.data_vars + ] + if coord_names: + ds = ds.set_coords(coord_names) + return ds + + def remap_dims(self, ds): + """Rename dims and coords to standards and build 2D lat/lon grids.""" + for var_name in ds.data_vars: + single_ts = ( + 'time' in ds[var_name].dims + and ds[var_name].transpose('time', ...).shape[0] == 1 + ) + if single_ts and var_name != 'reference_time': + ds[var_name] = ( + ('south_north', 'west_east'), + ds[var_name].isel(time=0).data, + ) + + ref_time = ds.attrs.get('reference_time', None) + if ref_time is not None: + time_index = pd.DatetimeIndex([ref_time]).values + else: + time_index = self.time_index.values + ds = self._rename_spatial_dims(ds) + ds = self._promote_lat_lon_coords(ds) + ds = ds.assign_coords({'time': ('time', time_index)}) + return ds + + @classmethod + def fill_missing_vars(cls, ds): + """Fill any missing variables with NaN arrays.""" + for var_name in cls.NAME_MAP: + if var_name not in ds.data_vars: + ds[var_name] = ( + ('south_north', 'west_east'), + np.full( + (ds.sizes['south_north'], ds.sizes['west_east']), + np.nan, + ), + ) + return ds + + def derive_solar_angles(self, ds): + """Derive solar angles if not already present in the dataset.""" + if 'solar_zenith_angle' not in ds.data_vars: + ds['solar_zenith_angle'] = ( + ('south_north', 'west_east'), + self.get_solar_zenith(ds), + ) + if 'solar_azimuth_angle' not in ds.data_vars: + ds['solar_azimuth_angle'] = ( + ('south_north', 'west_east'), + self.get_solar_azimuth(ds), + ) + return ds + + @classmethod + def remap_cloud_phase(cls, ds): + """Map source cloud phase flags to UWISC cloud types.""" + if cls.CLOUD_TYPE_SOURCE_VAR is None: + return ds + + cloud_type_name = cls.NAME_MAP[cls.CLOUD_TYPE_SOURCE_VAR] + cloud_type = ds[cloud_type_name].values.copy() + for value, cloud_source in cls.CLOUD_TYPE_MAP.items(): + cloud_type = np.where( + ds[cloud_type_name].values.astype(int) == int(value), + UWISC_CLOUD_TYPE[cloud_source], + cloud_type, + ) + ds[cloud_type_name] = (ds[cloud_type_name].dims, cloud_type) + return ds + + @classmethod + def derive_stdevs(cls, ds): + """Derive standard deviations used as training features.""" + for var_name in ('refl_0_65um_nom', 'temp_11_0um_nom'): + stddev = ( + ds[var_name] + .rolling( + south_north=3, + west_east=3, + center=True, + min_periods=1, + ) + .std() + ) + ds[f'{var_name}_stddev_3x3'] = stddev + return ds + + def process_dataset(self, ds): + """Run the shared UWISC preprocessing pipeline on a dataset.""" + ds = self.remap_dims(ds) + ds = self.fill_missing_vars(ds) + ds = self.transform_raw_data(ds) + ds = self.rename_vars(ds) + ds = self.drop_vars(ds) + ds = self.remap_cloud_phase(ds) + ds = self.derive_stdevs(ds) + ds = self.derive_solar_angles(ds) + return ds + + @classmethod + def write_output(cls, ds, output_file): + """Write converted dataset to the final output file.""" + os.makedirs(os.path.dirname(output_file), exist_ok=True) + ds = ds.transpose('south_north', 'west_east', ...) + ds.load().to_netcdf(output_file, format='NETCDF4', engine='h5netcdf') + + @classmethod + def run(cls, input_data, output_pattern): + """Run the conversion routine and write the converted dataset.""" + logger = logging.getLogger(cls.__module__) + data_model = cls(input_data, output_pattern) + + if os.path.exists(data_model.output_file): + logger.info( + '%s already exists. Skipping conversion.', + data_model.output_file, + ) + return + + logger.info('Geting xarray dataset for %s', input_data) + ds = data_model.process_dataset(data_model.ds) + + logger.info('Writing converted file to %s', data_model.output_file) + cls.write_output(ds, data_model.output_file) diff --git a/nsrdb/preprocessing/gk2a_data_model.py b/nsrdb/preprocessing/gk2a_data_model.py new file mode 100644 index 00000000..cabf0979 --- /dev/null +++ b/nsrdb/preprocessing/gk2a_data_model.py @@ -0,0 +1,330 @@ +"""Convert GK2A data to UWISC format.""" + +import argparse +import logging +import os +import re +from contextlib import suppress +from glob import glob + +import numpy as np +import pandas as pd +import xarray as xr +from rex import init_logger + +from nsrdb.preprocessing.base_data_model import ( + BaseUwiscDataModel, + run_data_model_jobs, +) + +init_logger('nsrdb', log_level='DEBUG') +init_logger(__name__, log_level='DEBUG') + +logger = logging.getLogger(__name__) + +NAME_MAP = { + 'sw038': 'temp_3_75um_nom', # brightness temperature at 3.75 um (K) + 'ir112': 'temp_11_0um_nom', # brightness temperature at 11.0 um (K) + 'vi006': 'refl_0_65um_nom', # visible reflectance at 0.65 um (%) + 'COT': 'cld_opd_dcomp', # cloud optical thickness + 'CER': 'cld_reff_dcomp', # cloud effective radius + 'CTP': 'cld_press_acha', # cloud top pressure + 'CTH': 'cld_height_acha', # cloud top height + 'CP': 'cloud_type', # cloud phase + 'VZA': 'sensor_zenith_angle', + 'VAZ': 'sensor_azimuth_angle', +} + +GK2A_CLOUD_TYPE = { + 'Clear': 0, + 'Water phase': 1, + 'Ice phase': 2, + 'Uncertain phase': 6, +} + +CLOUD_TYPE_MAP = { + 0: 'Clear', + 1: 'Water', + 2: 'Opaque Ice', + 6: 'Unknown', +} + +# from https://nmsc.kma.go.kr/upload/resource/data/gk2a/20190415_GK-2A_AMI_Conversion_Table_v3.0.zip +VAR_CONSTANTS = { + 'vi006': { + 'gain': 0.154856294393539, + 'offset': -6.194244384765620, + 'cprime': 0.0019244840, + }, + 'sw038': { + 'gain': -0.00108296517282724000, + 'offset': 17.69998741149900000000, + 'center_wn': 2612.677373521110, + 'c0': -0.447843939824124, + 'c1': 1.000655680903890, + 'c2': -6.338240899124480e-08, + }, + 'ir112': { + 'gain': -0.02167448587715620000, + 'offset': 176.71343994140600000000, + 'center_wn': 891.713057301260, + 'c0': -0.249111718496148, + 'c1': 1.001211668737560, + 'c2': -1.131679640116650e-06, + }, +} + + +class Gk2aDataModel(BaseUwiscDataModel): + """Class to handle conversion of gk2a data to standard uwisc style format + for NSRDB pipeline""" + + NAME_MAP = NAME_MAP + CLOUD_TYPE_MAP = CLOUD_TYPE_MAP + CLOUD_TYPE_SOURCE_VAR = 'CP' + + @staticmethod + def count_to_rad(ds, var, gain, offset): + """Convert raw counts to radiance using gain and offset.""" + return ds[var] * gain + offset + + @classmethod + def count_to_refl(cls, ds, var, gain, offset, cprime): + """Convert raw counts to reflectance / albedo percent.""" + rad = cls.count_to_rad(ds, var, gain, offset) + albedo = rad * cprime * 100 + return albedo + + @staticmethod + def rad_to_temp(rad, center_wn): + """Convert radiance to brightness temperature using Planck's law.""" + h = 6.62607015e-34 + c = 2.99792458e8 + k = 1.380649e-23 + + wn_m = 100 * center_wn + temp = (h * c) / ( + wn_m * k * np.log((2 * h * c**2 * wn_m**3) / (rad * 1e-5) + 1) + ) + return temp + + @classmethod + def count_to_temp(cls, ds, var, *, center_wn, gain, offset, c0, c1, c2): + """Convert raw counts to brightness temperature.""" + rad = cls.count_to_rad(ds, var, gain, offset) + te = cls.rad_to_temp(rad, center_wn) + tb = c0 + c1 * te + c2 * te**2 + return tb + + @classmethod + def transform_raw_data(cls, ds): + """Convert raw counts to radiance for IR and visible channels.""" + temp_vars = ( + var for var in NAME_MAP if NAME_MAP[var].startswith('temp') + ) + refl_vars = ( + var for var in NAME_MAP if NAME_MAP[var].startswith('refl') + ) + for var in temp_vars: + constants = VAR_CONSTANTS[var] + ds[var] = cls.count_to_temp(ds, var, **constants) + for var in refl_vars: + constants = VAR_CONSTANTS[var] + ds[var] = cls.count_to_refl(ds, var, **constants) + return ds + + @classmethod + def get_primary_input_file(cls, input_files): + """Get the timestamped input file used for naming outputs.""" + for file in input_files: + if re.search(r'(\d{12})(?=\.[^.]+$)', os.path.basename(file)): + return file + return input_files[0] + + @staticmethod + def _normalize_channel_name(channel_name): + """Normalize a channel name attribute to a lowercase variable name.""" + if isinstance(channel_name, (list, tuple, np.ndarray)): + channel_name = channel_name[0] + if hasattr(channel_name, 'item') and not isinstance(channel_name, str): + with suppress(ValueError): + channel_name = channel_name.item() + if isinstance(channel_name, bytes): + channel_name = channel_name.decode() + return str(channel_name).lower() + + @classmethod + def _rename_image_variable(cls, ds): + """Rename image pixel values using the embedded channel name attr.""" + if 'image_pixel_values' not in ds.data_vars: + return ds + + channel_name = ds['image_pixel_values'].attrs.get('channel_name') + if channel_name is None: + msg = 'image_pixel_values variable is missing channel_name attr' + raise KeyError(msg) + + return ds.rename({ + 'image_pixel_values': cls._normalize_channel_name(channel_name) + }) + + @staticmethod + def _normalize_spatial_dims(ds): + """Rename equivalent spatial dims so datasets can be merged.""" + rename_map = {} + if 'x' in ds.dims: + rename_map['x'] = 'dim_x' + if 'y' in ds.dims: + rename_map['y'] = 'dim_y' + if rename_map: + ds = ds.rename(rename_map) + return ds + + @staticmethod + def _get_spatial_shape(ds): + """Get the standard spatial shape for a dataset if present.""" + if 'dim_y' in ds.dims and 'dim_x' in ds.dims: + return ds.sizes['dim_y'], ds.sizes['dim_x'] + return None + + @classmethod + def _coarsen_highres_dataset(cls, ds, target_shape): + """Coarsen 4x higher-resolution datasets down to the target shape.""" + shape = cls._get_spatial_shape(ds) + if shape is None or shape == target_shape: + return ds + + y_size, x_size = shape + target_y, target_x = target_shape + y_factor = y_size // target_y + x_factor = x_size // target_x + if ( + y_size == target_y * 4 + and x_size == target_x * 4 + and y_factor == 4 + and x_factor == 4 + ): + return ds.coarsen(dim_y=4, dim_x=4, boundary='trim').mean( + keep_attrs=True + ) + + msg = ( + 'Cannot align dataset with spatial shape ' + f'{shape} to target shape {target_shape}' + ) + raise ValueError(msg) + + @classmethod + def parse_timestamp(cls, input_file): + """Parse the GK2A timestamp tuple from an input file path.""" + basename = os.path.basename(input_file) + match = re.search(r'(\d{12})(?=\.[^.]+$)', basename) + timestamp = pd.to_datetime(match.group(1), format='%Y%m%d%H%M') + year = timestamp.strftime('%Y') + doy = timestamp.strftime('%j') + hour = timestamp.strftime('%H') + minute = timestamp.strftime('%M') + secs = '000' + return year, doy, hour, minute, secs + + @classmethod + def open_dataset(cls, input_files): + """Get xarray dataset for raw input file""" + return cls.combine_files(input_files) + + @classmethod + def get_files_from_timestamp(cls, timestamp): + """Get list of files needed for given timestamp. This is needed to + combine different channels, which are stored in separate files.""" + year, doy, hour, minute, _ = timestamp + file_pattern = f'*{year}.{doy}.{hour}{minute}*.nc' + files = glob(os.path.join(os.path.dirname(__file__), file_pattern)) + return files + + @classmethod + def combine_files(cls, files): + """Combine multiple files into one dataset. This is needed to combine + different channels, which are stored in separate files.""" + ds_list = [] + for file in files: + ds = xr.open_dataset(file, format='NETCDF4', engine='h5netcdf') + ds = cls._rename_image_variable(ds) + ds = cls._normalize_spatial_dims(ds) + ds_list.append(ds) + + spatial_shapes = [ + shape for ds in ds_list if (shape := cls._get_spatial_shape(ds)) + ] + if spatial_shapes: + target_shape = min(spatial_shapes) + ds_list = [ + cls._coarsen_highres_dataset(ds, target_shape) + for ds in ds_list + ] + + return xr.merge( + ds_list, + compat='override', + combine_attrs='drop_conflicts', + ) + + @classmethod + def run(cls, input_files, output_pattern): + """Run conversion routine and write converted dataset.""" + return super().run(input_files, output_pattern) + + +def group_files_by_timestamp(files): + """Group files by timestamp, which is needed to combine different channels, + which are stored in separate files.""" + groups = {} + untimestamped = [] + for file in files: + match = re.search(r'(\d{12})(?=\.[^.]+$)', os.path.basename(file)) + if match is None: + untimestamped.append(file) + continue + + timestamp = f's{match.group(1)}' + if timestamp not in groups: + groups[timestamp] = [] + groups[timestamp].append(file) + + if not groups: + return [untimestamped.copy()] if untimestamped else [] + + return [group + untimestamped for group in groups.values()] + + +if __name__ == '__main__': + default_output_pattern = '/projects/pxs/GK2A/standardized/{year}' + default_output_pattern += '/{doy}/gk2a_{timestamp}.nc' + parser = argparse.ArgumentParser() + parser.add_argument( + 'input_pattern', + type=str, + nargs='+', + help="""File pattern for input_files. e.g. + /projects/pxs/GK2A/2025/**/*.nc""", + ) + parser.add_argument( + '-output_pattern', + type=str, + default=default_output_pattern, + help='File pattern for output files.', + ) + parser.add_argument( + '-max_workers', + type=int, + default=10, + help='Number of workers to use for parallel file conversion', + ) + args = parser.parse_args() + run_data_model_jobs( + Gk2aDataModel, + args.input_pattern, + args.output_pattern, + max_workers=args.max_workers, + group_inputs=group_files_by_timestamp, + logger=logger, + ) diff --git a/nsrdb/preprocessing/nasa_data_model.py b/nsrdb/preprocessing/nasa_data_model.py index 9ad26120..a7eff157 100644 --- a/nsrdb/preprocessing/nasa_data_model.py +++ b/nsrdb/preprocessing/nasa_data_model.py @@ -2,24 +2,20 @@ import argparse import logging -import os -import re -from concurrent.futures import ThreadPoolExecutor, as_completed -from functools import cached_property -from glob import glob -import numpy as np -import pandas as pd import xarray as xr from rex import init_logger +from nsrdb.preprocessing.base_data_model import ( + BaseUwiscDataModel, + run_data_model_jobs, +) + init_logger('nsrdb', log_level='DEBUG') init_logger(__name__, log_level='DEBUG') logger = logging.getLogger(__name__) -DROP_VARS = ['relative_time'] - NAME_MAP = { 'BT_3.75um': 'temp_3_75um_nom', 'BT_10.8um': 'temp_11_0um_nom', @@ -34,23 +30,6 @@ 'relative_azimuth': 'solar_azimuth_angle', } -UWISC_CLOUD_TYPE = { - 'N/A': -15, - 'Clear': 0, - 'Probably Clear': 1, - 'Fog': 2, - 'Water': 3, - 'Super-Cooled Water': 4, - 'Mixed': 5, - 'Opaque Ice': 6, - 'Cirrus': 7, - 'Overlapping': 8, - 'Overshooting': 9, - 'Unknown': 10, - 'Dust': 11, - 'Smoke': 12, -} - NASA_CLOUD_TYPE = { 'Clear sky snow/ice': 0, 'Water cloud': 1, @@ -76,197 +55,28 @@ } -class NasaDataModel: +class NasaDataModel(BaseUwiscDataModel): """Class to handle conversion of nasa data to standard uwisc style format for NSRDB pipeline""" - def __init__(self, input_file, output_pattern): - """ - Parameters - ---------- - input_file : str - e.g. "./2017/01/01/nacomposite_2017.001.0000.nc" - output_pattern : str - Needs to include year, doy, and timestamp format keys. - e.g. "./{year}/{doy}/nacomposite_{timestamp}.nc" - """ - self.input_file = input_file - self.output_pattern = output_pattern - - @cached_property - def timestamp(self): - """Get year, doy, hour from input file name. + NAME_MAP = NAME_MAP + CLOUD_TYPE_MAP = CLOUD_TYPE_MAP + CLOUD_TYPE_SOURCE_VAR = 'cloud_phase' - TODO: Should get this from relative_time variables to be more precise - """ - match_pattern = r'.*_([0-9]+).([0-9]+).([0-9]+).\w+' - ts = re.match(match_pattern, self.input_file).groups() - year, doy, hour = ts - secs = '000' - return year, doy, hour, secs - - @cached_property - def output_file(self): - """Get output file name for given output pattern.""" - year, doy, _, _ = self.timestamp - return self.output_pattern.format( - year=year, doy=doy, timestamp=f's{"".join(self.timestamp)}' - ) - - @cached_property - def ds(self): - """Get xarray dataset for raw input file""" + @classmethod + def open_dataset(cls, input_file): + """Get xarray dataset for raw input file.""" return xr.open_mfdataset( - self.input_file, + input_file, **{'group': 'map_data', 'decode_times': False}, format='NETCDF4', engine='h5netcdf', ) - @classmethod - def rename_vars(cls, ds): - """Rename variables to uwisc conventions""" - for k, v in NAME_MAP.items(): - if k in ds.data_vars: - ds = ds.rename({k: v}) - return ds - - @classmethod - def drop_vars(cls, ds): - """Drop list of variables""" - for v in DROP_VARS: - if v in ds.data_vars: - ds = ds.drop_vars(v) - return ds - - @classmethod - def remap_dims(cls, ds): - """Rename dims and coords to standards. Make lat / lon into 2d arrays, - as expected by cloud regridding routine.""" - - sdims = ('south_north', 'west_east') - for var in ds.data_vars: - single_ts = ( - 'time' in ds[var].dims - and ds[var].transpose('time', ...).shape[0] == 1 - ) - if single_ts and var != 'reference_time': - ds[var] = (sdims, ds[var].isel(time=0).data) - - ref_time = ds.attrs.get('reference_time', None) - if ref_time is not None: - ti = pd.DatetimeIndex([ref_time]).values - ds = ds.assign_coords({'time': ('time', ti)}) - if 'Lines' in ds.dims: - ds = ds.swap_dims({'Lines': 'south_north', 'Pixels': 'west_east'}) - if 'lat' in ds.coords: - ds = ds.rename({'lat': 'latitude', 'lon': 'longitude'}) - - ds['south_north'] = ds['latitude'] - ds['west_east'] = ds['longitude'] - - lons, lats = np.meshgrid(ds['longitude'], ds['latitude']) - ds = ds.assign_coords({ - 'latitude': (sdims, lats), - 'longitude': (sdims, lons), - }) - return ds - - @classmethod - def remap_cloud_phase(cls, ds): - """Map nasa cloud phase flags to uwisc values.""" - ct_name = NAME_MAP['cloud_phase'] - cloud_type = ds[ct_name].values.copy() - for val, cs_type in CLOUD_TYPE_MAP.items(): - cloud_type = np.where( - ds[ct_name].values.astype(int) == int(val), - UWISC_CLOUD_TYPE[cs_type], - cloud_type, - ) - ds[ct_name] = (ds[ct_name].dims, cloud_type) - return ds - - @classmethod - def derive_stdevs(cls, ds): - """Derive standard deviations of some variables, which are used as - training features.""" - for var_name in ('refl_0_65um_nom', 'temp_11_0um_nom'): - stddev = ( - ds[var_name] - .rolling( - south_north=3, west_east=3, center=True, min_periods=1 - ) - .std() - ) - ds[f'{var_name}_stddev_3x3'] = stddev - return ds - - @classmethod - def write_output(cls, ds, output_file): - """Write converted dataset to output_file.""" - os.makedirs(os.path.dirname(output_file), exist_ok=True) - ds = ds.transpose('south_north', 'west_east', ...) - ds.load().to_netcdf(output_file, format='NETCDF4', engine='h5netcdf') - @classmethod def run(cls, input_file, output_pattern): """Run conversion routine and write converted dataset.""" - dm = cls(input_file, output_pattern) - - if os.path.exists(dm.output_file): - logger.info( - '%s already exists. Skipping conversion.', dm.output_file - ) - else: - logger.info('Geting xarray dataset for %s', input_file) - ds = dm.ds - - logger.info('Remapping dimensions.') - ds = dm.remap_dims(ds) - - logger.info('Renaming variables.') - ds = dm.rename_vars(ds) - - logger.info('Dropping some variables.') - ds = dm.drop_vars(ds) - - logger.info('Remapping cloud type values.') - ds = dm.remap_cloud_phase(ds) - - logger.info('Deriving some stddev variables.') - ds = dm.derive_stdevs(ds) - - logger.info('Writing converted file to %s', dm.output_file) - dm.write_output(ds, dm.output_file) - - -def run_jobs(input_pattern, output_pattern, max_workers=None): - """Run multiple file conversion jobs""" - - files = glob(input_pattern) - - if max_workers == 1: - for file in files: - NasaDataModel.run(input_file=file, output_pattern=output_pattern) - else: - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = {} - for file in files: - fut = executor.submit( - NasaDataModel.run, - input_file=file, - output_pattern=output_pattern, - ) - futures[fut] = file - - for future in as_completed(futures): - try: - future.result() - except Exception as e: - logger.error('Error processing file: %s', futures[future]) - logger.exception(e) - - logger.info('Finished converting %s files.', len(files)) + return super().run(input_file, output_pattern) if __name__ == '__main__': @@ -276,6 +86,7 @@ def run_jobs(input_pattern, output_pattern, max_workers=None): parser.add_argument( 'input_pattern', type=str, + nargs='+', help="""File pattern for input_files. e.g. /projects/pxs/nasa_polar/2023/*/*/*.nc""", ) @@ -292,8 +103,10 @@ def run_jobs(input_pattern, output_pattern, max_workers=None): help='Number of workers to use for parallel file conversion', ) args = parser.parse_args() - run_jobs( - input_pattern=args.input_pattern, - output_pattern=args.output_pattern, + run_data_model_jobs( + NasaDataModel, + args.input_pattern, + args.output_pattern, max_workers=args.max_workers, + logger=logger, ) diff --git a/nsrdb/solar_position/spa.py b/nsrdb/solar_position/spa.py index 0730f7c1..fe32a623 100755 --- a/nsrdb/solar_position/spa.py +++ b/nsrdb/solar_position/spa.py @@ -1205,7 +1205,7 @@ def elevation(cls, time_index, lat_lon, elev=0, delta_t=None): @classmethod def azimuth(cls, time_index, lat_lon, elev=0, delta_t=None): """ - Compute the solar elevation + Compute the solar azimuth Parameters ---------- @@ -1231,7 +1231,7 @@ def azimuth(cls, time_index, lat_lon, elev=0, delta_t=None): @classmethod def zenith(cls, time_index, lat_lon, elev=0, delta_t=None): """ - Compute the solar elevation + Compute the solar zenith Parameters ---------- diff --git a/tests/preproc/conftest.py b/tests/preproc/conftest.py new file mode 100644 index 00000000..3232f2e8 --- /dev/null +++ b/tests/preproc/conftest.py @@ -0,0 +1,35 @@ +"""Shared pytest fixtures for preprocessing tests.""" + +import pytest + + +@pytest.fixture +def make_nested_files(tmp_path): + """Create nested files beneath a temporary directory.""" + + def _make(*relative_paths): + files = [] + for relative_path in relative_paths: + file_path = tmp_path / relative_path + file_path.parent.mkdir(parents=True, exist_ok=True) + file_path.touch() + files.append(str(file_path)) + return files + + return _make + + +@pytest.fixture +def collect_run_calls(monkeypatch): + """Patch a data model run method and collect its calls.""" + + def _collect(data_model_class): + calls = [] + + def fake_run(input_data, output_pattern): + calls.append((input_data, output_pattern)) + + monkeypatch.setattr(data_model_class, 'run', fake_run) + return calls + + return _collect diff --git a/tests/preproc/test_base_data_model.py b/tests/preproc/test_base_data_model.py new file mode 100644 index 00000000..c6f8e018 --- /dev/null +++ b/tests/preproc/test_base_data_model.py @@ -0,0 +1,102 @@ +"""Tests for shared preprocessing base model helpers.""" + +import numpy as np +import pytest +import xarray as xr + +from nsrdb.preprocessing.base_data_model import ( + BaseUwiscDataModel, + expand_input_patterns, + run_data_model_jobs, +) + + +class DummyDataModel(BaseUwiscDataModel): + """Minimal test double for the shared base preprocessing model.""" + + +class DummyRunner: + """Minimal runner target for shared run-data-model job tests.""" + + calls = [] + + @classmethod + def run(cls, input_data, output_pattern): + """Collect runner calls for assertions.""" + cls.calls.append((input_data, output_pattern)) + + +@pytest.mark.parametrize( + ('y_dim', 'x_dim'), + [ + ('dim_y', 'dim_x'), + ('Lines', 'Pixels'), + ], +) +def test_remap_dims_normalizes_spatial_dims_and_coords(y_dim, x_dim): + """remap_dims should normalize time, dims, and lat/lon coordinates.""" + model = DummyDataModel( + '/tmp/input_2025.001.12.nc', + '/tmp/out/{year}/{doy}/{timestamp}.nc', + ) + ds = xr.Dataset( + { + 'foo': ( + ('time', y_dim, x_dim), + np.ones((1, 2, 2)), + ), + 'latitude': (y_dim, np.array([10.0, 20.0])), + 'longitude': (x_dim, np.array([30.0, 40.0])), + }, + ) + + remapped = model.remap_dims(ds) + + assert remapped.indexes['time'][0] == model.time_index[0] + assert remapped['foo'].dims == ('south_north', 'west_east') + assert 'latitude' in remapped.coords + assert 'longitude' in remapped.coords + assert remapped['latitude'].dims == ('south_north', 'west_east') + assert remapped['longitude'].dims == ('south_north', 'west_east') + + +def test_expand_input_patterns_handles_multiple_recursive_globs( + make_nested_files, +): + """Multiple recursive glob patterns should be expanded into one list.""" + file_1, file_2 = make_nested_files( + 'set_1/a/file_1.nc', + 'set_2/b/file_2.nc', + ) + + root = file_1.split('/set_1/', 1)[0] + files = expand_input_patterns([ + f'{root}/set_1/**/*.nc', + f'{root}/set_2/**/*.nc', + ]) + + assert files == [file_1, file_2] + + +def test_run_data_model_jobs_accepts_multiple_glob_patterns( + make_nested_files, +): + """Shared run_data_model_jobs should handle multiple glob inputs.""" + file_1, file_2 = make_nested_files( + 'set_1/a/file_1.nc', + 'set_2/b/file_2.nc', + ) + DummyRunner.calls = [] + + root = file_1.split('/set_1/', 1)[0] + run_data_model_jobs( + DummyRunner, + [f'{root}/set_1/**/*.nc', f'{root}/set_2/**/*.nc'], + '/tmp/out/{year}/{doy}/file_{timestamp}.nc', + max_workers=1, + ) + + assert DummyRunner.calls == [ + (file_1, '/tmp/out/{year}/{doy}/file_{timestamp}.nc'), + (file_2, '/tmp/out/{year}/{doy}/file_{timestamp}.nc'), + ] diff --git a/tests/preproc/test_gk2a_data_model.py b/tests/preproc/test_gk2a_data_model.py new file mode 100644 index 00000000..d6fa976b --- /dev/null +++ b/tests/preproc/test_gk2a_data_model.py @@ -0,0 +1,107 @@ +"""Tests for GK2A preprocessing helpers.""" + +from pathlib import Path + +import numpy as np +import xarray as xr + +from nsrdb.preprocessing.gk2a_data_model import ( + Gk2aDataModel, + group_files_by_timestamp, +) + + +def test_group_files_by_timestamp_includes_untimestamped_files(): + """Untimestamped files should be included with each timestamp group.""" + files = [ + '/tmp/vi006_202501010700.nc', + '/tmp/ir112_202501010700.nc', + '/tmp/vi006_202501010710.nc', + '/tmp/static_mask.nc', + ] + + groups = group_files_by_timestamp(files) + groups = sorted(sorted(group) for group in groups) + + assert groups == [ + sorted([ + '/tmp/ir112_202501010700.nc', + '/tmp/static_mask.nc', + '/tmp/vi006_202501010700.nc', + ]), + sorted([ + '/tmp/static_mask.nc', + '/tmp/vi006_202501010710.nc', + ]), + ] + + +def test_group_files_by_timestamp_handles_only_untimestamped_files(): + """A list without timestamps should remain a single group.""" + files = ['/tmp/static_mask.nc', '/tmp/terrain.nc'] + + assert group_files_by_timestamp(files) == [files] + + +def test_gk2a_output_filename_uses_year_doy_hour_minute_seconds(): + """GK2A output filenames should use NSRDB-style timestamp strings.""" + data_model = Gk2aDataModel( + ['/tmp/vi006_202501010700.nc', '/tmp/static_mask.nc'], + '/tmp/out/{year}/{doy}/gk2a_{timestamp}.nc', + ) + + assert data_model.timestamp_string == 's20250010700000' + assert ( + data_model.output_file == '/tmp/out/2025/001/gk2a_s20250010700000.nc' + ) + + +def test_combine_files_coarsens_and_renames_image_variables(tmp_path): + """Combine mixed-resolution files into one aligned dataset.""" + low_res = xr.Dataset({ + 'COT': (('dim_y', 'dim_x'), np.array([[1.0, 2.0], [3.0, 4.0]])) + }) + low_res_file = Path(tmp_path / 'cot_202501010700.nc') + low_res.to_netcdf(low_res_file, engine='h5netcdf') + + high_res = xr.Dataset({ + 'image_pixel_values': ( + ('y', 'x'), + np.kron( + np.array([[10.0, 20.0], [30.0, 40.0]]), + np.ones((4, 4)), + ), + ) + }) + high_res['image_pixel_values'].attrs['channel_name'] = 'VI006' + high_res_file = Path(tmp_path / 'vi006_202501010700.nc') + high_res.to_netcdf(high_res_file, engine='h5netcdf') + + combined = Gk2aDataModel.combine_files([ + str(low_res_file), + str(high_res_file), + ]) + + assert set(combined.data_vars) == {'COT', 'vi006'} + assert combined['vi006'].dims == ('dim_y', 'dim_x') + np.testing.assert_allclose( + combined['vi006'].values, + np.array([[10.0, 20.0], [30.0, 40.0]]), + ) + + +def test_gk2a_remap_cloud_phase_uses_cp_source_var(): + """GK2A cloud phase remapping should use the CP source variable mapping.""" + ds = xr.Dataset({ + 'cloud_type': ( + ('dim_y', 'dim_x'), + np.array([[0, 1], [2, 6]]), + ) + }) + + remapped = Gk2aDataModel.remap_cloud_phase(ds) + + np.testing.assert_array_equal( + remapped['cloud_type'].values, + np.array([[0, 3], [6, 10]]), + )