Source code for gnrs.optimize.lbfgs
"""
This module provides a wrapper around the LBFGS implementation from ASE.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
from ase.atoms import Atoms
from ase.optimize import LBFGS
from ase.filters import FrechetCellFilter
from ase.constraints import FixSymmetry
from gnrs.core.optimizer import GeometryOptimizerABC
[docs]
class LBFGSOptimizer(GeometryOptimizerABC):
"""
Limited-memory BFGS optimization using ASE's LBFGS implementation.
Attributes:
fmax: Maximum force tolerance for convergence criterion
steps: Maximum number of optimization steps to perform
fix_sym: Whether to fix the symmetry of the structure
cell_opt: Whether to optimize the cell parameters as well
opt_name: Name of the optimizer for storing in the crystal info
converged: Whether the optimization successfully converged
"""
[docs]
def __init__(self, *args):
super().__init__(*args)
self.opt_name = "lbfgs"
self.fmax = self.tsk_set.pop("fmax")
self.steps = self.tsk_set.pop("steps")
self.fix_sym = self.tsk_set.pop("fix_sym")
self.cell_opt = self.tsk_set.pop("cell_opt")
[docs]
def optimize(self, xtal: Atoms) -> None:
"""
Performs geometry optimization using LBFGS algorithm.
Args:
xtal: ASE Atoms object
"""
# Assign the calculator to the structure
xtal.calc = self.energy_calc
if self.fix_sym:
xtal.set_constraint(FixSymmetry(xtal))
if self.cell_opt:
ecf = FrechetCellFilter(xtal)
dyn = LBFGS(ecf, master=True, logfile="lbfgs.log", **self.tsk_set)
else:
dyn = LBFGS(
xtal, master=True, logfile="lbfgs.log", **self.tsk_set
)
try:
self.converged = dyn.run(fmax=self.fmax, steps=self.steps)
except Exception:
self.converged = False
# Remove the constraints
if self.fix_sym:
del xtal.constraints
[docs]
def update(self, xtal: Atoms) -> None:
"""
Update the optimizer with the new structure.
"""
super().update(xtal)
try:
xtal.info[f"{self.opt_name}_{self.energy_method}"] = xtal.get_potential_energy()
except:
xtal.info[f"{self.opt_name}_{self.energy_method}"] = 0
xtal.info[self.opt_name] = "converged" if self.converged else "unconverged"