Skip to content

Commit

Permalink
cleanup workdir correctly (stac-utils#70)
Browse files Browse the repository at this point in the history
* cleanup workdir correctly
  • Loading branch information
Phil Varner authored Dec 20, 2023
1 parent cbf9d00 commit 45b12e2
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 38 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,17 @@ repos:
args: [--ignore-words=.codespellignore]
types_or: [jupyter, markdown, python, shell]
- repo: https://github.com/psf/black
rev: 23.9.1
rev: 23.12.0
hooks:
- id: black
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.6.0
rev: v1.7.1
hooks:
- id: mypy
additional_dependencies:
- pytest
- types-setuptools == 65.7.0.3
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.292
rev: v0.1.8
hooks:
- id: ruff
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## Unreleased - TBD

### Changed

- handler now explicitly calls performs workdir cleanup
- workdir cleanup is correctly defensive and logs errors

## [v0.2.0] - 2023-11-16

### Changed
Expand Down
65 changes: 37 additions & 28 deletions stactask/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import json
import logging
import os
import sys
import warnings
from abc import ABC, abstractmethod
Expand Down Expand Up @@ -64,13 +65,9 @@ def __init__(
skip_upload: bool = False,
skip_validation: bool = False,
):
# set up logger
self.logger = logging.getLogger(self.name)

# set this to avoid confusion in destructor if called during validation
self._save_workdir = True

# validate input payload...or not
# validate input payload... or not
if not skip_validation:
if not self.validate(payload):
raise FailedValidation()
Expand All @@ -90,12 +87,6 @@ def __init__(
# if a workdir was specified we don't want to rm by default
self._save_workdir = save_workdir if save_workdir is not None else True

def __del__(self) -> None:
# remove work directory if not running locally
if not self._save_workdir:
self.logger.debug("Removing work directory %s", self._workdir)
rmtree(self._workdir)

@property
def process_definition(self) -> Dict[str, Any]:
process = self._payload.get("process", {})
Expand Down Expand Up @@ -198,6 +189,21 @@ def add_software_version_to_item(cls, item: Dict[str, Any]) -> Dict[str, Any]:
item["properties"]["processing:software"] = {cls.name: cls.version}
return item

def cleanup_workdir(self) -> None:
"""Remove work directory if configured not to save it"""
try:
if (
not self._save_workdir
and self._workdir
and os.path.exists(self._workdir)
):
self.logger.debug("Removing work directory %s", self._workdir)
rmtree(self._workdir)
except Exception as e:
self.logger.warning(
"Failed removing work directory %s: %s", self._workdir, e
)

def assign_collections(self) -> None:
"""Assigns new collection names based on"""
for i, (coll, expr) in itertools.product(
Expand Down Expand Up @@ -305,24 +311,27 @@ def post_process_item(self, item: Dict[str, Any]) -> Dict[str, Any]:

@classmethod
def handler(cls, payload: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]:
if "href" in payload or "url" in payload:
# read input
with fsspec.open(payload.get("href", payload.get("url"))) as f:
payload = json.loads(f.read())

task = cls(payload, **kwargs)
try:
items = list()
for item in task.process(**task.parameters):
items.append(task.post_process_item(item))

task._payload["features"] = items
task.assign_collections()

return task._payload
except Exception as err:
task.logger.error(err, exc_info=True)
raise err
if "href" in payload or "url" in payload:
# read input
with fsspec.open(payload.get("href", payload.get("url"))) as f:
payload = json.loads(f.read())

task = cls(payload, **kwargs)
try:
items = list()
for item in task.process(**task.parameters):
items.append(task.post_process_item(item))

task._payload["features"] = items
task.assign_collections()

return task._payload
except Exception as err:
task.logger.error(err, exc_info=True)
raise err
finally:
task.cleanup_workdir()

@classmethod
def parse_args(cls, args: List[str]) -> Dict[str, Any]:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,15 @@ def test_edit_items2(nothing_task: Task) -> None:

@pytest.mark.parametrize("save_workdir", [False, True, None])
def test_tmp_workdir(items: Dict[str, Any], save_workdir: Optional[bool]) -> None:
nothing_task = NothingTask(items, save_workdir=save_workdir)
t = NothingTask(items, save_workdir=save_workdir)
expected = save_workdir if save_workdir is not None else False
assert nothing_task._save_workdir is expected
workdir = nothing_task._workdir
assert t._save_workdir is expected
workdir = t._workdir
assert workdir.parts[-1].startswith("tmp")
assert workdir.is_absolute() is True
assert workdir.is_dir() is True
del nothing_task
assert workdir.is_dir() is expected
t.cleanup_workdir()
assert workdir.exists() is expected


@pytest.mark.parametrize("save_workdir", [False, True, None])
Expand All @@ -80,8 +80,8 @@ def test_workdir(
assert workdir.parts[-1] == "test_task"
assert workdir.is_absolute() is True
assert workdir.is_dir() is True
del t
assert workdir.is_dir() is expected
t.cleanup_workdir()
assert workdir.exists() is expected


def test_parameters(items: Dict[str, Any]) -> None:
Expand Down

0 comments on commit 45b12e2

Please sign in to comment.