"""
Abstract base class for geometry optimization.
This module provides the base class for implementing geometry optimization.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import abc
import logging
from collections import deque
from typing import Callable, Optional
import numpy as np
from mpi4py import MPI
from ase.atoms import Atoms
from gnrs.core.gpu import GPUDeviceManager
logger = logging.getLogger("optimizer")
# MPI tags specific to optimization worker/feeder (offset from energy tags)
TAG_OPT_DATA = 200
TAG_OPT_RESULT = 201
TAG_OPT_SHUTDOWN = 202
[docs]
class GeometryOptimizerABC(abc.ABC):
"""
Abstract class for Geometry optimization methods.
Supports three execution modes:
1. Direct mode (ranks <= GPUs or CPU-only calculators):
Every rank runs the optimizer locally.
2. Worker/feeder mode (ranks > GPUs, GPU-based calculators):
Worker ranks run optimizations on GPU. Feeder ranks send
structures to their assigned worker and receive optimized
results back.
3. Serial DFT mode:
Only rank 0 runs optimizations. Results are broadcast to all
ranks. Use when the DFT binary needs the full allocation.
All optimizers should inherit this class.
"""
[docs]
def __init__(
self,
comm: MPI.Comm,
task_set: dict,
opt_name: str = "relax",
energy_method: str | None = None,
energy_calc: any | None = None,
gpu_mgr: GPUDeviceManager | None = None,
dft_serial_mode: bool = False,
) -> None:
"""
Initialize the geometry optimizer.
Args:
comm: MPI communicator for parallel computation
task_set: Optimization settings
opt_name: Optimizer
energy_method: Energy calculation method
energy_calc: ASE calculator
gpu_mgr: GPU device manager
dft_serial_mode: If True, only rank 0 runs optimizations and
results are broadcast to all ranks.
"""
self.opt_name = opt_name
self.comm = comm
self.rank = comm.Get_rank()
self.size = comm.Get_size()
self.is_master = self.rank == 0
self.tsk_set = task_set
self.energy_method = energy_method
self.energy_calc = energy_calc
self.converged = False
self._gpu_mgr = gpu_mgr
self._use_worker_feeder = (
gpu_mgr is not None and gpu_mgr.num_feeders > 0
)
self._dft_serial_mode = dft_serial_mode
[docs]
def run(self, xtal: Atoms) -> None:
"""
Run the optimization workflow.
1. Initialize
2. Perform optimization
3. Update structure information
4. Finalize
Args:
xtal: ASE Atoms object representing the crystal structure
"""
self.initialize()
self.optimize(xtal)
self.update(xtal)
self.finalize(xtal)
[docs]
def run_batch(
self,
structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]] = None,
) -> None:
"""
Run optimization on a batch of structures.
Args:
structs: structure dictionary
on_structure_done: used for checkpoint saves
"""
if self._dft_serial_mode:
self._serial_dft_batch(structs, on_structure_done)
elif not self._use_worker_feeder:
failed = []
for name, xtal in list(structs.items()):
try:
self.run(xtal)
except (ValueError, RuntimeError) as e:
logger.warning(
"Optimization failed for %s: %s", name, e,
)
failed.append(name)
continue
if on_structure_done is not None:
on_structure_done()
for name in failed:
del structs[name]
elif self._gpu_mgr.is_worker:
self._worker_loop(structs, on_structure_done)
else:
self._feeder_loop(structs, on_structure_done)
self.comm.Barrier()
def _serial_dft_batch(
self,
structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]],
) -> None:
"""
Serial DFT mode: only rank 0 runs optimizations.
"""
local_items = list(structs.items())
all_items = self.comm.gather(local_items, root=0)
results: dict[str, tuple[dict, np.ndarray, np.ndarray]] | None = None
if self.is_master:
flat = [(n, x) for rank_items in all_items for n, x in rank_items]
logger.info(
"dft_mode=serial: rank 0 optimizing %d structures",
len(flat),
)
results = {}
for name, xtal in flat:
self.run(xtal)
results[name] = (
xtal.info.copy(),
np.array(xtal.positions),
np.array(xtal.cell),
)
if on_structure_done is not None:
on_structure_done()
results = self.comm.bcast(results, root=0)
for name, xtal in structs.items():
if name in results:
info, positions, cell = results[name]
xtal.info.update(info)
xtal.positions = positions
xtal.cell = cell
def _worker_loop(
self,
local_structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]],
) -> None:
"""
GPU worker: interleave local computation with feeder requests
"""
my_feeders = set()
for feeder_rank in self._gpu_mgr.feeder_ranks:
feeder_index = feeder_rank - self._gpu_mgr.num_workers
if feeder_index % self._gpu_mgr.num_workers == self.rank:
my_feeders.add(feeder_rank)
local_queue: deque[Atoms] = deque(
xtal for xtal in local_structs.values()
if self.opt_name not in xtal.info
)
while local_queue or my_feeders:
served = self._drain_feeder_requests(my_feeders)
if local_queue:
xtal = local_queue.popleft()
self.run(xtal)
if on_structure_done is not None:
on_structure_done()
continue
if my_feeders and not served:
status = MPI.Status()
data = self.comm.recv(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status,
)
self._handle_worker_msg(
data, status.Get_source(), status.Get_tag(), my_feeders,
)
def _drain_feeder_requests(self, active_feeders: set[int]) -> bool:
"""
Non-blocking: serve all pending feeder messages
Returns:
True if at least one message was processed
"""
served_any = False
while True:
status = MPI.Status()
has_msg = self.comm.iprobe(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status,
)
if not has_msg:
break
data = self.comm.recv(
source=status.Get_source(), tag=status.Get_tag(),
)
self._handle_worker_msg(
data, status.Get_source(), status.Get_tag(), active_feeders,
)
served_any = True
return served_any
def _handle_worker_msg(
self,
data: object,
source: int,
tag: int,
active_feeders: set[int],
) -> None:
"""
Process a single message received by a worker from a feeder.
"""
if tag == TAG_OPT_SHUTDOWN:
active_feeders.discard(source)
return
if tag == TAG_OPT_DATA:
name, xtal = data
self.run(xtal)
result = (
name,
xtal.info.copy(),
np.array(xtal.positions),
np.array(xtal.cell),
)
self.comm.send(result, dest=source, tag=TAG_OPT_RESULT)
def _feeder_loop(
self,
local_structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]],
) -> None:
"""
CPU feeder: delegate optimization to assigned GPU worker.
"""
worker = self._gpu_mgr.assigned_worker()
for name, xtal in local_structs.items():
if self.opt_name in xtal.info:
continue
self.comm.send((name, xtal), dest=worker, tag=TAG_OPT_DATA)
_, info, positions, cell = self.comm.recv(
source=worker, tag=TAG_OPT_RESULT,
)
xtal.info.update(info)
xtal.positions = positions
xtal.cell = cell
if on_structure_done is not None:
on_structure_done()
self.comm.send(None, dest=worker, tag=TAG_OPT_SHUTDOWN)
[docs]
def initialize(self) -> None:
"""
Initialize for optimization.
"""
pass
[docs]
@abc.abstractmethod
def optimize(self, xtal: Atoms) -> None:
"""
Perform optimization.
Args:
xtal: ASE Atoms object
"""
pass
[docs]
@abc.abstractmethod
def update(self, xtal: Atoms) -> None:
"""
Update the geometry and add energy information.
Args:
xtal: ASE Atoms object
"""
pass
[docs]
def finalize(self, xtal: Atoms) -> None:
"""
Finalize the optimization and clean up.
Args:
xtal: ASE Atoms object
"""
xtal.calc = None