Skip to content

pd: add flag CINN_ALLOW_DYNAMIC_SHAPE for better performance with dynamic shape #4826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions deepmd/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def main_parser() -> argparse.ArgumentParser:
if default_backend not in BACKEND_TABLE.keys():
raise ValueError(
f"Unknown backend {default_backend}. "
"Please set DP_BACKEND to either tensorflow or pytorch."
"Please set DP_BACKEND to either tensorflow, pytorch, or paddle."
)

parser_backend = parser.add_mutually_exclusive_group()
Expand Down Expand Up @@ -312,7 +312,7 @@ def main_parser() -> argparse.ArgumentParser:
"--output",
type=str,
default="frozen_model",
help="Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth",
help="Filename (prefix) of the output model file. TensorFlow backend: suffix is .pb; PyTorch backend: suffix is .pth; Paddle backend: suffix is .json and .pdiparams",
)
parser_frz.add_argument(
"-n",
Expand Down
97 changes: 57 additions & 40 deletions deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from deepmd.pd.utils.env import (
CINN,
CINN_ALLOW_DYNAMIC_SHAPE,
DEFAULT_PRECISION,
DEVICE,
JIT,
Expand Down Expand Up @@ -609,49 +610,65 @@
)

backend = "CINN" if CINN else None
# NOTE: This is a trick to decide the right input_spec for wrapper.forward
_, label_dict, _ = self.get_data(is_train=True)

# Define specification templates
spec_templates = {
"find_box": np.float32(1.0),
"find_coord": np.float32(1.0),
"find_numb_copy": np.float32(0.0),
"numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"),
"find_energy": np.float32(1.0),
"energy": static.InputSpec([1, 1], "float64", name="energy"),
"find_force": np.float32(1.0),
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
"find_virial": np.float32(0.0),
"virial": static.InputSpec([1, 9], "float64", name="virial"),
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
}
# Build spec only for keys present in sample data
label_dict_spec = {
k: spec_templates[k] for k in label_dict.keys() if k in spec_templates
}
self.wrapper.forward = jit.to_static(
backend=backend,
input_spec=[
static.InputSpec([1, -1, 3], "float64", name="coord"), # coord
static.InputSpec([1, -1], "int32", name="atype"), # atype
None, # spin
static.InputSpec([1, 9], "float64", name="box"), # box
static.InputSpec([], "float64", name="cur_lr"), # cur_lr
label_dict_spec, # label,
# None, # task_key
# False, # inference_only
# False, # do_atomic_virial
# None, # fparam
# None, # aparam
],
full_graph=True,
)(self.wrapper.forward)
if CINN_ALLOW_DYNAMIC_SHAPE:

Check warning on line 613 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L613

Added line #L613 was not covered by tests
# Build spec only for keys present in sample data
# NOTE: This is a trick to decide the right input_spec for wrapper.forward
_, label_dict, _ = self.get_data(is_train=True)

Check warning on line 616 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L616

Added line #L616 was not covered by tests
# Define specification templates
spec_templates = {

Check warning on line 618 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L618

Added line #L618 was not covered by tests
"find_box": np.float32(1.0),
"find_coord": np.float32(1.0),
"find_numb_copy": np.float32(0.0),
"numb_copy": static.InputSpec([1, 1], "int64", name="numb_copy"),
"find_energy": np.float32(1.0),
"energy": static.InputSpec([1, 1], "float64", name="energy"),
"find_force": np.float32(1.0),
"force": static.InputSpec([1, -1, 3], "float64", name="force"),
"find_virial": np.float32(0.0),
"virial": static.InputSpec([1, 9], "float64", name="virial"),
"natoms": static.InputSpec([1, -1], "int32", name="natoms"),
}
label_dict_spec = {

Check warning on line 631 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L631

Added line #L631 was not covered by tests
k: spec_templates[k]
for k in label_dict.keys()
if k in spec_templates
}
self.wrapper.forward = jit.to_static(

Check warning on line 636 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L636

Added line #L636 was not covered by tests
backend=backend,
input_spec=[
static.InputSpec([1, -1, 3], "float64", name="coord"), # coord
static.InputSpec([1, -1], "int32", name="atype"), # atype
None, # spin
static.InputSpec([1, 9], "float64", name="box"), # box
static.InputSpec([], "float64", name="cur_lr"), # cur_lr
label_dict_spec, # label,
# None, # task_key
# False, # inference_only
# False, # do_atomic_virial
# None, # fparam
# None, # aparam
],
full_graph=True,
)(self.wrapper.forward)
else:
self.wrapper.forward = jit.to_static(full_graph=True, backend=backend)(

Check warning on line 654 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L654

Added line #L654 was not covered by tests
self.wrapper.forward
)

log.info(
"Enable CINN during training, there may be some additional "
"compilation time in the first traning step."
"[CINN] Enable CINN during training, there may be some additional "
"compilation time in the first training step."
)
if not CINN_ALLOW_DYNAMIC_SHAPE:
log.info(

Check warning on line 663 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L662-L663

Added lines #L662 - L663 were not covered by tests
"[CINN] Dynamic shape is disabled (CINN_ALLOW_DYNAMIC_SHAPE=0). "
"Make sure the input batch shapes are fixed during training. "
"This is recommended for optimal performance, e.g., as in examples/water."
)
log.info(

Check warning on line 668 in deepmd/pd/train/training.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pd/train/training.py#L668

Added line #L668 was not covered by tests
"[CINN] If batch data from your dataset(s) has varying input shapes, consider setting "
"CINN_ALLOW_DYNAMIC_SHAPE=1 to enable dynamic shape support."
)

if dist.is_available() and dist.is_initialized():
# DDP will guarantee the model parameters are identical across all processes
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pd/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,14 @@ def to_bool(flag: int | bool | str) -> bool:
"installation or recompiling with CINN enabled."
)

# NOTE: Allow the CINN compiler to optimize inputs with dynamic shapes,
# may lead to a slight performance decrease compared to static shapes.

# If you can confirm that the shape of the input tensors will not change,
# you can set it to False to further enhance performance.
# Otherwise, please use the default value(True) to improve runtime compatibility.
CINN_ALLOW_DYNAMIC_SHAPE = to_bool(os.environ.get("CINN_ALLOW_DYNAMIC_SHAPE", True))

CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory
ENERGY_BIAS_TRAINABLE = True
CUSTOM_OP_USE_JIT = to_bool(os.environ.get("CUSTOM_OP_USE_JIT", False))
Expand Down Expand Up @@ -199,6 +207,7 @@ def enable_prim(enable: bool = True):
__all__ = [
"CACHE_PER_SYS",
"CINN",
"CINN_ALLOW_DYNAMIC_SHAPE",
"CUSTOM_OP_USE_JIT",
"DEFAULT_PRECISION",
"DEVICE",
Expand Down
3 changes: 3 additions & 0 deletions doc/train/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,10 @@ $ dp --pd train input.json

# [experimental] training model with CINN compiler for better performance,
# see: https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/paddle_v3_features/cinn_cn.html
## If the shape(s) of batch input data are dynamic during training(default).
$ CINN=1 dp --pd train input.json
## If the shape(s) of batch input data are fixed during training, e.g., examples/water.
$ CINN=1 CINN_ALLOW_DYNAMIC_SHAPE=0 dp --pd train input.json
```

:::
Expand Down