Skip to content


WIP: have separated ocean_model pretty well
Browse files Browse the repository at this point in the history
next is finish going through and fix config.
Also need to figure out how to handle when user inputs their own model Dataset.
  • Loading branch information
kthyng committed Mar 4, 2025
1 parent abdb4a3 commit 313288c
Show file tree
Hide file tree
Showing 11 changed files with 375 additions and 751 deletions.
405 changes: 0 additions & 405 deletions particle_tracking_manager/

This file was deleted.

2 changes: 1 addition & 1 deletion particle_tracking_manager/
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pydantic.fields import FieldInfo
from typing_extensions import Self

from .utils import calc_known_horizontal_diffusivity
# from .utils import calc_known_horizontal_diffusivity
from .models.opendrift.utils import make_nwgoa_kerchunk, make_ciofs_kerchunk
from .config_the_manager import TheManagerConfig

Expand Down
288 changes: 211 additions & 77 deletions particle_tracking_manager/

Large diffs are not rendered by default.

34 changes: 17 additions & 17 deletions particle_tracking_manager/
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from pydantic.fields import FieldInfo
from typing_extensions import Self

from .utils import calc_known_horizontal_diffusivity
# from .utils import calc_known_horizontal_diffusivity
from .models.opendrift.utils import make_nwgoa_kerchunk, make_ciofs_kerchunk
import logging

Expand Down Expand Up @@ -53,30 +53,30 @@ class SeedFlagEnum(str, Enum):

# Enum for "ocean_model"
class OceanModelEnum(str, Enum):
# # Enum for "ocean_model"
# class OceanModelEnum(str, Enum):

class TheManagerConfig(BaseModel):
model: ModelEnum = Field(ModelEnum.opendrift, description="Lagrangian model software to use for simulation.", ptm_level=1)
lon: float = Field(-151.0, ge=-180, le=180, description="Central longitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_east")
lat: float = Field(58.0, ge=-90, le=90, description="Central latitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_north")
lon: Optional[float] = Field(-151.0, ge=-180, le=180, description="Central longitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_east")
lat: Optional[float] = Field(58.0, ge=-90, le=90, description="Central latitude for seeding drifters. Only used if `seed_flag==\"elements\"`.", ptm_level=1, units="degrees_north")
geojson: Optional[dict] = Field(None, description="GeoJSON describing a polygon within which to seed drifters. To use this parameter, also have `seed_flag==\"geojson\"`.", ptm_level=1)
seed_flag: SeedFlagEnum = Field(SeedFlagEnum.elements, description="Method for seeding drifters. Options are \"elements\" or \"geojson\". If \"elements\", seed drifters at or around a single point defined by lon and lat. If \"geojson\", seed drifters within a polygon described by a GeoJSON object.", ptm_level=1)
number: int = Field(100, description="Number of drifters to seed.", ptm_level=1, od_mapping="seed:number")
start_time: datetime = Field(datetime(2022,1,1), description="Start time for drifter simulation.", ptm_level=1)
start_time: Optional[datetime] = Field(datetime(2022,1,1), description="Start time for drifter simulation.", ptm_level=1)
start_time_end: Optional[datetime] = Field(None, description="If used, this creates a range of start times for drifters, starting with `start_time` and ending with `start_time_end`. Drifters will be initialized linearly between the two start times.", ptm_level=2)
run_forward: bool = Field(True, description="Run forward in time.", ptm_level=2)
time_step: int = Field(300, ge=1, le=86400, description="Interval between particles updates, in seconds.", ptm_level=3, units="seconds")
time_step_output: int = Field(3600, ge=1, le=604800, description="Time step at which element properties are stored and eventually written to file. Must be a multiple of time_step.", ptm_level=3, units="seconds")
steps: Optional[int] = Field(None, ge=1, le=10000, description="Maximum number of steps. End of simulation will be start_time + steps * time_step.", ptm_level=1)
duration: Optional[timedelta] = Field(None, description="The length of the simulation. steps, end_time, or duration must be input by user.", ptm_level=1)
end_time: Optional[datetime] = Field(None, description="The end of the simulation. steps, end_time, or duration must be input by user.", ptm_level=1)
ocean_model: OceanModelEnum = Field(OceanModelEnum.CIOFSOP, description="Name of ocean model to use for driving drifter simulation.", ptm_level=1)
# ocean_model: OceanModelEnum = Field(OceanModelEnum.CIOFSOP, description="Name of ocean model to use for driving drifter simulation.", ptm_level=1)
# NWGOA_time_range: Optional[datetime] = Field(None, description="Time range for NWGOA ocean model.", ptm_level=1)
# CIOFS_time_range: Optional[datetime] = Field(None, description="Time range for CIOFS ocean model.", ptm_level=1)
# CIOFSOP_time_range: Optional[datetime] = Field(None, description="Time range for CIOFSOP ocean model.", ptm_level=1)
Expand All @@ -90,15 +90,15 @@ class TheManagerConfig(BaseModel):
surface_only: Optional[bool] = Field(None, description="Set to True to keep drifters at the surface.", ptm_level=1)
do3D: bool = Field(False, description="Set to True to run drifters in 3D, by default False. This is overridden if surface_only==True.", ptm_level=1)
vertical_mixing: bool = Field(False, description="Set to True to activate vertical mixing in the simulation.", ptm_level=2)
z: float = Field(0, ge=-100000, le=0, description="Depth of the drifters.", ptm_level=1, od_mapping="seed:z")
z: Optional[float] = Field(0, ge=-100000, le=0, description="Depth of the drifters. None to use `seed_seafloor` flag.", ptm_level=1, od_mapping="seed:z")
seed_seafloor: bool = Field(False, description="Set to True to seed drifters on the seafloor.", ptm_level=2, od_mapping="seed:seafloor")
use_static_masks: bool = Field(True, description="Set to True to use static masks for known models instead of wetdry masks.", ptm_level=3)
# output_file: Optional[str] = Field(None, description="Name of file to write output to. If None, default name is used.", ptm_level=3)
# output_format: OutputFormatEnum = Field(OutputFormatEnum.netcdf, description="Output file format. Options are \"netcdf\" or \"parquet\".", ptm_level=2)
use_cache: bool = Field(True, description="Set to True to use cache for storing interpolators.", ptm_level=3)
wind_drift_factor: float = Field(0.02, description="Wind drift factor for the drifters.", ptm_level=2, od_mapping="seed:wind_drift_factor")
wind_drift_factor: Optional[float] = Field(0.02, description="Wind drift factor for the drifters.", ptm_level=2, od_mapping="seed:wind_drift_factor")
stokes_drift: bool = Field(True, description="Set to True to enable Stokes drift.", ptm_level=2, od_mapping="drift:stokes_drift")
horizontal_diffusivity: Optional[float] = Field(None, description="Horizontal diffusivity for the simulation.", ptm_level=2, od_mapping="drift:horizontal_diffusivity")
# horizontal_diffusivity: Optional[float] = Field(None, description="Horizontal diffusivity for the simulation.", ptm_level=2, od_mapping="drift:horizontal_diffusivity")
# log_level: LogLevelEnum = Field(LogLevelEnum.INFO, description="Log verbosity", ptm_level=3)

class Config:
Expand Down Expand Up @@ -257,7 +257,7 @@ def check_config_do3D(self) -> Self:

# elif self.ocean_model in _KNOWN_MODELS:

# hdiff = 0 # TODO: UPDATE THIS BACK calc_known_horizontal_diffusivity(self.ocean_model)
# hdiff = calc_known_horizontal_diffusivity(self.ocean_model)
# f"Setting horizontal_diffusivity parameter to one tuned to reader model of value {hdiff}."
# )
Expand All @@ -281,11 +281,11 @@ def check_config_do3D(self) -> Self:
def check_config_ocean_model_local(self) -> Self:
if self.ocean_model_local:
f"Using local output for ocean_model {self.ocean_model}"
"Using local output for ocean_model."
f"Using remote output for ocean_model {self.ocean_model}"
"Using remote output for ocean_model."
return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class OpenDriftConfig(BaseModel):

wind_uncertainty: float = Field(default=0.0, value=0.0, od_mapping="drift:wind_uncertainty", ptm_level=2)

wind_drift_depth: float = Field(default=0.02, od_mapping="drift:wind_drift_depth", ptm_level=3)
wind_drift_depth: Optional[float] = Field(default=0.02, od_mapping="drift:wind_drift_depth", ptm_level=3)

vertical_mixing_timestep: int = Field(default=60, od_mapping="vertical_mixing:timestep", ptm_level=3)

Expand Down
41 changes: 7 additions & 34 deletions particle_tracking_manager/models/opendrift/
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from opendrift.readers import reader_ROMS_native

# from ...config_replacement import OpenDriftConfig
from ...config_ocean_model import _KNOWN_MODELS
from .config_opendrift import OpenDriftConfig
from ...the_manager import ParticleTrackingManager
from ...config_logging import LoggerMethods
Expand Down Expand Up @@ -148,15 +149,14 @@ def __init__(self, **kwargs):
# output_file was altered in PTM when setting up logger, so want to use
# that version.
# kwargs.update({"output_file": self.output_file})
keys_from_the_manager = ["use_cache", "stokes_drift", "do3D", "wind_drift_factor", "use_static_masks", "vertical_mixing", "ocean_model"]
keys_from_the_manager = ["use_cache", "stokes_drift", "do3D", "wind_drift_factor", "use_static_masks", "vertical_mixing"]
inputs = {key: getattr(self.manager_config,key) for key in keys_from_the_manager}
keys_from_ocean_model = ["model_drop_vars"]
keys_from_ocean_model = ["model_drop_vars", "ocean_model"]
inputs.update({key: getattr(self.ocean_model,key) for key in keys_from_ocean_model})
self.config = OpenDriftConfig(**inputs) # this runs both OpenDriftConfig and PTMConfig
# logger = self.config.logger # this is where logger is expected to be found
# import pdb; pdb.set_trace()

self._KNOWN_MODELS = self.manager_config.model_json_schema()['$defs']['OceanModelEnum']["enum"]

# self._setup_interpolator()

Expand All @@ -170,6 +170,7 @@ def __init__(self, **kwargs):

# LoggerMethods().merge_with_opendrift_log(logger)

# TODO: move these so they aren't initialized during __init__
Expand All @@ -189,37 +190,11 @@ def __init__(self, **kwargs):
self.checked_plot = False

# def _setup_interpolator(self):
# """Setup interpolator."""
# # TODO: this isn't working correctly at the moment

# if self.config.use_cache:
# # TODO: fix this for Ahmad
# cache_dir = Path(appdirs.user_cache_dir(appname="particle-tracking-manager", appauthor="axiom-data-science"))
# cache_dir.mkdir(parents=True, exist_ok=True)
# if self.config.interpolator_filename is None:
# self.config.interpolator_filename = cache_dir / Path(f"{}_interpolator").with_suffix(".pickle")
# else:
# self.config.interpolator_filename = Path(self.config.interpolator_filename).with_suffix(".pickle")
# self.save_interpolator = True

# # change interpolator_filename to string
# self.config.interpolator_filename = str(self.config.interpolator_filename)

# if Path(self.config.interpolator_filename).exists():
#"Loading the interpolator from {self.config.interpolator_filename}.")
# else:
#"A new interpolator will be saved to {self.config.interpolator_filename}.")
# else:
# self.save_interpolator = False
#"Interpolators will not be saved.")

def _create_opendrift_model_object(self):
# do this right away so I can query the object
# we don't actually input output_format here because we first output to netcdf, then
# resave as parquet after adding in extra config
# TODO: should drift_model be instantiated in OpenDriftConfig or here?
# import pdb; pdb.set_trace()
log_level = logger.level
if self.config.drift_model == "Leeway":
from opendrift.models.leeway import Leeway
Expand Down Expand Up @@ -259,7 +234,6 @@ def _update_od_config_from_this_config(self):
This uses the metadata key "od_mapping" to map from the PTM parameter
name to the OpenDrift parameter name.
# import pdb; pdb.set_trace()

for key in self.config.model_fields:
if getattr(self.config.model_fields[key], "json_schema_extra") is not None:
Expand All @@ -269,7 +243,6 @@ def _update_od_config_from_this_config(self):
self.o._config[od_key]["value"] = getattr(self.config, key)

def _modify_opendrift_model_object(self):
# import pdb; pdb.set_trace()

# TODO: where to put these things
# turn on other things if using stokes_drift
Expand Down Expand Up @@ -322,7 +295,7 @@ def add_reader(
# TODO: have standard_name_mapping as an initial input only with initial call to OpenDrift?
# TODO: has ds as an initial input for user-input ds?
if (
self.manager_config.ocean_model not in self._KNOWN_MODELS
self.manager_config.ocean_model not in _KNOWN_MODELS
and self.manager_config.ocean_model != "test"
and ds is None
Expand Down Expand Up @@ -350,7 +323,7 @@ def add_reader(

# TODO: the stuff in apply_user_input_ocean_model_specific_changes can be moved to OceanModelConfig
# validation I think
if self.manager_config.ocean_model not in self._KNOWN_MODELS and self.manager_config.ocean_model != "test":
if self.manager_config.ocean_model not in _KNOWN_MODELS and self.manager_config.ocean_model != "test":
ds = apply_user_input_ocean_model_specific_changes(ds, self.manager_config.use_static_mask)

self.ds = ds
Expand Down
81 changes: 48 additions & 33 deletions particle_tracking_manager/models/opendrift/
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pathlib import Path
import xarray as xr
import datetime
from datetime import datetime
import logging
import pandas as pd

Expand All @@ -14,7 +14,7 @@
from kerchunk.combine import MultiZarrToZarr

def narrow_dataset_to_simulation_time(ds: xr.Dataset, start_time: datetime.datetime, end_time: datetime.datetime) -> xr.Dataset:
def narrow_dataset_to_simulation_time(ds: xr.Dataset, start_time: datetime, end_time: datetime) -> xr.Dataset:
"""Narrow the dataset to the simulation time."""
units = ds.ocean_time.attrs["units"]
Expand Down Expand Up @@ -114,31 +114,37 @@ def make_ciofs_kerchunk(start, end, name):

fs2 = fsspec.filesystem("") # local file system to save final jsons to

# select the single file Jsons to combine
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/*.json")
# ) # combine single json files

if name in ["ciofs", "ciofs_fresh"]:
json_list = fs2.glob(f"{output_dir_single_files}/*.json")
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/*.json")
# ) # combine single json files

# base for matching
def base_str(a_time):
return f"{output_dir_single_files}/{a_time}_*.json"
date_format = "%Y_0%j"

elif name == "aws_ciofs_with_angle":

# base for matching
def base_str(a_time):
return f"{output_dir_single_files}/ciofs_{a_time}-*.json"
date_format = "ciofs_%Y-%m-%d"
raise ValueError(f"Name {name} not recognized")

# only glob start and end year files, order isn't important
json_list = fs2.glob(base_str(start[:4]))
if end[:4] != start[:4]:
json_list += fs2.glob(base_str(end[:4]))

# forward in time
if end[:4] > start[:4]:
json_list = [
j for j in json_list if Path(j).stem >= start and Path(j).stem <= end
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() >= start and datetime.strptime(Path(j).stem, date_format).isoformat() <= end
elif name == "aws_ciofs_with_angle":
json_list = fs2.glob(f"{output_dir_single_files}/ciofs_*.json")
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/ciofs_*.json")
# ) # combine single json files
# backward in time
elif end[:4] < start[:4]:
json_list = [
for j in json_list
if Path(j).stem.split("_")[1] >= start and Path(j).stem.split("_")[1] <= end
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() <= start and datetime.strptime(Path(j).stem, date_format).isoformat() >= end
raise ValueError(f"Name {name} not recognized")

if json_list == []:
raise ValueError(
Expand Down Expand Up @@ -278,17 +284,26 @@ def make_nwgoa_kerchunk(start, end):

fs2 = fsspec.filesystem("") # local file system to save final jsons to

# select the single file Jsons to combine
json_list = fs2.glob(f"{output_dir_single_files}/nwgoa*.json") # combine single json files
# json_list = sorted(
# fs2.glob(f"{output_dir_single_files}/nwgoa*.json")
# ) # combine single json files
json_list = [
for j in json_list
if Path(j).stem.split("nwgoa_")[1] >= start
and Path(j).stem.split("nwgoa_")[1] <= end
# base for matching
def base_str(a_time):
return f"{output_dir_single_files}/nwgoa_{a_time}-*.json"
date_format = "nwgoa_%Y-%m-%d"

# only glob start and end year files, order isn't important
json_list = fs2.glob(base_str(start[:4]))
if end[:4] != start[:4]:
json_list += fs2.glob(base_str(end[:4]))

# forward in time
if end[:4] > start[:4]:
json_list = [
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() >= start and datetime.strptime(Path(j).stem, date_format).isoformat() <= end
# backward in time
elif end[:4] < start[:4]:
json_list = [
j for j in json_list if datetime.strptime(Path(j).stem, date_format).isoformat() <= start and datetime.strptime(Path(j).stem, date_format).isoformat() >= end

if json_list == []:
raise ValueError(
Expand Down

0 comments on commit 313288c

Please sign in to comment.