"""
Abstract base class for energy calculators.
This module provides the base class for implementing energy calculators.
It supports a GPU worker/feeder pattern: when running with more MPI ranks
than GPUs, only a subset of ranks (workers) load models onto GPUs.
The remaining ranks (feeders) send structures to workers via MPI.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import abc
from collections import deque
from typing import Any, Callable, Optional
from mpi4py import MPI
from ase import Atoms
from gnrs.core.gpu import GPUDeviceManager, TAG_WORK_DATA, TAG_WORK_RESULT, TAG_SHUTDOWN
[docs]
class EnergyCalculatorABC(abc.ABC):
"""
Abstract base class for energy calculators.
Supports two execution modes:
1. Direct mode (ranks <= GPUs or CPU-only calculators):
Every rank loads the model and computes locally.
2. Worker/feeder mode (ranks > GPUs, GPU-based calculators):
Worker ranks (one per GPU) load models. Feeder ranks send structures to their assigned worker and receive results back.
"""
requires_gpu: bool = False
[docs]
def __init__(self, comm: MPI.Comm, task_settings: dict, energy_name: str) -> None:
"""
Initialize the energy calculations.
Args:
comm: MPI communicator
task_settings: Task settings
energy_name: Energy name
"""
self.comm = comm
self.rank = comm.Get_rank()
self.size = comm.Get_size()
self.is_master = True if self.rank == 0 else False
self.tsk_set = task_settings
self.energy_name = energy_name
self.calc = None
self._gpu_mgr = None
self._use_worker_feeder = False
if self.requires_gpu:
max_workers = task_settings.get("max_workers_per_gpu", 1)
self._gpu_mgr = GPUDeviceManager(
comm, max_workers_per_gpu=max_workers,
)
self._use_worker_feeder = self._gpu_mgr.num_feeders > 0
@property
def device(self) -> str:
"""
Torch device string for this rank.
"""
if self._gpu_mgr is not None:
return self._gpu_mgr.device
return "cpu"
[docs]
def run(self, xtal: Atoms) -> None:
"""
Run the energy calculation on a single structure (direct mode only).
Args:
xtal: Crystal structure
"""
if self.energy_name in xtal.info:
return
self.initialize()
self.compute(xtal)
self.finalize()
[docs]
def run_batch(
self,
structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]] = None,
) -> None:
"""
Run energy calculations on a batch of structures.
Args:
structs: structure dictionary
on_structure_done: used for checkpoint saves
"""
if not self._use_worker_feeder:
for xtal in structs.values():
self.run(xtal)
if on_structure_done is not None:
on_structure_done()
elif self._gpu_mgr.is_worker:
self._worker_loop(structs, on_structure_done)
else:
self._feeder_loop(structs, on_structure_done)
self.comm.Barrier()
def _worker_loop(
self,
local_structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]],
) -> None:
"""
GPU worker: interleave local computation with feeder requests
"""
my_feeders = set()
for feeder_rank in self._gpu_mgr.feeder_ranks:
feeder_index = feeder_rank - self._gpu_mgr.num_workers
if feeder_index % self._gpu_mgr.num_workers == self.rank:
my_feeders.add(feeder_rank)
local_queue: deque[Atoms] = deque(
xtal for xtal in local_structs.values()
if self.energy_name not in xtal.info
)
while local_queue or my_feeders:
served = self._drain_feeder_requests(my_feeders)
if local_queue:
xtal = local_queue.popleft()
self.run(xtal)
if on_structure_done is not None:
on_structure_done()
continue
if my_feeders and not served:
status = MPI.Status()
data = self.comm.recv(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status,
)
self._handle_worker_msg(
data, status.Get_source(), status.Get_tag(), my_feeders,
)
def _drain_feeder_requests(self, active_feeders: set[int]) -> bool:
"""
Non-blocking: serve all pending feeder messages
Returns:
True if at least one message was processed
"""
served_any = False
while True:
status = MPI.Status()
has_msg = self.comm.iprobe(
source=MPI.ANY_SOURCE, tag=MPI.ANY_TAG, status=status,
)
if not has_msg:
break
data = self.comm.recv(source=status.Get_source(), tag=status.Get_tag())
self._handle_worker_msg(
data, status.Get_source(), status.Get_tag(), active_feeders,
)
served_any = True
return served_any
def _handle_worker_msg(
self, data: Any, source: int, tag: int, active_feeders: set[int],
) -> None:
"""
Process a single message received by a worker
"""
if tag == TAG_SHUTDOWN:
active_feeders.discard(source)
return
if tag == TAG_WORK_DATA:
name, xtal = data
self.initialize()
self.compute(xtal)
self.finalize()
energy = xtal.info.get(self.energy_name, 0)
self.comm.send((name, energy), dest=source, tag=TAG_WORK_RESULT)
def _feeder_loop(
self,
local_structs: dict[str, Atoms],
on_structure_done: Optional[Callable[[], None]],
) -> None:
"""
CPU feeder: delegate GPU computation to assigned worker
"""
worker = self._gpu_mgr.assigned_worker()
for name, xtal in local_structs.items():
if self.energy_name in xtal.info:
continue
self.comm.send((name, xtal), dest=worker, tag=TAG_WORK_DATA)
_, energy = self.comm.recv(source=worker, tag=TAG_WORK_RESULT)
xtal.info[self.energy_name] = energy
if on_structure_done is not None:
on_structure_done()
self.comm.send(None, dest=worker, tag=TAG_SHUTDOWN)
[docs]
@abc.abstractmethod
def initialize(self) -> None:
"""
Initialize the energy calculations.
"""
pass
[docs]
def get_calculator(self) -> Any:
"""
Returns the calculator.
"""
return self.calc
[docs]
@abc.abstractmethod
def compute(self, xtal: Atoms) -> None:
"""
Compute the energy.
Args:
xtal: Crystal structure
"""
pass
[docs]
@abc.abstractmethod
def finalize(self) -> None:
"""
Finalize the energy calculations.
"""
pass