Skip to content

Commit

Permalink
Improve logging and add dry-run feature.
Browse files Browse the repository at this point in the history
  • Loading branch information
GeigerJ2 committed Feb 5, 2025
1 parent e09e078 commit 48acce7
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 45 deletions.
36 changes: 25 additions & 11 deletions src/aiida/cmdline/commands/cmd_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiida.common import exceptions
from aiida.manage.configuration import Profile, create_profile, get_config
from aiida.tools.dumping import ProcessDumper, ProfileDumper
from aiida.tools.dumping.logger import DumpLogger


@verdi.group('profile')
Expand Down Expand Up @@ -306,6 +307,7 @@ def profile_mirror(
):
"""Dump all data in an AiiDA profile's storage to disk."""

import json
from datetime import datetime
from pathlib import Path

Expand All @@ -319,17 +321,6 @@ def profile_mirror(
if path is None:
path = Path.cwd() / f'{profile.name}-mirror'

# TODO: Implement proper dry-run feature
dry_run_message = f"Dry run for dumping of profile `{profile.name}`'s data at path: `{path}`.\n"
dry_run_message += 'Only directories will be created.'

if dry_run:
echo.echo_report(dry_run_message)
return

else:
echo.echo_report(f"Dumping of profile `{profile.name}`'s data at path: `{path}`.")

SAFEGUARD_FILE: str = '.verdi_profile_mirror' # noqa: N806
safeguard_file_path: Path = path / SAFEGUARD_FILE

Expand All @@ -349,6 +340,24 @@ def profile_mirror(
except IndexError:
last_dump_time = None

if dry_run:
node_counts = ProfileDumper._get_number_of_nodes_to_dump(last_dump_time)
node_counts_str = ' & '.join(f'{count} {node_type}' for node_type, count in node_counts.items())
dry_run_message = f'Dry run for mirroring of profile `{profile.name}`: {node_counts_str} to dump.\n'
echo.echo_report(dry_run_message)
return

echo.echo_report(f"Dumping of profile `{profile.name}`'s data at path: `{path}`.")

if incremental:
msg = 'Incremental dumping selected. Will update directory.'
echo.echo_report(msg)

try:
dump_logger = DumpLogger.from_file(dump_parent_path=path)
except (json.JSONDecodeError, OSError):
dump_logger = DumpLogger(dump_parent_path=path)

base_dumper = BaseDumper(
dump_parent_path=path,
overwrite=overwrite,
Expand All @@ -368,6 +377,7 @@ def profile_mirror(
profile_dumper = ProfileDumper(
base_dumper=base_dumper,
process_dumper=process_dumper,
dump_logger=dump_logger,
groups=groups,
organize_by_groups=organize_by_groups,
deduplicate=deduplicate,
Expand All @@ -381,3 +391,7 @@ def profile_mirror(
last_dump_time = datetime.now().astimezone()
with safeguard_file_path.open('a') as fhandle:
fhandle.write(f'Last profile mirror time: {last_dump_time.isoformat()}\n')

dump_logger.save_log()

echo.echo_report(f'Dumped {dump_logger.counter} new nodes.')
55 changes: 36 additions & 19 deletions src/aiida/tools/dumping/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
from __future__ import annotations

import os
from datetime import datetime
from functools import cached_property
from pathlib import Path

from aiida import orm
from aiida.common.log import AIIDA_LOGGER
from aiida.tools.dumping.base import BaseDumper
from aiida.tools.dumping.logger import DumpLogger
from aiida.tools.dumping.logger import DumpLog, DumpLogger
from aiida.tools.dumping.process import ProcessDumper
from aiida.tools.dumping.utils import filter_by_last_dump_time

Expand Down Expand Up @@ -47,28 +48,38 @@ def __init__(

self.base_dumper = base_dumper or BaseDumper()
self.process_dumper = process_dumper or ProcessDumper()
self.dump_logger = dump_logger or DumpLogger()
self.dump_logger = dump_logger or DumpLogger(dump_parent_path=self.base_dumper.dump_parent_path)

# Properly set the `output_path` attribute

self.output_path = Path(output_path or self.base_dumper.dump_parent_path)

@cached_property
def nodes(self):
def nodes(self) -> list[str]:
return self._get_nodes()

def _get_nodes(self):
def _get_nodes(self) -> list[str]:
nodes: list[str] | None = None
if isinstance(self.collection, orm.Group):
nodes: list[str] = list(self.collection.nodes)
nodes = [n.uuid for n in list(self.collection.nodes)]
elif isinstance(self.collection, list) and len(self.collection) > 0:
if all(isinstance(n, str) for n in self.collection):
nodes = self.collection
else:
msg = 'A collection of nodes must be passed via their UUIDs.'
raise TypeError(msg)
else:
nodes = []

return filter_by_last_dump_time(nodes=nodes, last_dump_time=self.base_dumper.last_dump_time)
filtered_nodes = filter_by_last_dump_time(nodes=nodes, last_dump_time=self.base_dumper.last_dump_time)
return filtered_nodes

def _should_dump_processes(self, nodes: list[orm.Node] | None = None) -> bool:
def _should_dump_processes(self, nodes: list[str] | None = None) -> bool:
test_nodes = nodes or self.nodes
return len([node for node in test_nodes if isinstance(node, orm.ProcessNode)]) > 0
return len([node for node in test_nodes if isinstance(orm.load_node(node), orm.ProcessNode)]) > 0

def _get_processes(self):
nodes = self.nodes
nodes = [orm.load_node(n) for n in self.nodes]
workflows = [node for node in nodes if isinstance(node, orm.WorkflowNode)]

# Make sure that only top-level workflows are dumped in their own directories when de-duplcation is enabled
Expand Down Expand Up @@ -99,26 +110,31 @@ def _dump_processes(self):
self._dump_workflows()

def _dump_calculations(self):
if len(self.calculations) == 0:
return
calculations_path = self.output_path / 'calculations'
dumped_calculations = {}

for calculation in self.calculations:
calculation_dumper = self.process_dumper

calculation_dump_path = calculations_path / calculation_dumper._generate_default_dump_path(
process_node=calculation, prefix=''
process_node=calculation, prefix=None
)

if calculation.caller is None:
# or (calculation.caller is not None and not self.deduplicate):
calculation_dumper._dump_calculation(calculation_node=calculation, output_path=calculation_dump_path)

dumped_calculations[calculation.uuid] = calculation_dump_path
dumped_calculations[calculation.uuid] = DumpLog(
path=calculation_dump_path,
time=datetime.now().astimezone(),
)

self.dump_logger.update_calculations(dumped_calculations)

def _dump_workflows(self):
# workflow_nodes = get_nodes_from_db(aiida_node_type=orm.WorkflowNode, with_group=self.group, flat=True)
def _dump_workflows(self) -> None:
if len(self.workflows) == 0:
return
workflow_path = self.output_path / 'workflows'
workflow_path.mkdir(exist_ok=True, parents=True)
dumped_workflows = {}
Expand All @@ -130,22 +146,23 @@ def _dump_workflows(self):
process_node=workflow, prefix=None
)

logged_workflows = self.dump_logger.get_logs()['workflows']
logged_workflows = self.dump_logger.get_log()['workflows']

if self.deduplicate and workflow.uuid in logged_workflows.keys():
os.symlink(
src=logged_workflows[workflow.uuid],
src=logged_workflows[workflow.uuid].path,
dst=workflow_dump_path,
)
else:
workflow_dumper._dump_workflow(
workflow_node=workflow,
output_path=workflow_dump_path,
# link_calculations=not self.deduplicate,
# link_calculations_dir=self.output_path / 'calculations',
)

dumped_workflows[workflow.uuid] = workflow_dump_path
dumped_workflows[workflow.uuid] = DumpLog(
path=workflow_dump_path,
time=datetime.now().astimezone(),
)

self.dump_logger.update_workflows(dumped_workflows)

Expand Down
98 changes: 88 additions & 10 deletions src/aiida/tools/dumping/logger.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,96 @@
import json
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import TypeAlias


@dataclass
class DumpLog:
"""Represents a single dump log entry."""

path: Path
time: datetime


DumpDict: TypeAlias = dict[str, DumpLog]


class DumpLogger:
def __init__(self):
self.log_dict: dict[str, dict[str, Path]] = {'calculations': {}, 'workflows': {}}
"""Main logger class using dataclasses for better structure."""

def update_calculations(self, new_calculations: dict[str, Path]):
"""Update the log with new calculations."""
self.log_dict['calculations'].update(new_calculations)
DUMP_FILE: str = '.dump_log.json'

def update_workflows(self, new_workflows: dict[str, Path]):
"""Update the log with new workflows."""
self.log_dict['workflows'].update(new_workflows)
def __init__(
self,
dump_parent_path: Path,
calculations: DumpDict | None = None,
workflows: DumpDict | None = None,
counter: int = 0,
) -> None:
self.dump_parent_path = dump_parent_path
self.calculations = calculations or {}
self.workflows = workflows or {}
self.counter = 0

def get_logs(self):
@property
def dump_file(self) -> Path:
"""Get the path to the dump file."""
return self.dump_parent_path / self.DUMP_FILE

def update_calculations(self, new_calculations: DumpDict) -> None:
"""Update the calculations log."""
self.calculations.update(new_calculations)
self.counter += len(new_calculations)

def update_workflows(self, new_workflows: DumpDict) -> None:
"""Update the workflows log."""
self.workflows.update(new_workflows)
self.counter += len(new_workflows)

def get_log(self) -> dict[str, DumpDict]:
"""Retrieve the current state of the log."""
return self.log_dict
return {'calculations': self.calculations, 'workflows': self.workflows}

def save_log(self) -> None:
"""Save the log to a JSON file."""

def serialize_logs(logs: DumpDict) -> dict:
serialized = {}
for uuid, entry in logs.items():
serialized[uuid] = {'path': str(entry.path), 'time': entry.time.isoformat()}
return serialized

log_dict = {
'calculations': serialize_logs(self.calculations),
'workflows': serialize_logs(self.workflows),
}

with self.dump_file.open('w', encoding='utf-8') as f:
json.dump(log_dict, f, indent=4)

@classmethod
def from_file(cls, dump_parent_path: Path) -> 'DumpLogger':
"""Alternative constructor to load from an existing JSON file."""
instance = cls(dump_parent_path=dump_parent_path)

if not instance.dump_file.exists():
return instance

try:
with instance.dump_file.open('r', encoding='utf-8') as f:
data = json.load(f)

def deserialize_logs(category_data: dict) -> DumpDict:
deserialized = {}
for uuid, entry in category_data.items():
deserialized[uuid] = DumpLog(path=Path(entry['path']), time=datetime.fromisoformat(entry['time']))
return deserialized

instance.calculations = deserialize_logs(data['calculations'])
instance.workflows = deserialize_logs(data['workflows'])

except (json.JSONDecodeError, OSError):
raise

return instance
12 changes: 11 additions & 1 deletion src/aiida/tools/dumping/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(

self.base_dumper = base_dumper or BaseDumper()
self.process_dumper = process_dumper or ProcessDumper()
self.dump_logger = dump_logger or DumpLogger()
self.dump_logger = dump_logger or DumpLogger(dump_parent_path=self.base_dumper.dump_parent_path)

# Load the profile
if isinstance(profile, str):
Expand Down Expand Up @@ -145,3 +145,13 @@ def _get_no_group_nodes(self) -> list[str]:
nodes = filter_by_last_dump_time(nodes=nodes, last_dump_time=self.base_dumper.last_dump_time)

return nodes

@staticmethod
def _get_number_of_nodes_to_dump(last_dump_time) -> dict[str, int]:
result = {}
for node_type in (orm.CalculationNode, orm.WorkflowNode):
qb = orm.QueryBuilder().append(node_type, project=['uuid'])
nodes = cast(list[str], qb.all(flat=True))
nodes = filter_by_last_dump_time(nodes=nodes, last_dump_time=last_dump_time)
result[node_type.class_node_type.split('.')[-2] + 's'] = len(nodes)
return result
5 changes: 1 addition & 4 deletions src/aiida/tools/dumping/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,7 @@ def prepare_dump_path(

# Case 1: Non-empty directory and overwrite is False
if not is_empty and not overwrite:
if incremental:
msg = f'Incremental dumping selected. Will update directory `{path_to_validate}` with new data.'
logger.report(msg)
else:
if not incremental:
msg = f'Path `{path_to_validate}` already exists, and neither overwrite nor incremental is enabled.'
raise FileExistsError(msg)

Expand Down

0 comments on commit 48acce7

Please sign in to comment.