Skip to content

Commit

Permalink
WIP: about to make some ocean_model changes
Browse files Browse the repository at this point in the history
  • Loading branch information
kthyng committed Mar 5, 2025
1 parent 313288c commit 4c60fb2
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 225 deletions.
3 changes: 3 additions & 0 deletions particle_tracking_manager/config_ocean_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ class BaseOceanModel(BaseModel):
ocean_model_local: bool = Field(True, description="Set to True to use local ocean model data, False for remote access.")
end_time: datetime
horizontal_diffusivity: Optional[float] = Field(None, description="Horizontal diffusivity for the simulation.", ptm_level=2, od_mapping="drift:horizontal_diffusivity")
# TODO: Move functions for manipulating ocean model dataset to here and store ds, allowing user to input ds directly
# and avoid some of the initial checks as needed.


def open_dataset(self, drop_vars: list) -> xr.Dataset:
"""Open an xarray dataset
Expand Down
125 changes: 65 additions & 60 deletions particle_tracking_manager/models/opendrift/opendrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ def __init__(self, **kwargs):

# OpenDriftConfig, _KNOWN_MODELS = setup_opendrift_config(**kwargs)

# OpenDriftConfig is a subclass of PTMConfig so it knows about all the
# PTMConfig parameters. PTMConfig is run with OpenDriftConfig.
# OpenDriftConfig is a subclass of TheManagerConfig so it knows about all the
# TheManagerConfig parameters. TheManagerConfig is run with OpenDriftConfig.
# output_file was altered in PTM when setting up logger, so want to use
# that version.
# kwargs.update({"output_file": self.output_file})
Expand All @@ -154,7 +154,7 @@ def __init__(self, **kwargs):
keys_from_ocean_model = ["model_drop_vars", "ocean_model"]
inputs.update({key: getattr(self.ocean_model,key) for key in keys_from_ocean_model})
inputs.update(kwargs)
self.config = OpenDriftConfig(**inputs) # this runs both OpenDriftConfig and PTMConfig
self.config = OpenDriftConfig(**inputs) # this runs both OpenDriftConfig and TheManagerConfig
# logger = self.config.logger # this is where logger is expected to be found
# import pdb; pdb.set_trace()

Expand Down Expand Up @@ -226,21 +226,24 @@ def _create_opendrift_model_object(self):
self.o = o

def _update_od_config_from_this_config(self):
"""Update OpenDrift's config with OpenDriftConfig and PTMConfig.
"""Update OpenDrift's config with OpenDriftConfig and TheManagerConfig.
Update the default value in OpenDrift's config dict with the
config value from OpenDriftConfig (which includes PTMConfig).
config value from OpenDriftConfig, TheManagerConfig, OceanModelConfig, and SetupOutputFiles.
This uses the metadata key "od_mapping" to map from the PTM parameter
name to the OpenDrift parameter name.
"""

for key in self.config.model_fields:
if getattr(self.config.model_fields[key], "json_schema_extra") is not None:
if "od_mapping" in self.config.model_fields[key].json_schema_extra:
od_key = self.config.model_fields[key].json_schema_extra["od_mapping"]
if od_key in self.o._config:# and od_key is not None:
self.o._config[od_key]["value"] = getattr(self.config, key)
# import pdb; pdb.set_trace()

base_models_to_check = [self.files, self.manager_config, self.ocean_model, self.config]
for base_model in base_models_to_check:
for key in base_model.model_fields:
if getattr(base_model.model_fields[key], "json_schema_extra") is not None:
if "od_mapping" in base_model.model_fields[key].json_schema_extra:
od_key = base_model.model_fields[key].json_schema_extra["od_mapping"]
if od_key in self.o._config:# and od_key is not None:
self.o._config[od_key]["value"] = getattr(base_model, key)

def _modify_opendrift_model_object(self):

Expand Down Expand Up @@ -295,8 +298,8 @@ def add_reader(
# TODO: have standard_name_mapping as an initial input only with initial call to OpenDrift?
# TODO: has ds as an initial input for user-input ds?
if (
self.manager_config.ocean_model not in _KNOWN_MODELS
and self.manager_config.ocean_model != "test"
self.ocean_model.ocean_model not in _KNOWN_MODELS
and self.ocean_model.ocean_model != "test"
and ds is None
):
raise ValueError(
Expand All @@ -307,9 +310,9 @@ def add_reader(

if ds is not None:
if name is None:
self.manager_config.ocean_model = "user_input"
self.ocean_model.ocean_model = "user_input"
else:
self.manager_config.ocean_model = name
self.ocean_model.ocean_model = name

# TODO: do I still need a pathway for ocean_model of "test"?
# TODO: move tests from test_manager to other files
Expand All @@ -319,22 +322,22 @@ def add_reader(
ds = narrow_dataset_to_simulation_time(ds, self.manager_config.start_time, self.manager_config.end_time)
logger.info("Narrowed model output to simulation time")

ds = apply_known_ocean_model_specific_changes(ds, self.manager_config.ocean_model, self.manager_config.use_static_masks)
ds = apply_known_ocean_model_specific_changes(ds, self.ocean_model.ocean_model, self.manager_config.use_static_masks)

# TODO: the stuff in apply_user_input_ocean_model_specific_changes can be moved to OceanModelConfig
# validation I think
if self.manager_config.ocean_model not in _KNOWN_MODELS and self.manager_config.ocean_model != "test":
if self.ocean_model.ocean_model not in _KNOWN_MODELS and self.ocean_model.ocean_model != "test":
ds = apply_user_input_ocean_model_specific_changes(ds, self.manager_config.use_static_mask)

self.ds = ds

# if self.manager_config.ocean_model == "test":
# if self.ocean_model.ocean_model == "test":
# pass
# # oceanmodel_lon0_360 = True
# # loc = "test"
# # kwargs_xarray = dict()

# elif self.manager_config.ocean_model is not None or ds is not None:
# elif self.ocean_model.ocean_model is not None or ds is not None:
# # pass

# # TODO: should I change to computed_fields and where should this go?
Expand Down Expand Up @@ -404,7 +407,7 @@ def add_reader(
# # "Dropping mask_rho, mask_u, mask_v, mask_psi because using wetdry masks instead."
# # )

# # if self.manager_config.ocean_model == "NWGOA":
# # if self.ocean_model.ocean_model == "NWGOA":
# # oceanmodel_lon0_360 = True

# # standard_name_mapping.update(
Expand All @@ -425,7 +428,7 @@ def add_reader(
# # "snow_thick",
# # ]

# # if self.manager_config.ocean_model_local:
# # if self.ocean_model.ocean_model_local:

# # if self.config.start_time is None:
# # raise ValueError(
Expand All @@ -442,15 +445,15 @@ def add_reader(
# # "http://xpublish-nwgoa.srv.axds.co/datasets/nwgoa_all/zarr/"
# # )

# # elif "CIOFS" in self.manager_config.ocean_model:
# # elif "CIOFS" in self.ocean_model.ocean_model:
# # oceanmodel_lon0_360 = False

# # drop_vars += [
# # "wetdry_mask_psi",
# # ]
# # if self.manager_config.ocean_model == "CIOFS":
# # if self.ocean_model.ocean_model == "CIOFS":

# # if self.manager_config.ocean_model_local:
# # if self.ocean_model.ocean_model_local:

# # if self.config.start_time is None:
# # raise ValueError(
Expand All @@ -463,9 +466,9 @@ def add_reader(
# # )
# # loc_remote = "http://xpublish-ciofs.srv.axds.co/datasets/ciofs_hindcast/zarr/"

# # elif self.manager_config.ocean_model == "CIOFSFRESH":
# # elif self.ocean_model.ocean_model == "CIOFSFRESH":

# # if self.manager_config.ocean_model_local:
# # if self.ocean_model.ocean_model_local:

# # if self.config.start_time is None:
# # raise ValueError(
Expand All @@ -479,7 +482,7 @@ def add_reader(
# # )
# # loc_remote = None

# # elif self.manager_config.ocean_model == "CIOFSOP":
# # elif self.ocean_model.ocean_model == "CIOFSOP":

# # standard_name_mapping.update(
# # {
Expand All @@ -488,7 +491,7 @@ def add_reader(
# # }
# # )

# # if self.manager_config.ocean_model_local:
# # if self.ocean_model.ocean_model_local:

# # if self.config.start_time is None:
# # raise ValueError(
Expand All @@ -504,7 +507,7 @@ def add_reader(

# # loc_remote = "https://thredds.aoos.org/thredds/dodsC/AWS_CIOFS.nc"

# # if self.manager_config.ocean_model == "user_input":
# # if self.ocean_model.ocean_model == "user_input":

# # # check for case that self.config.use_static_masks False (which is the default)
# # # but user input doesn't have wetdry masks
Expand All @@ -518,7 +521,7 @@ def add_reader(

# # # if local and not a user-input ds
# # if ds is None:
# # if self.manager_config.ocean_model_local:
# # if self.ocean_model.ocean_model_local:

# # ds = xr.open_dataset(
# # self.config.loc_local,
Expand All @@ -536,7 +539,7 @@ def add_reader(
# # else:
# # if ".nc" in self.config.loc_remote:

# # if self.manager_config.ocean_model == "CIOFSFRESH":
# # if self.ocean_model.ocean_model == "CIOFSFRESH":
# # raise NotImplementedError

# # ds = xr.open_dataset(
Expand All @@ -558,12 +561,12 @@ def add_reader(
# # )

# # # For NWGOA, need to calculate wetdry mask from a variable
# # if self.manager_config.ocean_model == "NWGOA" and not self.config.use_static_masks:
# # if self.ocean_model.ocean_model == "NWGOA" and not self.config.use_static_masks:
# # ds["wetdry_mask_rho"] = (~ds.zeta.isnull()).astype(int)

# # # For CIOFSOP need to rename u/v to have "East" and "North" in the variable names
# # # so they aren't rotated in the ROMS reader (the standard names have to be x/y not east/north)
# # elif self.manager_config.ocean_model == "CIOFSOP":
# # elif self.ocean_model.ocean_model == "CIOFSOP":
# # ds = ds.rename_vars({"urot": "u_eastward", "vrot": "v_northward"})
# # # grid = xr.open_dataset("/mnt/vault/ciofs/HINDCAST/nos.ciofs.romsgrid.nc")
# # # ds["angle"] = grid["angle"]
Expand Down Expand Up @@ -611,15 +614,15 @@ def add_reader(
# # )
reader = reader_ROMS_native.Reader(
filename=ds,
name=self.manager_config.ocean_model,
name=self.ocean_model.ocean_model,
standard_name_mapping=self.ocean_model.standard_name_mapping,
save_interpolator=self.config.save_interpolator,
interpolator_filename=self.config.interpolator_filename,
)

self.o.add_reader([reader])
self.reader = reader
# can find reader at manager.o.env.readers[self.manager_config.ocean_model]
# can find reader at manager.o.env.readers[self.ocean_model.ocean_model]

# self.oceanmodel_lon0_360 = oceanmodel_lon0_360

Expand Down Expand Up @@ -794,7 +797,7 @@ def _model_config(self):
@property
def all_config(self):
"""Combined dict of this class config and OpenDrift native config."""

# TODO: update this
if self._all_config is None:
self._all_config = {**self._model_config, **self.config.dict()}
return self._all_config
Expand Down Expand Up @@ -910,12 +913,12 @@ def drift_model_config(self, prefix=""):

def show_all_config(
self,
key=None,
# key=None,
prefix="",
level=None,
ptm_level=None,
substring="",
excludestring="excludestring",
level=[1,2,3],
# ptm_level=None,
# substring="",
# excludestring="excludestring",
) -> dict:
"""Show configuring for the drift model selected in configuration.
Expand Down Expand Up @@ -989,31 +992,33 @@ def show_all_config(
"""

if key is not None:
prefix = key
return self.o.get_configspec(prefix=prefix, level=level)

output = self.get_configspec(
prefix=prefix,
level=level,
ptm_level=ptm_level,
substring=substring,
excludestring=excludestring,
)
import pdb; pdb.set_trace()
if key is not None:
if key in output:
return output[key]
else:
return output
else:
return output
# if key is not None:
# prefix = key

# output = self.get_configspec(
# prefix=prefix,
# level=level,
# ptm_level=ptm_level,
# substring=substring,
# excludestring=excludestring,
# )
# import pdb; pdb.set_trace()
# if key is not None:
# if key in output:
# return output[key]
# else:
# return output
# else:
# return output

def reader_metadata(self, key):
"""allow manager to query reader metadata."""

if not self.state.has_added_reader:
raise ValueError("reader has not been added yet.")
return self.o.env.readers[self.manager_config.ocean_model].__dict__[key]
return self.o.env.readers[self.ocean_model.ocean_model].__dict__[key]

# @property
# def outfile_name(self):
Expand Down
36 changes: 24 additions & 12 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,28 +374,40 @@ def test_misc_parameters():
"""Test values of parameters being input."""

m = TestParticleTrackingManager(steps=1, start_time="2022-01-01",
horizontal_diffusivity=1,
# horizontal_diffusivity=1,
number=100, time_step=5,
wind_drift_factor=0.04,
stokes_drift=False, log="DEBUG",)

assert m.manager_config.horizontal_diffusivity == 1
# assert m.manager_config.horizontal_diffusivity == 1
assert m.manager_config.number == 100
assert m.manager_config.time_step == 5
assert m.manager_config.wind_drift_factor == 0.04


def test_horizontal_diffusivity_logic():
"""Check logic for using default horizontal diff values for known models."""
# def test_horizontal_diffusivity_logic():
# """Check logic for using default horizontal diff values for known models."""

m = TestParticleTrackingManager(ocean_model="NWGOA", steps=1, start_time="2007-01-01")
assert m.manager_config.horizontal_diffusivity == 150.0 # known grid values
# m = TestParticleTrackingManager(ocean_model="NWGOA", steps=1, start_time="2007-01-01")
# assert m.manager_config.horizontal_diffusivity == 150.0 # known grid values

m = TestParticleTrackingManager(ocean_model="CIOFS", steps=1, start_time="2020-01-01")
assert m.manager_config.horizontal_diffusivity == 10.0 # known grid values
# m = TestParticleTrackingManager(ocean_model="CIOFS", steps=1, start_time="2020-01-01")
# assert m.manager_config.horizontal_diffusivity == 10.0 # known grid values

m = TestParticleTrackingManager(ocean_model="CIOFSOP", horizontal_diffusivity=11, steps=1)
assert m.manager_config.horizontal_diffusivity == 11.0 # user-selected value
# m = TestParticleTrackingManager(ocean_model="CIOFSOP", horizontal_diffusivity=11, steps=1)
# assert m.manager_config.horizontal_diffusivity == 11.0 # user-selected value

m = TestParticleTrackingManager(ocean_model="CIOFSOP", steps=1)
assert m.manager_config.horizontal_diffusivity == 10.0 # known grid values
# m = TestParticleTrackingManager(ocean_model="CIOFSOP", steps=1)
# assert m.manager_config.horizontal_diffusivity == 10.0 # known grid values




def test_output_file():
"""make sure output file is parquet if output_format is parquet"""

m = TestParticleTrackingManager(output_format="parquet", steps=1)
assert m.files.output_file.endswith(".parquet")

m = TestParticleTrackingManager(output_format="netcdf", steps=1)
assert m.files.output_file.endswith(".nc")
Loading

0 comments on commit 4c60fb2

Please sign in to comment.