From e7ee5a4a077c2179d6037ba634240ce7ebcb76ef Mon Sep 17 00:00:00 2001 From: Uma Kadam <90750049+umak1106@users.noreply.github.com> Date: Mon, 12 Sep 2022 13:46:26 +0530 Subject: [PATCH] Type hints for lib module (#3729) * Type hints for mdamath.py * fix errors in mdamath.py * changed input from NDarray to arraylike * Added type annotations to init.py Added type annotations to init to avoid mypy from raising errors when other modules are type checked . * Allowing mypy to type check lib module * Update changes in init.py * Update __init__.py * type hints for pkdtree * Fix all errros in pkdtree.py * Chage npt.NDArray to np.ndarray * Update pkdtree.py * Update pkdtree.py * Update pkdtree.py * Update NeighborSearch.py * Update NeighborSearch.py * Update pkdtree.py Co-authored-by: Jonathan Barnoud --- mypy.ini | 8 ----- package/MDAnalysis/__init__.py | 3 +- package/MDAnalysis/lib/NeighborSearch.py | 16 +++++++-- package/MDAnalysis/lib/mdamath.py | 30 ++++++++++------- package/MDAnalysis/lib/pkdtree.py | 42 ++++++++++++++++-------- 5 files changed, 62 insertions(+), 37 deletions(-) diff --git a/mypy.ini b/mypy.ini index d5e2a652e78..d87115bc20e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -18,9 +18,6 @@ ignore_errors = True [mypy-MDAnalysis.core.*] ignore_errors = True -[mypy-MDAnalysis.lib.*] -ignore_errors = True - [mypy-MDAnalysis.selections.*] ignore_errors = True @@ -47,8 +44,3 @@ ignore_errors = True [mypy-MDAnalysis.version] ignore_errors = True - -[mypy-MDAnalysis.*] -ignore_errors = True - - diff --git a/package/MDAnalysis/__init__.py b/package/MDAnalysis/__init__.py index 902a1a467e6..b72908d765c 100644 --- a/package/MDAnalysis/__init__.py +++ b/package/MDAnalysis/__init__.py @@ -179,8 +179,7 @@ _CONVERTERS: Dict = {} # Registry of TopologyAttributes _TOPOLOGY_ATTRS: Dict = {} # {attrname: cls} -_TOPOLOGY_TRANSPLANTS: Dict = {} -# {name: [attrname, method, transplant class]} +_TOPOLOGY_TRANSPLANTS: Dict = {} # {name: [attrname, method, transplant class]} _TOPOLOGY_ATTRNAMES: Dict = {} # {lower case name w/o _ : name} diff --git a/package/MDAnalysis/lib/NeighborSearch.py b/package/MDAnalysis/lib/NeighborSearch.py index f581b72b352..572973afcd1 100644 --- a/package/MDAnalysis/lib/NeighborSearch.py +++ b/package/MDAnalysis/lib/NeighborSearch.py @@ -31,6 +31,9 @@ import numpy as np from MDAnalysis.lib.distances import capped_distance from MDAnalysis.lib.util import unique_int_1d +from MDAnalysis.core.groups import AtomGroup, SegmentGroup, ResidueGroup +import numpy.typing as npt +from typing import Optional, Union, List class AtomNeighborSearch(object): @@ -41,7 +44,8 @@ class AtomNeighborSearch(object): :class:`~MDAnalysis.lib.distances.capped_distance`. """ - def __init__(self, atom_group, box=None): + def __init__(self, atom_group: AtomGroup, + box: Optional[npt.ArrayLike] = None) -> None: """ Parameters @@ -58,7 +62,10 @@ def __init__(self, atom_group, box=None): self._u = atom_group.universe self._box = box - def search(self, atoms, radius, level='A'): + def search(self, atoms: AtomGroup, + radius: float, + level: str = 'A' + ) -> Optional[Union[AtomGroup, ResidueGroup, SegmentGroup]]: """ Return all atoms/residues/segments that are within *radius* of the atoms in *atoms*. @@ -102,7 +109,10 @@ def search(self, atoms, radius, level='A'): unique_idx = unique_int_1d(np.asarray(pairs[:, 1], dtype=np.intp)) return self._index2level(unique_idx, level) - def _index2level(self, indices, level): + def _index2level(self, + indices: List[int], + level: str + ) -> Union[AtomGroup, ResidueGroup, SegmentGroup]: """Convert list of atom_indices in a AtomGroup to either the Atoms or segments/residues containing these atoms. diff --git a/package/MDAnalysis/lib/mdamath.py b/package/MDAnalysis/lib/mdamath.py index ecf0577669c..fff63d73bd5 100644 --- a/package/MDAnalysis/lib/mdamath.py +++ b/package/MDAnalysis/lib/mdamath.py @@ -63,11 +63,13 @@ from . import util from ._cutil import (make_whole, find_fragments, _sarrus_det_single, _sarrus_det_multiple) +import numpy.typing as npt +from typing import Union # geometric functions -def norm(v): +def norm(v: npt.ArrayLike) -> float: r"""Calculate the norm of a vector v. .. math:: v = \sqrt{\mathbf{v}\cdot\mathbf{v}} @@ -90,7 +92,8 @@ def norm(v): return np.sqrt(np.dot(v, v)) -def normal(vec1, vec2): +# typing: numpy +def normal(vec1: npt.ArrayLike, vec2: npt.ArrayLike) -> np.ndarray: r"""Returns the unit vector normal to two vectors. .. math:: @@ -110,7 +113,8 @@ def normal(vec1, vec2): return normal / n -def pdot(a, b): +# typing: numpy +def pdot(a: npt.ArrayLike, b: npt.ArrayLike) -> np.ndarray: """Pairwise dot product. ``a`` must be the same shape as ``b``. @@ -127,7 +131,8 @@ def pdot(a, b): return np.einsum('ij,ij->i', a, b) -def pnorm(a): +# typing: numpy +def pnorm(a: npt.ArrayLike) -> np.ndarray: """Euclidean norm of each vector in a matrix Parameters @@ -141,7 +146,7 @@ def pnorm(a): return pdot(a, a)**0.5 -def angle(a, b): +def angle(a: npt.ArrayLike, b: npt.ArrayLike) -> float: """Returns the angle between two vectors in radians .. versionchanged:: 0.11.0 @@ -156,7 +161,7 @@ def angle(a, b): return np.arccos(x) -def stp(vec1, vec2, vec3): +def stp(vec1: npt.ArrayLike, vec2: npt.ArrayLike, vec3: npt.ArrayLike) -> float: r"""Takes the scalar triple product of three vectors. Returns the volume *V* of the parallel epiped spanned by the three @@ -172,7 +177,7 @@ def stp(vec1, vec2, vec3): return np.dot(vec3, np.cross(vec1, vec2)) -def dihedral(ab, bc, cd): +def dihedral(ab: npt.ArrayLike, bc: npt.ArrayLike, cd: npt.ArrayLike) -> float: r"""Returns the dihedral angle in radians between vectors connecting A,B,C,D. The dihedral measures the rotation around bc:: @@ -194,7 +199,8 @@ def dihedral(ab, bc, cd): return (x if stp(ab, bc, cd) <= 0.0 else -x) -def sarrus_det(matrix): +# typing: numpy +def sarrus_det(matrix: np.ndarray) -> Union[float, np.ndarray]: """Computes the determinant of a 3x3 matrix according to the `rule of Sarrus`_. @@ -239,7 +245,8 @@ def sarrus_det(matrix): return _sarrus_det_multiple(m.reshape((-1, 3, 3))).reshape(shape[:-2]) -def triclinic_box(x, y, z): +# typing: numpy +def triclinic_box(x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike) -> np.ndarray: """Convert the three triclinic box vectors to ``[lx, ly, lz, alpha, beta, gamma]``. @@ -301,7 +308,8 @@ def triclinic_box(x, y, z): return np.zeros(6, dtype=np.float32) -def triclinic_vectors(dimensions, dtype=np.float32): +# typing: numpy +def triclinic_vectors(dimensions: npt.ArrayLike, dtype: npt.DTypeLike = np.float32) -> np.ndarray: """Convert ``[lx, ly, lz, alpha, beta, gamma]`` to a triclinic matrix representation. @@ -399,7 +407,7 @@ def triclinic_vectors(dimensions, dtype=np.float32): return box_matrix -def box_volume(dimensions): +def box_volume(dimensions: npt.ArrayLike) -> float: """Return the volume of the unitcell described by `dimensions`. The volume is computed as the product of the box matrix trace, with the diff --git a/package/MDAnalysis/lib/pkdtree.py b/package/MDAnalysis/lib/pkdtree.py index 432674ea555..d41c675110a 100644 --- a/package/MDAnalysis/lib/pkdtree.py +++ b/package/MDAnalysis/lib/pkdtree.py @@ -37,6 +37,8 @@ from .util import unique_rows from MDAnalysis.lib.distances import apply_PBC +import numpy.typing as npt +from typing import Optional, ClassVar __all__ = [ 'PeriodicKDTree' @@ -61,7 +63,8 @@ class PeriodicKDTree(object): :func:`MDAnalysis.lib.distances.undo_augment` function. """ - def __init__(self, box=None, leafsize=10): + + def __init__(self, box: npt.ArrayLike = None, leafsize: int = 10) -> None: """ Parameters @@ -82,7 +85,7 @@ def __init__(self, box=None, leafsize=10): self.dim = 3 # 3D systems self.box = box self._built = False - self.cutoff = None + self.cutoff: Optional[float] = None @property def pbc(self): @@ -95,7 +98,7 @@ def pbc(self): """ return self.box is not None - def set_coords(self, coords, cutoff=None): + def set_coords(self, coords: npt.ArrayLike, cutoff: Optional[float] = None) -> None: """Constructs KDTree from the coordinates Wrapping of coordinates to the primary unit cell is enforced @@ -126,23 +129,24 @@ def set_coords(self, coords, cutoff=None): MDAnalysis.lib.distances.augment_coordinates """ - # If no cutoff distance is provided but PBC aware - if self.pbc and (cutoff is None): - raise RuntimeError('Provide a cutoff distance' - ' with tree.set_coords(...)') # set coords dtype to float32 # augment coordinates will work only with float32 coords = np.asarray(coords, dtype=np.float32) + # If no cutoff distance is provided but PBC aware if self.pbc: self.cutoff = cutoff + if cutoff is None: + raise RuntimeError('Provide a cutoff distance' + ' with tree.set_coords(...)') + # Bring the coordinates in the central cell self.coords = apply_PBC(coords, self.box) # generate duplicate images self.aug, self.mapping = augment_coordinates(self.coords, self.box, - self.cutoff) + cutoff) # Images + coords self.all_coords = np.concatenate([self.coords, self.aug]) self.ckdt = cKDTree(self.all_coords, leafsize=self.leafsize) @@ -155,7 +159,8 @@ def set_coords(self, coords, cutoff=None): self.ckdt = cKDTree(self.coords, self.leafsize) self._built = True - def search(self, centers, radius): + # typing: numpy + def search(self, centers: npt.ArrayLike, radius: float) -> np.ndarray: """Search all points within radius from centers and their periodic images. All the centers coordinates are wrapped around the central cell @@ -179,6 +184,9 @@ def search(self, centers, radius): # Sanity check if self.pbc: + if self.cutoff is None: + raise ValueError( + "Cutoff needs to be provided when working with PBC.") if self.cutoff < radius: raise RuntimeError('Set cutoff greater or equal to the radius.') # Bring all query points to the central cell @@ -202,17 +210,19 @@ def search(self, centers, radius): self._indices = np.asarray(unique_int_1d(self._indices)) return self._indices - def get_indices(self): + # typing: numpy + def get_indices(self) -> np.ndarray: """Return the neighbors from the last query. Returns ------ - indices : list + indices : NDArray neighbors for the last query points and search radius """ return self._indices - def search_pairs(self, radius): + # typing: numpy + def search_pairs(self, radius: float) -> np.ndarray: """Search all the pairs within a specified radius Parameters @@ -229,6 +239,9 @@ def search_pairs(self, radius): raise RuntimeError(' Unbuilt Tree. Run tree.set_coords(...)') if self.pbc: + if self.cutoff is None: + raise ValueError( + "Cutoff needs to be provided when working with PBC.") if self.cutoff < radius: raise RuntimeError('Set cutoff greater or equal to the radius.') @@ -245,7 +258,7 @@ def search_pairs(self, radius): pairs = unique_rows(pairs) return pairs - def search_tree(self, centers, radius): + def search_tree(self, centers: npt.ArrayLike, radius: float) -> np.ndarray: """ Searches all the pairs within `radius` between `centers` and ``coords`` @@ -285,6 +298,9 @@ class initialization # Sanity check if self.pbc: + if self.cutoff is None: + raise ValueError( + "Cutoff needs to be provided when working with PBC.") if self.cutoff < radius: raise RuntimeError('Set cutoff greater or equal to the radius.') # Bring all query points to the central cell