Skip to content

Commit

Permalink
LBFGS batch size fix (#440)
Browse files Browse the repository at this point in the history
* General cleanup

* Batchwise alpha/beta/rho computation

* Refactors lbfgs step to not do an extra step after satisfying fmax

* Gets back update_graph; commented from main loop

* shape comment

* bugfix for nondeterministic fmax

* flatten energies per ocp2.0 changes

* lt logic

* update relaxation defaults

* update relaxation hyperparm defaults

---------

Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: Abhishek Das <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
Former-commit-id: 45779ead79cd383e9c405375c0278cd22afc6850
  • Loading branch information
4 people authored May 6, 2024
1 parent 8a003db commit 8f20f8e
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 24 deletions.
8 changes: 5 additions & 3 deletions configs/ocp_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ task:
num_relaxation_batches: 5
# Max no. of steps to run relaxations for.
relaxation_steps: 300
# Structure relaxation is terminated when this max force is achieved.
relaxation_fmax: 0.02
# Whether to save out the positions.
write_pos: True # True or False
# Path to initial structures to run relaxations on. Same as the IS2RE set.
Expand All @@ -99,10 +101,10 @@ task:
shard: 0 # int (optional)
relax_opt:
name: lbfgs
maxstep: 0.04
maxstep: 0.2
memory: 50
damping: 1.0
alpha: 70.0
damping: 1.2
alpha: 80.0
# Directory to save out trajectories (.traj files) in.
traj_dir: path/to/traj/directory
# Whether to save out the full trajectory or just the initial+final frames
Expand Down
2 changes: 1 addition & 1 deletion ocpmodels/common/relaxation/ase_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def batch_to_atoms(batch):
positions = torch.split(batch.pos, natoms)
tags = torch.split(batch.tags, natoms)
cells = batch.cell
energies = batch.y.tolist()
energies = batch.y.view(-1).tolist()

atoms_objects = []
for idx in range(n_systems):
Expand Down
6 changes: 3 additions & 3 deletions ocpmodels/common/relaxation/ml_relaxation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@ def ml_relax(
optimizer = LBFGS(
batch,
calc,
maxstep=relax_opt.get("maxstep", 0.04),
maxstep=relax_opt.get("maxstep", 0.2),
memory=relax_opt["memory"],
damping=relax_opt.get("damping", 1.0),
alpha=relax_opt.get("alpha", 70.0),
damping=relax_opt.get("damping", 1.2),
alpha=relax_opt.get("alpha", 80.0),
device=device,
save_full_traj=save_full_traj,
traj_dir=Path(traj_dir) if traj_dir is not None else None,
Expand Down
40 changes: 27 additions & 13 deletions ocpmodels/common/relaxation/optimizers/lbfgs_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def check_convergence(self, iteration, forces=None, energy=None):
# (batch_size) -> (nAtoms)
max_forces = max_forces_[self.batch.batch]

return max_forces.ge(self.fmax), energy, forces
return max_forces.lt(self.fmax), energy, forces

def run(self, fmax, steps):
self.fmax = fmax
Expand All @@ -109,9 +109,17 @@ def run(self, fmax, steps):

iteration = 0
converged = False
converged_mask = torch.zeros_like(
self.batch.atomic_numbers, device=self.device
).bool()
while iteration < steps and not converged:
update_mask, energy, forces = self.check_convergence(iteration)
converged = torch.all(torch.logical_not(update_mask))
_converged_mask, energy, forces = self.check_convergence(iteration)
# Models like GemNet-OC can have random noise in their predictions.
# Here we ensure atom positions are not being updated after already
# hitting the desired convergence criteria.
converged_mask = torch.logical_or(converged_mask, _converged_mask)
converged = torch.all(converged_mask)
update_mask = torch.logical_not(converged_mask)

if self.trajectories is not None and (
self.save_full or converged or iteration == steps - 1 or iteration == 0
Expand Down Expand Up @@ -145,6 +153,9 @@ def step(
forces: torch.Tensor | None,
update_mask: torch.Tensor,
) -> None:
def _batched_dot(x: torch.Tensor, y: torch.Tensor):
return scatter((x * y).sum(dim=-1), self.batch.batch, reduce="sum")

def determine_step(dr):
steplengths = torch.norm(dr, dim=1)
longest_steps = scatter(steplengths, self.batch.batch, reduce="max")
Expand All @@ -163,29 +174,32 @@ def determine_step(dr):

# Update s, y, rho
if iteration > 0:
s0 = (r - self.r0).flatten()
s0 = r - self.r0
self.s.append(s0)

y0 = -(forces - self.f0).flatten()
y0 = -(forces - self.f0)
self.y.append(y0)

self.rho.append(1.0 / torch.dot(y0, s0))
self.rho.append(1.0 / _batched_dot(y0, s0))

loopmax = min(self.memory, iteration)
alpha = forces.new_empty(loopmax)
q = -forces.flatten()
alpha = forces.new_empty(loopmax, self.batch.natoms.shape[0])
q = -forces

for i in range(loopmax - 1, -1, -1):
alpha[i] = self.rho[i] * torch.dot(self.s[i], q) # b
q -= alpha[i] * self.y[i]
alpha[i] = self.rho[i] * _batched_dot(self.s[i], q) # b
q -= alpha[i][self.batch.batch, ..., None] * self.y[i]

z = self.H0 * q
for i in range(loopmax):
beta = self.rho[i] * torch.dot(self.y[i], z)
z += self.s[i] * (alpha[i] - beta)
beta = self.rho[i] * _batched_dot(self.y[i], z)
z += self.s[i] * (
alpha[i][self.batch.batch, ..., None]
- beta[self.batch.batch, ..., None]
)

# descent direction
p = -z.reshape((-1, 3))
p = -z
dr = determine_step(p)
if torch.abs(dr).max() < 1e-7:
# Same configuration again (maybe a restart):
Expand Down
3 changes: 1 addition & 2 deletions ocpmodels/models/gemnet_oc/gemnet_oc.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,7 @@ def subselect_edges(
empty_image = subgraph["num_neighbors"] == 0
if torch.any(empty_image):
raise ValueError(
f"An image has no neighbors: id={data.id[empty_image]}, "
f"sid={data.sid[empty_image]}, fid={data.fid[empty_image]}"
f"An image has no neighbors: sid={data.sid[empty_image]}"
)
return subgraph

Expand Down
4 changes: 2 additions & 2 deletions ocpmodels/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,8 +568,8 @@ def run_relaxations(self, split="val"):
relaxed_batch = ml_relax(
batch=batch,
model=self,
steps=self.config["task"].get("relaxation_steps", 200),
fmax=self.config["task"].get("relaxation_fmax", 0.0),
steps=self.config["task"].get("relaxation_steps", 300),
fmax=self.config["task"].get("relaxation_fmax", 0.02),
relax_opt=self.config["task"]["relax_opt"],
save_full_traj=self.config["task"].get("save_full_traj", True),
device=self.device,
Expand Down

0 comments on commit 8f20f8e

Please sign in to comment.