Skip to content

Commit

Permalink
WIP: have separated ocean_model pretty well
Browse files Browse the repository at this point in the history
next is finish going through test_opendrift.py and fix opendrift.py config.
Also need to figure out how to handle when user inputs their own model Dataset.
  • Loading branch information
kthyng committed Mar 4, 2025
1 parent abdb4a3 commit 313288c
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 751 deletions.
405 changes: 0 additions & 405 deletions particle_tracking_manager/config.py

This file was deleted.

2 changes: 1 addition & 1 deletion particle_tracking_manager/config_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic.fields import FieldInfo
from typing_extensions import Self

from .utils import calc_known_horizontal_diffusivity
# from .utils import calc_known_horizontal_diffusivity
from .models.opendrift.utils import make_nwgoa_kerchunk, make_ciofs_kerchunk
from .config_the_manager import TheManagerConfig

Expand Down
288 changes: 211 additions & 77 deletions particle_tracking_manager/config_ocean_model.py

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions particle_tracking_manager/config_the_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pydantic.fields import FieldInfo
from typing_extensions import Self

from .utils import calc_known_horizontal_diffusivity
# from .utils import calc_known_horizontal_diffusivity
from .models.opendrift.utils import make_nwgoa_kerchunk, make_ciofs_kerchunk
import logging

Expand Down Expand Up @@ -53,30 +53,30 @@ class SeedFlagEnum(str, Enum):
# CRITICAL = "CRITICAL"


# Enum for "ocean_model"
class OceanModelEnum(str, Enum):
NWGOA = "NWGOA"
CIOFS = "CIOFS"
CIOFSOP = "CIOFSOP"
CIOFSFRESH = "CIOFSFRESH"
# # Enum for "ocean_model"
# class OceanModelEnum(str, Enum):
# NWGOA = "NWGOA"
# CIOFS = "CIOFS"
# CIOFSOP = "CIOFSOP"
# CIOFSFRESH = "CIOFSFRESH"


class TheManagerConfig(BaseModel):
model: ModelEnum = Field(ModelEnum.opendrift, description="Lagrangian model software to use for simulation.", ptm_level=1)
lon: float = Field(-151.0, ge=-180, le=180, description="Central longitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_east")
lat: float = Field(58.0, ge=-90, le=90, description="Central latitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_north")
lon: Optional[float] = Field(-151.0, ge=-180, le=180, description="Central longitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_east")
lat: Optional[float] = Field(58.0, ge=-90, le=90, description="Central latitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_north")
geojson: Optional[dict] = Field(None, description="GeoJSON describing a polygon within which to seed drifters. To use this parameter, also have `seed_flag==\"geojson\"`.", ptm_level=1)
seed_flag: SeedFlagEnum = Field(SeedFlagEnum.elements, description="Method for seeding drifters. Options are \"elements\" or \"geojson\". If \"elements\", seed drifters at or around a single point defined by lon and lat. If \"geojson\", seed drifters within a polygon described by a GeoJSON object.", ptm_level=1)
number: int = Field(100, description="Number of drifters to seed.", ptm_level=1, od_mapping="seed:number")
start_time: datetime = Field(datetime(2022,1,1), description="Start time for drifter simulation.", ptm_level=1)
start_time: Optional[datetime] = Field(datetime(2022,1,1), description="Start time for drifter simulation.", ptm_level=1)
start_time_end: Optional[datetime] = Field(None, description="If used, this creates a range of start times for drifters, starting with `start_time` and ending with `start_time_end`. Drifters will be initialized linearly between the two start times.", ptm_level=2)
run_forward: bool = Field(True, description="Run forward in time.", ptm_level=2)
time_step: int = Field(300, ge=1, le=86400, description="Interval between particles updates, in seconds.", ptm_level=3, units="seconds")
time_step_output: int = Field(3600, ge=1, le=604800, description="Time step at which element properties are stored and eventually written to file. Must be a multiple of time_step.", ptm_level=3, units="seconds")
steps: Optional[int] = Field(None, ge=1, le=10000, description="Maximum number of steps. End of simulation will be start_time + steps * time_step.", ptm_level=1)
duration: Optional[timedelta] = Field(None, description="The length of the simulation. steps, end_time, or duration must be input by user.", ptm_level=1)
end_time: Optional[datetime] = Field(None, description="The end of the simulation. steps, end_time, or duration must be input by user.", ptm_level=1)
ocean_model: OceanModelEnum = Field(OceanModelEnum.CIOFSOP, description="Name of ocean model to use for driving drifter simulation.", ptm_level=1)
# ocean_model: OceanModelEnum = Field(OceanModelEnum.CIOFSOP, description="Name of ocean model to use for driving drifter simulation.", ptm_level=1)
# NWGOA_time_range: Optional[datetime] = Field(None, description="Time range for NWGOA ocean model.", ptm_level=1)
# CIOFS_time_range: Optional[datetime] = Field(None, description="Time range for CIOFS ocean model.", ptm_level=1)
# CIOFSOP_time_range: Optional[datetime] = Field(None, description="Time range for CIOFSOP ocean model.", ptm_level=1)
Expand All @@ -90,15 +90,15 @@ class TheManagerConfig(BaseModel):
surface_only: Optional[bool] = Field(None, description="Set to True to keep drifters at the surface.", ptm_level=1)
do3D: bool = Field(False, description="Set to True to run drifters in 3D, by default False. This is overridden if surface_only==True.", ptm_level=1)
vertical_mixing: bool = Field(False, description="Set to True to activate vertical mixing in the simulation.", ptm_level=2)
z: float = Field(0, ge=-100000, le=0, description="Depth of the drifters.", ptm_level=1, od_mapping="seed:z")
z: Optional[float] = Field(0, ge=-100000, le=0, description="Depth of the drifters. None to use `seed_seafloor` flag.", ptm_level=1, od_mapping="seed:z")
seed_seafloor: bool = Field(False, description="Set to True to seed drifters on the seafloor.", ptm_level=2, od_mapping="seed:seafloor")
use_static_masks: bool = Field(True, description="Set to True to use static masks for known models instead of wetdry masks.", ptm_level=3)
# output_file: Optional[str] = Field(None, description="Name of file to write output to. If None, default name is used.", ptm_level=3)
# output_format: OutputFormatEnum = Field(OutputFormatEnum.netcdf, description="Output file format. Options are \"netcdf\" or \"parquet\".", ptm_level=2)
use_cache: bool = Field(True, description="Set to True to use cache for storing interpolators.", ptm_level=3)
wind_drift_factor: float = Field(0.02, description="Wind drift factor for the drifters.", ptm_level=2, od_mapping="seed:wind_drift_factor")
wind_drift_factor: Optional[float] = Field(0.02, description="Wind drift factor for the drifters.", ptm_level=2, od_mapping="seed:wind_drift_factor")
stokes_drift: bool = Field(True, description="Set to True to enable Stokes drift.", ptm_level=2, od_mapping="drift:stokes_drift")
horizontal_diffusivity: Optional[float] = Field(None, description="Horizontal diffusivity for the simulation.", ptm_level=2, od_mapping="drift:horizontal_diffusivity")
# horizontal_diffusivity: Optional[float] = Field(None, description="Horizontal diffusivity for the simulation.", ptm_level=2, od_mapping="drift:horizontal_diffusivity")
# log_level: LogLevelEnum = Field(LogLevelEnum.INFO, description="Log verbosity", ptm_level=3)

class Config:
Expand Down Expand Up @@ -257,7 +257,7 @@ def check_config_do3D(self) -> Self:

# elif self.ocean_model in _KNOWN_MODELS:

# hdiff = 0 # TODO: UPDATE THIS BACK calc_known_horizontal_diffusivity(self.ocean_model)
# hdiff = calc_known_horizontal_diffusivity(self.ocean_model)
# logger.info(
# f"Setting horizontal_diffusivity parameter to one tuned to reader model of value {hdiff}."
# )
Expand All @@ -281,11 +281,11 @@ def check_config_do3D(self) -> Self:
def check_config_ocean_model_local(self) -> Self:
if self.ocean_model_local:
logger.info(
f"Using local output for ocean_model {self.ocean_model}"
"Using local output for ocean_model."
)
else:
logger.info(
f"Using remote output for ocean_model {self.ocean_model}"
"Using remote output for ocean_model."
)
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class OpenDriftConfig(BaseModel):

wind_uncertainty: float = Field(default=0.0, value=0.0, od_mapping="drift:wind_uncertainty", ptm_level=2)

wind_drift_depth: float = Field(default=0.02, od_mapping="drift:wind_drift_depth", ptm_level=3)
wind_drift_depth: Optional[float] = Field(default=0.02, od_mapping="drift:wind_drift_depth", ptm_level=3)

vertical_mixing_timestep: int = Field(default=60, od_mapping="vertical_mixing:timestep", ptm_level=3)

Expand Down
41 changes: 7 additions & 34 deletions particle_tracking_manager/models/opendrift/opendrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from opendrift.readers import reader_ROMS_native

# from ...config_replacement import OpenDriftConfig
from ...config_ocean_model import _KNOWN_MODELS
from .config_opendrift import OpenDriftConfig
from ...the_manager import ParticleTrackingManager
from ...config_logging import LoggerMethods
Expand Down Expand Up @@ -148,15 +149,14 @@ def __init__(self, **kwargs):
# output_file was altered in PTM when setting up logger, so want to use
# that version.
# kwargs.update({"output_file": self.output_file})
keys_from_the_manager = ["use_cache", "stokes_drift", "do3D", "wind_drift_factor", "use_static_masks", "vertical_mixing", "ocean_model"]
keys_from_the_manager = ["use_cache", "stokes_drift", "do3D", "wind_drift_factor", "use_static_masks", "vertical_mixing"]
inputs = {key: getattr(self.manager_config,key) for key in keys_from_the_manager}
keys_from_ocean_model = ["model_drop_vars"]
keys_from_ocean_model = ["model_drop_vars", "ocean_model"]
inputs.update({key: getattr(self.ocean_model,key) for key in keys_from_ocean_model})
inputs.update(kwargs)
self.config = OpenDriftConfig(**inputs) # this runs both OpenDriftConfig and PTMConfig
# logger = self.config.logger # this is where logger is expected to be found
# import pdb; pdb.set_trace()

self._KNOWN_MODELS = self.manager_config.model_json_schema()['$defs']['OceanModelEnum']["enum"]

# self._setup_interpolator()

Expand All @@ -170,6 +170,7 @@ def __init__(self, **kwargs):

# LoggerMethods().merge_with_opendrift_log(logger)

# TODO: move these so they aren't initialized during __init__
self._create_opendrift_model_object()
self._update_od_config_from_this_config()
self._modify_opendrift_model_object()
Expand All @@ -189,37 +190,11 @@ def __init__(self, **kwargs):
self.checked_plot = False


# def _setup_interpolator(self):
# """Setup interpolator."""
# # TODO: this isn't working correctly at the moment

# if self.config.use_cache:
# # TODO: fix this for Ahmad
# cache_dir = Path(appdirs.user_cache_dir(appname="particle-tracking-manager", appauthor="axiom-data-science"))
# cache_dir.mkdir(parents=True, exist_ok=True)
# if self.config.interpolator_filename is None:
# self.config.interpolator_filename = cache_dir / Path(f"{self.manager_config.ocean_model.name}_interpolator").with_suffix(".pickle")
# else:
# self.config.interpolator_filename = Path(self.config.interpolator_filename).with_suffix(".pickle")
# self.save_interpolator = True

# # change interpolator_filename to string
# self.config.interpolator_filename = str(self.config.interpolator_filename)

# if Path(self.config.interpolator_filename).exists():
# logger.info(f"Loading the interpolator from {self.config.interpolator_filename}.")
# else:
# logger.info(f"A new interpolator will be saved to {self.config.interpolator_filename}.")
# else:
# self.save_interpolator = False
# logger.info("Interpolators will not be saved.")

def _create_opendrift_model_object(self):
# do this right away so I can query the object
# we don't actually input output_format here because we first output to netcdf, then
# resave as parquet after adding in extra config
# TODO: should drift_model be instantiated in OpenDriftConfig or here?
# import pdb; pdb.set_trace()
log_level = logger.level
if self.config.drift_model == "Leeway":
from opendrift.models.leeway import Leeway
Expand Down Expand Up @@ -259,7 +234,6 @@ def _update_od_config_from_this_config(self):
This uses the metadata key "od_mapping" to map from the PTM parameter
name to the OpenDrift parameter name.
"""
# import pdb; pdb.set_trace()

for key in self.config.model_fields:
if getattr(self.config.model_fields[key], "json_schema_extra") is not None:
Expand All @@ -269,7 +243,6 @@ def _update_od_config_from_this_config(self):
self.o._config[od_key]["value"] = getattr(self.config, key)

def _modify_opendrift_model_object(self):
# import pdb; pdb.set_trace()

# TODO: where to put these things
# turn on other things if using stokes_drift
Expand Down Expand Up @@ -322,7 +295,7 @@ def add_reader(
# TODO: have standard_name_mapping as an initial input only with initial call to OpenDrift?
# TODO: has ds as an initial input for user-input ds?
if (
self.manager_config.ocean_model not in self._KNOWN_MODELS
self.manager_config.ocean_model not in _KNOWN_MODELS
and self.manager_config.ocean_model != "test"
and ds is None
):
Expand Down Expand Up @@ -350,7 +323,7 @@ def add_reader(

# TODO: the stuff in apply_user_input_ocean_model_specific_changes can be moved to OceanModelConfig
# validation I think
if self.manager_config.ocean_model not in self._KNOWN_MODELS and self.manager_config.ocean_model != "test":
if self.manager_config.ocean_model not in _KNOWN_MODELS and self.manager_config.ocean_model != "test":
ds = apply_user_input_ocean_model_specific_changes(ds, self.manager_config.use_static_mask)

self.ds = ds
Expand Down
81 changes: 48 additions & 33 deletions particle_tracking_manager/models/opendrift/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pathlib import Path
import xarray as xr
import datetime
from datetime import datetime
import logging
import pandas as pd

Expand All @@ -14,7 +14,7 @@
from kerchunk.combine import MultiZarrToZarr


def narrow_dataset_to_simulation_time(ds: xr.Dataset, start_time: datetime.datetime, end_time: datetime.datetime) -> xr.Dataset:
def narrow_dataset_to_simulation_time(ds: xr.Dataset, start_time: datetime, end_time: datetime) -> xr.Dataset:
"""Narrow the dataset to the simulation time."""
try:
units = ds.ocean_time.attrs["units"]
Expand Down Expand Up @@ -114,31 +114,37 @@ def make_ciofs_kerchunk(start, end, name):

fs2 = fsspec.filesystem("") # local file system to save final jsons to

# select the single file Jsons to combine
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/*.json")
# ) # combine single json files

if name in ["ciofs", "ciofs_fresh"]:
json_list = fs2.glob(f"{output_dir_single_files}/*.json")
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/*.json")
# ) # combine single json files

# base for matching
def base_str(a_time):
return f"{output_dir_single_files}/{a_time}_*.json"
date_format = "%Y_0%j"

elif name == "aws_ciofs_with_angle":

# base for matching
def base_str(a_time):
return f"{output_dir_single_files}/ciofs_{a_time}-*.json"
date_format = "ciofs_%Y-%m-%d"
else:
raise ValueError(f"Name {name} not recognized")

# only glob start and end year files, order isn't important
json_list = fs2.glob(base_str(start[:4]))
if end[:4] != start[:4]:
json_list += fs2.glob(base_str(end[:4]))

# forward in time
if end[:4] > start[:4]:
json_list = [
j for j in json_list if Path(j).stem >= start and Path(j).stem <= end
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() >= start and datetime.strptime(Path(j).stem, date_format).isoformat() <= end
]
elif name == "aws_ciofs_with_angle":
json_list = fs2.glob(f"{output_dir_single_files}/ciofs_*.json")
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/ciofs_*.json")
# ) # combine single json files
# backward in time
elif end[:4] < start[:4]:
json_list = [
j
for j in json_list
if Path(j).stem.split("_")[1] >= start and Path(j).stem.split("_")[1] <= end
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() <= start and datetime.strptime(Path(j).stem, date_format).isoformat() >= end
]
else:
raise ValueError(f"Name {name} not recognized")

if json_list == []:
raise ValueError(
Expand Down Expand Up @@ -278,17 +284,26 @@ def make_nwgoa_kerchunk(start, end):

fs2 = fsspec.filesystem("") # local file system to save final jsons to

# select the single file Jsons to combine
json_list = fs2.glob(f"{output_dir_single_files}/nwgoa*.json") # combine single json files
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/nwgoa*.json")
# ) # combine single json files
json_list = [
j
for j in json_list
if Path(j).stem.split("nwgoa_")[1] >= start
and Path(j).stem.split("nwgoa_")[1] <= end
]
# base for matching
def base_str(a_time):
return f"{output_dir_single_files}/nwgoa_{a_time}-*.json"
date_format = "nwgoa_%Y-%m-%d"

# only glob start and end year files, order isn't important
json_list = fs2.glob(base_str(start[:4]))
if end[:4] != start[:4]:
json_list += fs2.glob(base_str(end[:4]))

# forward in time
if end[:4] > start[:4]:
json_list = [
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() >= start and datetime.strptime(Path(j).stem, date_format).isoformat() <= end
]
# backward in time
elif end[:4] < start[:4]:
json_list = [
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() <= start and datetime.strptime(Path(j).stem, date_format).isoformat() >= end
]

if json_list == []:
raise ValueError(
Expand Down
Loading

0 comments on commit 313288c

Please sign in to comment.