Skip to content

Commit

Permalink
Added ability to input filepath for saving/reading interpolator file
Browse files Browse the repository at this point in the history
  • Loading branch information
kthyng committed Jan 30, 2025
1 parent c1e74c7 commit e38cd94
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 57 deletions.
33 changes: 18 additions & 15 deletions particle_tracking_manager/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,20 +135,23 @@ def main():
"output_file"
] = f"output-results_{datetime.utcnow():%Y-%m-%dT%H%M:%SZ}.nc"

log_file = args.kwargs["output_file"].replace(".nc", ".log")
# log_file = args.kwargs["output_file"].replace(".nc", ".log")

# Convert the string representation of the dictionary to an actual dictionary
# not clear why I can't use `args.plots` in here but it isn't working
plots = ast.literal_eval(parser.parse_args().plots)
if parser.parse_args().plots is not None:
plots = ast.literal_eval(parser.parse_args().plots)
else:
plots = None

# Create a file handler
file_handler = logging.FileHandler(log_file)
# # Create a file handler
# file_handler = logging.FileHandler(log_file)

# Create a formatter and add it to the handler
formatter = logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
file_handler.setFormatter(formatter)
# # Create a formatter and add it to the handler
# formatter = logging.Formatter(
# "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
# )
# file_handler.setFormatter(formatter)

m = ptm.OpenDriftModel(**args.kwargs, plots=plots)

Expand All @@ -160,10 +163,10 @@ def main():

else:

# Add the handler to the logger
m.logger.addHandler(file_handler)
# # Add the handler to the logger
# m.logger.addHandler(file_handler)

m.logger.info(f"filename: {args.kwargs['output_file']}")
# m.logger.info(f"filename: {args.kwargs['output_file']}")

m.add_reader()
print(m.drift_model_config())
Expand All @@ -173,6 +176,6 @@ def main():

print(m.outfile_name)

# Remove the handler at the end of the loop
m.logger.removeHandler(file_handler)
file_handler.close()
# # Remove the handler at the end of the loop
# m.logger.removeHandler(file_handler)
# file_handler.close()
2 changes: 1 addition & 1 deletion particle_tracking_manager/models/opendrift/opendrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def __init__(
],
biodegradation: bool = config_model["biodegradation"]["default"],
log: str = config_model["log"]["default"],
plots: dict = config_model["plots"]["default"],
plots: Optional[dict] = config_model["plots"]["default"],
**kw,
) -> None:
"""Inputs for OpenDrift model."""
Expand Down
104 changes: 66 additions & 38 deletions particle_tracking_manager/the_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ class ParticleTrackingManager:
Name of input/output module type to use for writing Lagrangian model output. Default is "netcdf".
use_cache : bool
Set to True to use cache for saving interpolators, by default True.
interpolator_filename : Optional[Union[pathlib.Path,str]], optional
Filename to save interpolators to, by default None. The full path should be given, but no suffix.
Use this to either read from an existing file at a non-default location or to save to a
non-default location. If None and use_cache==True, the filename is set to a built-in name to an
`appdirs` cache directory.
Notes
-----
Expand Down Expand Up @@ -198,6 +203,9 @@ def __init__(
output_file: Optional[str] = config_ptm["output_file"]["default"],
output_format: str = config_ptm["output_format"]["default"],
use_cache: bool = config_ptm["use_cache"]["default"],
interpolator_filename: Optional[Union[pathlib.Path, str]] = config_ptm[
"interpolator_filename"
]["default"],
**kw,
) -> None:
"""Inputs necessary for any particle tracking."""
Expand Down Expand Up @@ -231,24 +239,16 @@ def __init__(

self.output_file_initial = None

# Set all attributes which will trigger some checks and changes in __setattr__
# these will also update "value" in the config dict
for key in sig.parameters.keys():
# no need to run through for init if value is None (already set to None)
if locals()[key] is not None:
self.__setattr__(key, locals()[key])

self.kw = kw
if output_file is None:
output_file = f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"

if self.__dict__["output_file"] is None:
self.__dict__[
"output_file"
] = f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
# want output_file to not include any suffix
output_file = output_file.rstrip(".nc").rstrip(".parq")

## set up log for this simulation
# Create a file handler
assert self.__dict__["output_file"] is not None
logfile_name = self.__dict__["output_file"] + ".log"
assert output_file is not None
logfile_name = output_file + ".log"
self.file_handler = logging.FileHandler(logfile_name)
self.logfile_name = logfile_name

Expand All @@ -264,6 +264,56 @@ def __init__(
self.logger.info(f"filename: {logfile_name}")
##

if interpolator_filename is not None and not use_cache:
raise ValueError(
"If interpolator_filename is input, use_cache must be True."
)

# deal with caching/interpolators
# save interpolators to save time
if use_cache:
cache_dir = pathlib.Path(
appdirs.user_cache_dir(
appname="particle-tracking-manager",
appauthor="axiom-data-science",
)
)
cache_dir.mkdir(parents=True, exist_ok=True)
cache_dir = cache_dir
if interpolator_filename is None:
interpolator_filename = cache_dir / pathlib.Path(
f"{ocean_model}_interpolator"
).with_suffix(".pickle")
else:
interpolator_filename = pathlib.Path(interpolator_filename).with_suffix(
".pickle"
)
interpolator_filename = str(interpolator_filename)
self.save_interpolator = True
# if interpolator_filename already exists, load that
if pathlib.Path(interpolator_filename).exists():
self.logger.info(
f"Loading the interpolator from {interpolator_filename}."
)
else:
self.logger.info(
f"A new interpolator will be saved to {interpolator_filename}."
)
else:
self.save_interpolator = False
# this is already None
# self.interpolator_filename = None
self.logger.info("Interpolators will not be saved.")

# Set all attributes which will trigger some checks and changes in __setattr__
# these will also update "value" in the config dict
for key in sig.parameters.keys():
# no need to run through for init if value is None (already set to None)
if locals()[key] is not None:
self.__setattr__(key, locals()[key])

self.kw = kw

def __setattr_model__(self, name: str, value) -> None:
"""Implement this in model class to add specific __setattr__ there too."""
pass
Expand Down Expand Up @@ -383,7 +433,8 @@ def __setattr__(self, name: str, value) -> None:
# by this point, output_file should already be a filename like what is
# available here, from OpenDrift (if run from there)
if self.output_file is not None:
output_file = self.output_file.rstrip(".nc")
output_file = self.output_file
# output_file = self.output_file.rstrip(".nc")
else:
output_file = (
f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
Expand Down Expand Up @@ -477,29 +528,6 @@ def __setattr__(self, name: str, value) -> None:
)
self.seed_seafloor = False

# save interpolators to save time
if name == "use_cache":
if value:
cache_dir = pathlib.Path(
appdirs.user_cache_dir(
appname="particle-tracking-manager",
appauthor="axiom-data-science",
)
)
cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_dir = cache_dir
self.interpolator_filename = cache_dir / pathlib.Path(
f"{self.ocean_model}_interpolator"
)
self.save_interpolator = True
self.logger.info(
f"Interpolators will be saved to {self.interpolator_filename}."
)
else:
self.save_interpolator = False
self.interpolator_filename = None
self.logger.info("Interpolators will not be saved.")

# if reader, lon, and lat set, check inputs
if (
name == "has_added_reader"
Expand Down
6 changes: 6 additions & 0 deletions particle_tracking_manager/the_manager_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -193,5 +193,11 @@
"default": true,
"description": "Set to True to use cache for storing interpolators. This saves time on repeat simulations, may be used for other items in the future.",
"ptm_level": 3
},
"interpolator_filename": {
"type": "str",
"default": "None",
"description": "Filename to save interpolator to. The full path should be given, but no suffix. Use this to either read from an existing file at a non-default location or to save to a non-default location. If None and use_cache==True, default name is used. ",
"ptm_level": 3
}
}
19 changes: 19 additions & 0 deletions tests/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,5 +387,24 @@ def test_start_time_none(self):
self.m.seed()


def test_interpolator_filename():
with pytest.raises(ValueError):
m = ptm.OpenDriftModel(interpolator_filename="test", use_cache=False)

m = ptm.OpenDriftModel(interpolator_filename="test")
assert m.interpolator_filename == "test.pickle"


def test_log_name():
m = ptm.OpenDriftModel(output_file="newtest")
assert m.logfile_name == "newtest.log"

m = ptm.OpenDriftModel(output_file="newtest.nc")
assert m.logfile_name == "newtest.log"

m = ptm.OpenDriftModel(output_file="newtest.parq")
assert m.logfile_name == "newtest.log"


if __name__ == "__main__":
unittest.main()
23 changes: 20 additions & 3 deletions tests/test_realistic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test realistic scenarios, which are slower."""

import pickle

import pytest
import xarray as xr

Expand Down Expand Up @@ -42,16 +44,31 @@ def test_run_parquet():
def test_run_netcdf():
"""Set up and run."""

import tempfile

import xroms

seeding_kwargs = dict(lon=-90, lat=28.7, number=1)
manager = ptm.OpenDriftModel(
**seeding_kwargs, use_static_masks=True, steps=2, output_format="netcdf"
)
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
manager = ptm.OpenDriftModel(
**seeding_kwargs,
use_static_masks=True,
steps=2,
output_format="netcdf",
use_cache=True,
interpolator_filename=temp_file.name
)
url = xroms.datasets.CLOVER.fetch("ROMS_example_full_grid.nc")
ds = xr.open_dataset(url, decode_times=False)
manager.add_reader(ds=ds, name="txla")
manager.seed()
manager.run()

assert "nc" in manager.o.outfile_name
assert manager.interpolator_filename == temp_file.name + ".pickle"

# Replace 'path_to_pickle_file.pkl' with the actual path to your pickle file
with open(manager.interpolator_filename, "rb") as file:
data = pickle.load(file)
assert "spl_x" in data
assert "spl_y" in data

0 comments on commit e38cd94

Please sign in to comment.