Source code for gnrs.deduplication.dedup
"""
Duplicate structure removal using pymatgen StructureMatcher.
Structures are grouped by space group for computational efficiency,
then within each space group a reference structure is broadcast to all MPI ranks
and compared against the remaining candidates in parallel.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import logging
import random
from collections import defaultdict
from ase.atoms import Atoms
from pymatgen.analysis.structure_matcher import StructureMatcher
from pymatgen.io.ase import AseAtomsAdaptor
import gnrs.parallel as gp
logger = logging.getLogger("dedup")
[docs]
def group_by_spg(structs: dict[str, Atoms]) -> dict[int, dict[str, Atoms]]:
"""
Group structures by space group.
Args:
structs: {name: Atoms}.
Returns:
{spg: {name: Atoms, ...}}.
"""
groups: dict[int, dict[str, Atoms]] = defaultdict(dict)
for name, xtal in structs.items():
spg = xtal.info.get("spg")
groups[spg][name] = xtal
return groups
def _select(
candidates: dict[str, Atoms],
energy_key: str | None
) -> str:
"""
Select one structure from a set of duplicates.
If energy_key is provided, the lowest-energy structure is chosen.
Otherwise a random one is chosen.
Args:
candidates: {name: Atoms} duplicates.
energy_key: Key in Atoms.info for energy, or None.
Returns:
Name of the chosen structure.
"""
if energy_key is not None:
energies = []
for name, xtal in candidates.items():
e = xtal.info.get(energy_key)
if e is not None:
energies.append((name, float(e)))
if len(energies) == len(candidates):
return min(energies, key=lambda x: x[1])[0]
return random.choice(sorted(candidates.keys()))
def _scatter_structs(pool: dict[str, Atoms]) -> dict[str, Atoms]:
"""
Master scatters a dict of structures evenly across ranks.
"""
scatter_list = None
if gp.is_master:
items = list(pool.items())
n = len(items)
per_rank = n // gp.size
remainder = n % gp.size
scatter_list = []
start = 0
for r in range(gp.size):
chunk = per_rank + (1 if r < remainder else 0)
scatter_list.append(dict(items[start : start + chunk]))
start += chunk
return gp.comm.scatter(scatter_list, root=0)
[docs]
def dedup_group(
pool: dict[str, Atoms],
matcher: StructureMatcher,
spg: int | None,
energy_key: str | None,
) -> dict[str, Atoms]:
"""
Remove duplicates from a space group in parallel.
1. Master picks one candidate from the pool and broadcasts its
pymatgen Structure to all ranks.
2. The remaining structures are scattered across ranks; each rank
tests matcher.fit(candidate, local_struct) in parallel.
3. Match results are gathered. Master collects the duplicate
cluster, selects the best structure, and removes
duplicates from the pool until the pool is empty.
Args:
pool: {name: Atoms} — all structures in this space group
(only meaningful on master; ignored on workers).
matcher: Configured StructureMatcher instance.
spg: Space group.
energy_key: Key in Atoms.info for energy, or None.
Returns:
{name: Atoms} — unique structures in the space group.
"""
kept = {}
while True:
n_rem = len(pool) if gp.is_master else 0
n_rem = gp.comm.bcast(n_rem, root=0)
if n_rem == 0:
break
if gp.is_master:
ref_name = next(iter(pool))
ref_xtal = pool.pop(ref_name)
pmg_ref = AseAtomsAdaptor.get_structure(ref_xtal)
else:
ref_name = None
ref_xtal = None
pmg_ref = None
ref_name = gp.comm.bcast(ref_name, root=0)
ref_xtal = gp.comm.bcast(ref_xtal, root=0)
pmg_ref = gp.comm.bcast(pmg_ref, root=0)
local_chunk = _scatter_structs(pool if gp.is_master else {})
local_matches = []
for name, xtal in local_chunk.items():
pmg_xtal = AseAtomsAdaptor.get_structure(xtal)
if matcher.fit(pmg_ref, pmg_xtal):
local_matches.append(name)
all_matches = gp.comm.gather(local_matches, root=0)
if gp.is_master:
match_names = set()
for sublist in all_matches:
match_names.update(sublist)
cluster = {ref_name: ref_xtal}
for mn in match_names:
cluster[mn] = pool.pop(mn)
best = _select(cluster, energy_key)
kept[best] = cluster[best]
logger.debug(
f"SPG {spg}: remaining pool: {len(pool)}"
)
kept = gp.comm.bcast(kept, root=0)
return kept