Source code for gnrs.parallel.structs
"""
This module contains the DistributedStructs class, which is used to handle distributed structure dictionaries.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import json
import logging
import os
import numpy as np
from pathlib import Path
from ase.atoms import Atoms
from ase.io.jsonio import decode, encode
import gnrs.parallel as gp
logger = logging.getLogger("DistributedStructs")
[docs]
class DistributedStructs:
"""
Contains functions for handling distributed structure dictionaries.
"""
[docs]
def __init__(self, structs: dict):
"""
Initialize with a dictionary of structures.
Args:
structs: Dictionary mapping structure names to ASE Atoms objects
Raises:
ValueError: If structs is not a dictionary
"""
self.structs = structs
self.logger = logger
[docs]
def get_num_structs(self) -> int:
"""
Get the total number of structures in a distributed
structures dictionary.
Returns:
Total number of structures across all ranks
"""
# Handle structs = None case
if self.structs is not None:
num_each_rank = len(self.structs)
else:
num_each_rank = 0
total_num_list = gp.comm.gather(num_each_rank, root=0)
if gp.is_master:
total_num = sum(total_num_list)
logger.debug(f"xtal distribution across cores: {total_num_list}")
else:
total_num = None
total_num = gp.comm.bcast(total_num, root=0)
return total_num
[docs]
def find_matches(self, target: Atoms, settings: dict | None = None) -> list:
"""
Runs pymatgen duplicate checks on a distributed struct dictionary.
Args:
target: Target structure to be matched
settings: Settings for pymatgen StructureMatcher
Returns:
List of matching structure IDs
"""
from pymatgen.io.ase import AseAtomsAdaptor
from pymatgen.analysis.structure_matcher import StructureMatcher
pmg_target = AseAtomsAdaptor.get_structure(target)
if settings is None:
settings = {"stol": 0.5, "ltol": 0.5, "angle_tol": 10}
matcher = StructureMatcher(**settings)
match_list = []
for name, xtal in self.structs.items():
pmg_xtal = AseAtomsAdaptor.get_structure(xtal)
if matcher.fit(pmg_target, pmg_xtal):
match_list.append(name)
# Combine match list and flatten
match_list = gp.comm.allgather(match_list)
match_list = [item for sublist in match_list for item in sublist]
logger.info(f"Matched with {len(match_list)} structures")
logger.debug(f"Matched structures = {match_list}")
return match_list
[docs]
def collect_property(self, prpty: str, ptype: str = "info") -> list:
"""
Collects the property of all the structures into a list.
Args:
prpty: Property name to collect
ptype: Property type, either 'info' or 'method'
Returns:
List of property values on master rank, None on other ranks
"""
# Construct property list for each rank
prpty_list = []
for xtal in self.structs.values():
if ptype == "method":
prop = getattr(xtal, prpty)()
elif ptype == "info":
prop = xtal.info.get(prpty)
prpty_list.append(prop)
# Combine
prpty_list = gp.comm.gather(prpty_list)
if gp.is_master:
prpty_list = [item for sublist in prpty_list for item in sublist]
return prpty_list
[docs]
def get_statistics(self, prpty: str, ptype: str = "info") -> dict:
"""
Gets the statistics on a property of interest.
Args:
prpty: Property name to analyze
ptype: Get property from either 'info' or 'method'
Returns:
Dictionary with statistics on master rank, None on other ranks
"""
prpty_list = self.collect_property(prpty, ptype)
if not gp.is_master:
return None
prpty_array = np.array(prpty_list)
stats = {
"Minimum": np.min(prpty_array),
"Maximum": np.max(prpty_array),
"Average": np.average(prpty_array),
"Std": np.std(prpty_array),
}
return stats
[docs]
def find_spg(self, tol: float = 0.001) -> None:
"""
Finds the space group of all structures.
Space group number is stored in info["spg"]
Args:
tol: Tolerance for symmetry finding
"""
from ase.spacegroup.spacegroup import get_spacegroup
for struct in self.structs.values():
struct.info["spg"] = get_spacegroup(struct, symprec=tol).no
[docs]
def checkpoint_save(self, path: str) -> None:
"""
Checkpoints partially done calculation for restart.
This routine doesn't communicate to others so that
ranks can execute independently.
Args:
path: Directory path to save checkpoint
"""
structs_str = {name: encode(xtal) for name, xtal in self.structs.items()}
save_path = os.path.join(path, f"{gp.rank}.save")
with open(save_path, "w") as chk:
json.dump(structs_str, chk)
[docs]
def checkpoint_load(self, path: str) -> None:
"""
Loads checkpoint. Unlike save, load is a collective
and blocking operation.
Args:
path: Directory path to load checkpoint from
"""
# Get all checkpoint files
checkpoints = None
if gp.is_master:
checkpoints = list(Path(path).rglob("*.save"))
# Bcast and partition among ranks
checkpoints = gp.comm.bcast(checkpoints, root=0)
checkpoints = checkpoints[gp.rank::gp.size]
saved_structs = {}
for checkpoint in checkpoints:
with open(checkpoint, "r") as chk:
saved_str = json.load(chk)
saved_structs.update({name: decode(xtal) for name, xtal in saved_str.items()})
self.structs = saved_structs
self.redistribute()
self.logger.debug(f"Read {self.get_num_structs()} from checkpoint")
[docs]
def redistribute(self) -> None:
"""
Redistribute structures such that all ranks have almost
equal number of structures. Helpful for balancing load
across cores
"""
allstructs = gp.comm.gather(self.structs, root=0)
scatter_list = None
# Assemble the list to be scattered
if gp.is_master:
# Combine all dictionaries
combined_structs = {}
for struct_dict in allstructs:
combined_structs.update(struct_dict)
# Split dict into list of dicts
items = list(combined_structs.items())
num_per_rank = len(combined_structs) // gp.size
remainder = len(combined_structs) % gp.size
scatter_list = []
start_idx = 0
for rank in range(gp.size):
slice_size = num_per_rank + (1 if rank < remainder else 0)
end_idx = start_idx + slice_size
scatter_list.append(dict(items[start_idx:end_idx]))
start_idx = end_idx
scatter_list.reverse()
self.structs = gp.comm.scatter(scatter_list, root=0)