Skip to content

Implement geomopt Constraints #313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 55 additions & 4 deletions janus_core/calculations/geom_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
from __future__ import annotations

from collections.abc import Callable
import inspect
from pathlib import Path
from typing import Any
import warnings

from ase import Atoms, filters, units
from ase import Atoms, constraints, filters, units
from ase.filters import FrechetCellFilter
from ase.io import read
import ase.optimize
Expand Down Expand Up @@ -80,6 +81,10 @@ class GeomOpt(BaseCalculation):
Deprecated. Please use `filter_class`.
filter_kwargs
Keyword arguments to pass to filter_class. Default is {}.
constraint_func
Constraint function, or name of function from ase.constraints. Default is None.
constraint_kwargs
Keyword arguments to pass to constraint_func. Default is {}.
optimizer
Optimization function, or name of function from ase.optimize. Default is
`LBFGS`.
Expand Down Expand Up @@ -119,6 +124,8 @@ def __init__(
filter_class: Callable | str | None = FrechetCellFilter,
filter_func: Callable | str | None = None,
filter_kwargs: dict[str, Any] | None = None,
constraint_func: Callable | str | None = None,
constraint_kwargs: dict[str, Any] | None = None,
optimizer: Callable | str = LBFGS,
opt_kwargs: ASEOptArgs | None = None,
write_results: bool = False,
Expand Down Expand Up @@ -177,6 +184,11 @@ def __init__(
Deprecated. Please use `filter_class`.
filter_kwargs
Keyword arguments to pass to filter_class. Default is {}.
constraint_func
Constraint function, or name of function from ase.constraints. Default is
None.
constraint_kwargs
Keyword arguments to pass to constraint_func. Default is {}.
optimizer
Optimization function, or name of function from ase.optimize. Default is
`LBFGS`.
Expand All @@ -194,9 +206,21 @@ def __init__(
"filename" keyword is inferred from `file_prefix` if not given.
Default is {}.
"""
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs = (
(
read_kwargs,
constraint_kwargs,
filter_kwargs,
opt_kwargs,
write_kwargs,
traj_kwargs,
) = list(
none_to_dict(
read_kwargs, filter_kwargs, opt_kwargs, write_kwargs, traj_kwargs
read_kwargs,
constraint_kwargs,
filter_kwargs,
opt_kwargs,
write_kwargs,
traj_kwargs,
)
)

Expand All @@ -211,6 +235,8 @@ def __init__(
self.angle_tolerance = angle_tolerance
self.filter_class = filter_class
self.filter_kwargs = filter_kwargs
self.constraint_func = constraint_func
self.constraint_kwargs = constraint_kwargs
self.optimizer = optimizer
self.opt_kwargs = opt_kwargs
self.write_results = write_results
Expand Down Expand Up @@ -301,12 +327,29 @@ def output_files(self) -> None:
"trajectory": self.traj_kwargs.get("filename"),
}

def _set_mandatory_constraint_kwargs(self) -> None:
"""
Inspect constraint class for mandatory arguments.

For now we are just looking for the "atoms" parameter of FixSymmetry
"""
parameters = inspect.signature(self.constraint_func.__init__).parameters
if "atoms" in parameters:
self.constraint_kwargs["atoms"] = self.struct

def set_optimizer(self) -> None:
"""Set optimizer for geometry optimization."""
self._set_functions()
if self.logger:
self.logger.info("Using optimizer: %s", self.optimizer.__name__)

if self.constraint_func is not None:
self._set_mandatory_constraint_kwargs()
self.struct.set_constraint(self.constraint_func(**self.constraint_kwargs))

if self.logger:
self.logger.info("Using constraint: %s", self.constraint_func.__name__)

if self.filter_class is not None:
if "scalar_pressure" in self.filter_kwargs:
self.filter_kwargs["scalar_pressure"] *= units.GPa
Expand All @@ -332,13 +375,21 @@ def set_optimizer(self) -> None:
self.dyn = self.optimizer(self.struct, **self.opt_kwargs)

def _set_functions(self) -> None:
"""Set optimizer and filter."""
"""Set optimizer, constraint and filter functions."""
if isinstance(self.optimizer, str):
try:
self.optimizer = getattr(ase.optimize, self.optimizer)
except AttributeError as e:
raise AttributeError(f"No such optimizer: {self.optimizer}") from e

if self.constraint_func is not None and isinstance(self.constraint_func, str):
try:
self.constraint_func = getattr(constraints, self.constraint_func)
except AttributeError as e:
raise AttributeError(
f"No such constraint: {self.constraint_func}"
) from e

if self.filter_class is not None and isinstance(self.filter_class, str):
try:
self.filter_class = getattr(filters, self.filter_class)
Expand Down
12 changes: 11 additions & 1 deletion janus_core/cli/geomopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,10 @@ def geomopt(
str | None,
Option(help="Deprecated. Please use --filter", rich_help_panel="Calculation"),
] = None,
constraint_func: Annotated[
str,
Option(help="Name of ASE constraint function to use."),
] = None,
pressure: Annotated[
float,
Option(
Expand Down Expand Up @@ -212,7 +216,11 @@ def geomopt(
Name of filter from ase.filters to wrap around atoms. If using
--opt-cell-lengths or --opt-cell-fully, defaults to `FrechetCellFilter`.
filter_func
Deprecated. Please use `filter_class`.
Deprecated. Please use `--filter_class`.
constraint_func
Name of constraint function from ase.constraints, to apply constraints
to atoms. Parameters should be included as a "constraint_kwargs" dict
within "minimize_kwargs". Default is None.
pressure
Scalar pressure when optimizing cell geometry, in GPa. Passed to the filter
function if either `opt_cell_lengths` or `opt_cell_fully` is True. Default is
Expand Down Expand Up @@ -293,6 +301,7 @@ def geomopt(
# Check optimized structure path not duplicated
if "filename" in write_kwargs:
raise ValueError("'filename' must be passed through the --out option")

if out:
write_kwargs["filename"] = out

Expand Down Expand Up @@ -339,6 +348,7 @@ def geomopt(
"symmetrize": symmetrize,
"symmetry_tolerance": symmetry_tolerance,
"file_prefix": file_prefix,
"constraint_func": constraint_func,
**opt_cell_fully_dict,
**minimize_kwargs,
"write_results": True,
Expand Down
Loading