Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions deepmd/dpmodel/output_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def __getitem__(
def get_data(self) -> dict[str, OutputVariableDef]:
return self.var_defs

def keys(self): # noqa: ANN201
def keys(self):
return self.var_defs.keys()


Expand Down Expand Up @@ -319,25 +319,25 @@ def get_data(
) -> dict[str, OutputVariableDef]:
return self.var_defs

def keys(self): # noqa: ANN201
def keys(self):
return self.var_defs.keys()

def keys_outp(self): # noqa: ANN201
def keys_outp(self):
return self.def_outp.keys()

def keys_redu(self): # noqa: ANN201
def keys_redu(self):
return self.def_redu.keys()

def keys_derv_r(self): # noqa: ANN201
def keys_derv_r(self):
return self.def_derv_r.keys()

def keys_hess_r(self): # noqa: ANN201
def keys_hess_r(self):
return self.def_hess_r.keys()

def keys_derv_c(self): # noqa: ANN201
def keys_derv_c(self):
return self.def_derv_c.keys()

def keys_derv_c_redu(self): # noqa: ANN201
def keys_derv_c_redu(self):
return self.def_derv_c_redu.keys()


Expand Down
58 changes: 29 additions & 29 deletions deepmd/dpmodel/utils/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
)


def sigmoid_t(x): # noqa: ANN001, ANN201
def sigmoid_t(x):
"""Sigmoid."""
if array_api_compat.is_jax_array(x):
from deepmd.jax.env import (
Expand All @@ -55,7 +55,7 @@ class Identity(NativeOP):
def __init__(self) -> None:
super().__init__()

def call(self, x): # noqa: ANN001, ANN201
def call(self, x):
"""The Identity operation layer."""
return x

Expand Down Expand Up @@ -260,7 +260,7 @@ def dim_out(self) -> int:
return self.w.shape[1]

@support_array_api(version="2022.12")
def call(self, x): # noqa: ANN001, ANN201
def call(self, x):
"""Forward pass.

Parameters
Expand Down Expand Up @@ -301,22 +301,22 @@ def get_activation_fn(activation_function: str) -> Callable[[np.ndarray], np.nda
activation_function = activation_function.lower()
if activation_function == "tanh":

def fn(x): # noqa: ANN001, ANN202 # noqa: ANN001, ANN202
def fn(x):
xp = array_api_compat.array_namespace(x)
return xp.tanh(x)

return fn
elif activation_function == "relu":

def fn(x): # noqa: ANN001, ANN202
def fn(x):
xp = array_api_compat.array_namespace(x)
# https://stackoverflow.com/a/47936476/9567349
return x * xp.astype(x > 0, x.dtype)

return fn
elif activation_function in ("gelu", "gelu_tf"):

def fn(x): # noqa: ANN001, ANN202
def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return (
Expand All @@ -328,7 +328,7 @@ def fn(x): # noqa: ANN001, ANN202
return fn
elif activation_function == "relu6":

def fn(x): # noqa: ANN001, ANN202
def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return xp.where(
Expand All @@ -338,22 +338,22 @@ def fn(x): # noqa: ANN001, ANN202
return fn
elif activation_function == "softplus":

def fn(x): # noqa: ANN001, ANN202
def fn(x):
xp = array_api_compat.array_namespace(x)
# generated by GitHub Copilot
return xp.log(1 + xp.exp(x))

return fn
elif activation_function == "sigmoid":

def fn(x): # noqa: ANN001, ANN202
def fn(x):
# generated by GitHub Copilot
return sigmoid_t(x)

return fn
elif activation_function == "silu":

def fn(x): # noqa: ANN001, ANN202
def fn(x):
# generated by GitHub Copilot
return x * sigmoid_t(x)

Expand All @@ -362,13 +362,13 @@ def fn(x): # noqa: ANN001, ANN202
"custom_silu"
):

def sigmoid(x): # noqa: ANN001, ANN202
def sigmoid(x):
return 1 / (1 + np.exp(-x))

def silu(x): # noqa: ANN001, ANN202
def silu(x):
return x * sigmoid(x)

def silu_grad(x): # noqa: ANN001, ANN202
def silu_grad(x):
sig = sigmoid(x)
return sig + x * sig * (1 - sig)

Expand All @@ -380,7 +380,7 @@ def silu_grad(x): # noqa: ANN001, ANN202
slope = float(silu_grad(threshold))
const = float(silu(threshold))

def fn(x): # noqa: ANN001, ANN202
def fn(x):
xp = array_api_compat.array_namespace(x)
return xp.where(
x < threshold,
Expand All @@ -391,7 +391,7 @@ def fn(x): # noqa: ANN001, ANN202
return fn
elif activation_function.lower() in ("none", "linear"):

def fn(x): # noqa: ANN001, ANN202
def fn(x):
return x

return fn
Expand Down Expand Up @@ -535,7 +535,7 @@ def __getitem__(self, key: str) -> Any:
def dim_out(self) -> int:
return self.w.shape[0]

def call(self, x): # noqa: ANN001, ANN201
def call(self, x):
"""Forward pass.

Parameters
Expand All @@ -552,11 +552,11 @@ def call(self, x): # noqa: ANN001, ANN201
return y

@staticmethod
def layer_norm_numpy( # noqa: ANN205
x, # noqa: ANN001
def layer_norm_numpy(
x,
shape: tuple[int, ...],
weight=None, # noqa: ANN001
bias=None, # noqa: ANN001
weight=None,
bias=None,
eps: float = 1e-5,
):
xp = array_api_compat.array_namespace(x)
Expand Down Expand Up @@ -633,7 +633,7 @@ def check_shape_consistency(self) -> None:
f"output {self.layers[ii].dim_out}",
)

def call(self, x): # noqa: ANN001, ANN202
def call(self, x):
"""Forward pass.

Parameters
Expand All @@ -650,7 +650,7 @@ def call(self, x): # noqa: ANN001, ANN202
x = layer(x)
return x

def call_until_last(self, x): # noqa: ANN001, ANN202
def call_until_last(self, x):
"""Return the output before last layer.

Parameters
Expand Down Expand Up @@ -1025,9 +1025,9 @@ def deserialize(cls, data: dict) -> "NetworkCollection":
return cls(**data)


def aggregate( # noqa: ANN201
data, # noqa: ANN001
owners, # noqa: ANN001
def aggregate(
data,
owners,
average: bool = True,
num_owner: Optional[int] = None,
):
Expand Down Expand Up @@ -1065,10 +1065,10 @@ def aggregate( # noqa: ANN201
return output


def get_graph_index( # noqa: ANN201
nlist, # noqa: ANN001
nlist_mask, # noqa: ANN001
a_nlist_mask, # noqa: ANN001
def get_graph_index(
nlist,
nlist_mask,
a_nlist_mask,
nall: int,
use_loc_mapping: bool = True,
):
Expand Down
28 changes: 16 additions & 12 deletions deepmd/pd/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Path,
)
from typing import (
Any,
Optional,
Union,
)
Expand Down Expand Up @@ -80,15 +81,15 @@


def get_trainer(
config,
init_model=None,
restart_model=None,
finetune_model=None,
force_load=False,
init_frz_model=None,
shared_links=None,
finetune_links=None,
):
config: dict[str, Any],
init_model: Optional[str] = None,
restart_model: Optional[str] = None,
finetune_model: Optional[str] = None,
force_load: bool = False,
init_frz_model: Optional[str] = None,
shared_links: Optional[dict[str, Any]] = None,
finetune_links: Optional[dict[str, Any]] = None,
) -> training.Trainer:
multi_task = "model_dict" in config.get("model", {})

# Initialize DDP
Expand All @@ -98,8 +99,11 @@ def get_trainer(
fleet.init(is_collective=True)

def prepare_trainer_input_single(
model_params_single, data_dict_single, rank=0, seed=None
):
model_params_single: dict[str, Any],
data_dict_single: dict[str, Any],
rank: int = 0,
seed: Optional[int] = None,
) -> tuple[Any, Any, Any, Optional[Any]]:
training_dataset_params = data_dict_single["training_data"]
validation_dataset_params = data_dict_single.get("validation_data", None)
validation_systems = (
Expand Down Expand Up @@ -535,7 +539,7 @@ def change_bias(
log.info(f"Saved model to {output_path}")


def main(args: Optional[Union[list[str], argparse.Namespace]] = None):
def main(args: Optional[Union[list[str], argparse.Namespace]] = None) -> None:
if not isinstance(args, argparse.Namespace):
FLAGS = parse_args(args=args)
else:
Expand Down
9 changes: 6 additions & 3 deletions deepmd/pd/infer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
from copy import (
deepcopy,
)
from typing import (
Optional,
)

import paddle

Expand All @@ -23,9 +26,9 @@
class Tester:
def __init__(
self,
model_ckpt,
head=None,
):
model_ckpt: str,
head: Optional[str] = None,
) -> None:
"""Construct a DeePMD tester.

Args:
Expand Down
7 changes: 5 additions & 2 deletions deepmd/pd/loss/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
ABC,
abstractmethod,
)
from typing import (
NoReturn,
)

import paddle

Expand All @@ -15,11 +18,11 @@


class TaskLoss(paddle.nn.Layer, ABC, make_plugin_registry("loss")):
def __init__(self, **kwargs):
def __init__(self) -> None:
"""Construct loss."""
super().__init__()

def forward(self, input_dict, model, label, natoms, learning_rate):
def forward(self, input_dict: dict[str, paddle.Tensor], model: paddle.nn.Layer, label: dict[str, paddle.Tensor], natoms: int, learning_rate: float) -> NoReturn:
"""Return loss ."""
raise NotImplementedError

Expand Down
12 changes: 11 additions & 1 deletion deepmd/pd/model/atomic_model/energy_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Union,
)

from deepmd.pd.model.task.ener import (
EnergyFittingNet,
InvarFitting,
Expand All @@ -10,7 +14,13 @@


class DPEnergyAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
def __init__(
self,
descriptor: object,
fitting: Union[EnergyFittingNet, InvarFitting],
type_map: list[str],
**kwargs: object,
) -> None:
assert isinstance(fitting, EnergyFittingNet) or isinstance(
fitting, InvarFitting
)
Expand Down
4 changes: 2 additions & 2 deletions deepmd/pd/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ def update_sel(
)
return local_jdata_cpy, min_nbor_dist

def get_fitting_net(self):
def get_fitting_net(self) -> object:
"""Get the fitting network."""
return self.atomic_model.fitting_net

def get_descriptor(self):
def get_descriptor(self) -> object:
"""Get the descriptor."""
return self.atomic_model.descriptor

Expand Down
4 changes: 2 additions & 2 deletions deepmd/pd/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def __init__(
mixed_types: bool = True,
seed: Optional[Union[int, list[int]]] = None,
type_map: Optional[list[str]] = None,
**kwargs,
):
**kwargs: object,
) -> None:
super().__init__(
"energy",
ntypes,
Expand Down
Loading