Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve Full Trajectory Information in AnalysisBase #4892

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
9 changes: 7 additions & 2 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,18 @@ The rules for this file:


-------------------------------------------------------------------------------
??/??/?? IAlibay, orbeckst

??/??/?? IAlibay, orbeckst, yuxuanzhuang

* 2.10.0

Fixes

Enhancements
* Added _run_slicer to make it possible to retrieve information
of the full trajectory slice in AnalysisBase. (Issue #4891 PR #4892)
* Added _run_frame_index to to keep track of frame iteration
number for the full trajectory slice in `single_frame` in AnalysisBase
(Issue #4891 PR #4892)

Changes

Expand Down Expand Up @@ -64,6 +68,7 @@ Changes
* Codebase is now formatted with black (version `24`) (PR #4886)



11/11/24 IAlibay, HeetVekariya, marinegor, lilyminium, RMeli,
ljwoods2, aditya292002, pstaerk, PicoCentauri, BFedder,
tyler.je.reddy, SampurnaM, leonwehrhan, kainszs, orionarcher,
Expand Down
79 changes: 77 additions & 2 deletions package/MDAnalysis/analysis/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,8 @@
import logging
import warnings
from functools import partial
from typing import Iterable, Union
from typing import Iterable, Union, Optional, List
from dataclasses import dataclass

import numpy as np
from .. import coordinates
Expand All @@ -170,6 +171,29 @@
logger = logging.getLogger(__name__)


@dataclass(frozen=True)
class RunConfig:
"""Stores user-provided arguments for `run()`."""
start: Optional[int] = None
stop: Optional[int] = None
step: Optional[int] = None
frames: Optional[np.ndarray] = None
backend: Optional[Union[str, BackendBase]] = None
n_workers: Optional[int] = None
n_parts: Optional[int] = None
unsupported_backend: bool = False


@dataclass
class RunState:
"""Stores runtime-generated attributes that can be used
during the analysis."""
slicer: Optional[Union[slice, np.ndarray]] = None
n_frames: Optional[int] = None
computation_groups: Optional[List[np.ndarray]] = None
frame_index: Optional[int] = None


class AnalysisBase(object):
r"""Base class for defining multi-frame analysis

Expand Down Expand Up @@ -441,8 +465,14 @@
each of the workers and gets executed twice: one time in
:meth:`_setup_frames` for the whole trajectory, second time in
:meth:`_compute` for each of the computation groups.

.. versionchanged:: 2.9.0
Add `self._run_slicer` attribute to store the slicer for the
whole trajectory being analyzed.
"""
slicer = self._define_run_frames(trajectory, start, stop, step, frames)
self.run_state.slicer = slicer
self.run_state.n_frames = len(trajectory[slicer])
self._prepare_sliced_trajectory(slicer)

def _single_frame(self):
Expand All @@ -452,6 +482,10 @@
Attributes accessible during your calculations:

- ``self._frame_index``: index of the frame in results array
Note that this is not the same as the frame number in the trajectory
- ``self._run_frame_index``: index of the frame in the trajectory
This is useful for parallel runs, where you can't rely on the
`self._frame_index`.
- ``self._ts`` -- Timestep instance
- ``self._sliced_trajectory`` -- trajectory that you're iterating over
- ``self.results`` -- :class:`MDAnalysis.analysis.results.Results` instance
Expand Down Expand Up @@ -537,6 +571,7 @@
)
):
self._frame_index = idx # accessed later by subclasses
self.run_state.frame_index = indexed_frames[idx, 0]
self._ts = ts
self.frames[idx] = ts.frame
self.times[idx] = ts.time
Expand Down Expand Up @@ -778,7 +813,7 @@
By default, performs calculations in a serial fashion.
Otherwise, user can choose a backend: ``str`` is matched to a
builtin backend (one of ``serial``, ``multiprocessing`` and
``dask``), or a :class:`MDAnalysis.analysis.results.BackendBase`
``dask``), or a :class:`MDAnalysis.analysis.backends.BackendBase`
subclass.

.. versionadded:: 2.8.0
Expand Down Expand Up @@ -854,6 +889,17 @@
f"{executor.n_workers=} is greater than {n_parts=}"
)
)

self._run_config = RunConfig(
start=start,
stop=stop,
step=step,
frames=frames,
backend=backend,
n_workers=n_workers,
n_parts=n_parts,
unsupported_backend=unsupported_backend,
)

# start preparing the run
worker_func = partial(
Expand All @@ -871,6 +917,7 @@
computation_groups = self._setup_computation_groups(
start=start, stop=stop, step=step, frames=frames, n_parts=n_parts
)
self.run_state.computation_groups = computation_groups

# get all results from workers in other processes.
# we need `AnalysisBase` classes
Expand Down Expand Up @@ -902,6 +949,34 @@
.. versionadded:: 2.8.0
"""
return ResultsGroup(lookup=None)

@property
def run_config(self) -> RunConfig:
"""Stores user-provided arguments for `run()`.
It includes `start`, `stop`, `step`, `frames`, `backend`, `n_workers`,
`n_parts` and `unsupported_backend` attributes.
"""
return self._run_config

Check warning on line 959 in package/MDAnalysis/analysis/base.py

View check run for this annotation

Codecov / codecov/patch

package/MDAnalysis/analysis/base.py#L959

Added line #L959 was not covered by tests

@property
def run_state(self) -> RunState:
"""
Stores runtime-generated attributes that can be used
during the analysis.

It includes `slicer`, `n_frames`, `computation_groups` and `frame_index`
attributes.

The `slicer`, `n_frames`, `frame_index` attributes are used to store the
information for the whole trajectory being analyzed.
They are different from e.g. `self.n_frames` which is used to store the
information for the current computation group being analyzed.
"""
# lazy initialization for the OldAPIAnalysis as it doesn't have the same
# `__init__` function as the current AnalysisBase.
if not hasattr(self, "_run_state"):
self._run_state = RunState()
return self._run_state


class AnalysisFromFunction(AnalysisBase):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,29 @@ For MDAnalysis developers
From a developer point of view, there are a few methods that are important in
order to understand how parallelization is implemented:

#. :meth:`MDAnalysis.analysis.base.AnalysisBase._define_run_frames`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_frames`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._prepare_sliced_trajectory`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._configure_backend`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._setup_computation_groups`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._compute`
#. :meth:`MDAnalysis.analysis.base.AnalysisBase._get_aggregator`

The first two methods share the functionality of :meth:`_setup_frames`.
:meth:`_define_run_frames` is run once during analysis, as it checks that input
parameters `start`, `stop`, `step` or `frames` are consistent with the given
trajectory and prepares the ``slicer`` object that defines the iteration
pattern through the trajectory. :meth:`_prepare_sliced_trajectory` assigns to
:meth:`_setup_frames` is run once during analysis :attr:`run()`, as it checks that input
parameters :attr:`start`, :attr:`stop`, :attr:`step` or :attr:`frames` are consistent with the given
trajectory and prepares the :attr:`slicer` object that defines the iteration
pattern through the trajectory with :meth:`_define_run_frames`.
The attribute :attr:`self._run_slicer` is assigned based on the `slicer`.
Users can later access the full sliced trajectory being analyzed via
:attr:`self._trajectory[self._run_slicer]`.

:meth:`_prepare_sliced_trajectory` assigns to
the :attr:`self._sliced_trajectory` attribute, computes the number of frames in
it, and fills the :attr:`self.frames` and :attr:`self.times` arrays. In case
the computation will be later split between other processes, this method will
be called again on each of the computation groups.
be called again on each of the computation groups. In parallel analysis,
:attr:`self._sliced_trajectory` represents a split of the original sliced
trajectory, and :attr:`self.n_frames` is the number of frames in each split
computation group (not the total number of frames in the sliced trajectory).

The method :meth:`_configure_backend` performs basic health checks for a given
analysis class -- namely, it compares a given backend (if it's a :class:`str`
Expand All @@ -155,7 +162,14 @@ analysis get initialized with the :meth:`_prepare` method. Then the function
iterates over :attr:`self._sliced_trajectory`, assigning
:attr:`self._frame_index` and :attr:`self._ts` as frame index (within a
computation group) and timestamp, and also setting respective
:attr:`self.frames` and :attr:`self.times` array values.
:attr:`self.frames` and :attr:`self.times` array values. Additionally,
:attr:`self._run_frame_index` is assigned the run frame index
within the full sliced trajectory (:attr:`self._trajectory[self._run_slicer]`)
that is being analyzed.
This run frame index is particularly useful for analyses requiring it, such as
:class:`MDAnalysis.analysis.diffusionmap.DistanceMatrix` that needs to know the
frame index in the trajectory sliced that is being analyzed.
See :ref:`retrieving-correct-frame-index` for more details.

After :meth:`_compute` has finished, the main analysis instance calls the
:meth:`_get_aggregator` method, which merges the :attr:`self.results`
Expand Down Expand Up @@ -357,6 +371,82 @@ In this way, you will override the check for supported backends.
with a supported backend. When reporting *always mention if you used*
``unsupported_backend=True``.

.. _retrieving-correct-frame-index:
Retrieving correct frame index in parallel analysis
===================================================

To retrieve the correct frame index during parallel analysis, use the
:attr:`self._run_frame_index` attribute. This attribute represents the correct
frame index within the full sliced trajectory
(:attr:`self._trajectory[self._run_slicer]`).

For an example illustrating when to use :attr:`_frame_index` versus
:attr:`_run_frame_index` and :attr:`self._run_slicer`,
see the following code snippet:

.. code-block:: python

from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.analysis.results import ResultsGroup

class MyAnalysis(AnalysisBase):
_analysis_algorithm_is_parallelizable = True

@classmethod
def get_supported_backends(cls):
"""Define the supported backends for the analysis."""
return ('serial', 'multiprocessing', 'dask')

def _prepare(self):
"""Initialize result attributes and compute frame count."""
self.results.frame_index = []
self.results.run_frame_index = []
self.results.n_frames = []
self.results.run_n_frames = []
self.run_n_frames = len(self._trajectory[self._run_slicer])

def _single_frame(self):
"""Process a single frame during the analysis."""
frame_index = self._frame_index
run_frame_index = self._run_frame_index

# Append results for the current frame
self.results.frame_index.append(frame_index)
self.results.run_frame_index.append(run_frame_index)
self.results.n_frames.append(self.n_frames)
self.results.run_n_frames.append(self.run_n_frames)

def _get_aggregator(self):
"""Return an aggregator to combine results from multiple workers."""
return ResultsGroup(
lookup={
'frame_index': ResultsGroup.flatten_sequence,
'run_frame_index': ResultsGroup.flatten_sequence,
'n_frames': ResultsGroup.flatten_sequence,
'run_n_frames': ResultsGroup.flatten_sequence,
}
)

# Example usage: serial analysis
ana = MyAnalysis(u.trajectory)
ana.run(step=2)
print(ana.results)
# Output:
# {'frame_index': [0, 1, 2, 3, 4],
# 'run_frame_index': [0, 1, 2, 3, 4],
# 'n_frames': [5, 5, 5, 5, 5],
# 'run_n_frames': [5, 5, 5, 5, 5]}

# Example usage: parallel analysis
ana = MyAnalysis(u.trajectory)
ana.run(step=2, backend='dask', n_workers=2)
print(ana.results)
# Output:
# {'frame_index': [0, 1, 2, 0, 1],
# 'run_frame_index': [0, 1, 2, 3, 4],
# 'n_frames': [3, 3, 3, 2, 2],
# 'run_n_frames': [5, 5, 5, 5, 5]}


.. rubric:: References
.. footbibliography::
Expand Down
29 changes: 27 additions & 2 deletions testsuite/MDAnalysisTests/analysis/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,36 @@ def __init__(self, reader, **kwargs):

def _prepare(self):
self.results.found_frames = []
self.results.frame_index = []
self.results.run_frame_index = []
self.results.n_frames = []
self.results.run_n_frames = []

# self.n_frames is defined elsewhere
self.run_n_frames = len(self._trajectory[self.run_state.slicer])

def _single_frame(self):
frame_index = self._frame_index
run_frame_index = self.run_state.frame_index

self.results.found_frames.append(self._ts.frame)
self.results.frame_index.append(frame_index)
self.results.run_frame_index.append(run_frame_index)
self.results.n_frames.append(self.n_frames)
self.results.run_n_frames.append(self.run_n_frames)

def _conclude(self):
self.found_frames = list(self.results.found_frames)

def _get_aggregator(self):
return base.ResultsGroup(
{"found_frames": base.ResultsGroup.ndarray_hstack}
{
"found_frames": base.ResultsGroup.ndarray_hstack,
"frame_index": base.ResultsGroup.ndarray_hstack,
"run_frame_index": base.ResultsGroup.ndarray_hstack,
"n_frames": base.ResultsGroup.ndarray_hstack,
"run_n_frames": base.ResultsGroup.ndarray_hstack,
}
)


Expand Down Expand Up @@ -450,12 +470,17 @@ def test_frames_times(client_FrameAnalysis):
start=1, stop=8, step=2, **client_FrameAnalysis
)
frames = np.array([1, 3, 5, 7])
assert an.n_frames == len(frames)
n_frames = len(frames)
frame_indices = np.arange(n_frames)

assert an.n_frames == n_frames
assert_equal(an.found_frames, frames)
assert_equal(an.frames, frames, err_msg=FRAMES_ERR)
assert_allclose(
an.times, frames * 100, rtol=0, atol=1.5e-4, err_msg=TIMES_ERR
)
assert_equal(an.results.run_frame_index, frame_indices)
assert_equal(an.results.run_n_frames, [n_frames] * n_frames)


def test_verbose(u):
Expand Down
Loading