Source code for gnrs.cluster.selection.window
"""
This module provides the WindowSelection class for selecting structures within a energy window.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import re
import logging
import gnrs.output as gout
from gnrs.core.selection import SelectionABC
from gnrs.gnrsutil.core import eV2kJ
logger = logging.getLogger("WindowSelection")
[docs]
class WINDOWSelection(SelectionABC):
"""
Selection class that selects structures within a energy window from clusters.
"""
[docs]
def initialize(self) -> None:
"""
Initialize the window selection process.
"""
self.logger = logger
super().initialize()
# Number of structures to select from each cluster
self.n_structs = self.settings.get("n_structs_per_cluster", None)
# Energy range above minimum energy to include structures
self.e_window = self.settings.get("energy_window", None)
if self.settings["final_num_cluster"] is not None:
self.n_clusters = self.settings["final_num_cluster"]
else:
self.n_clusters = self.settings["n_clusters"]
logger.debug(f"Final cluster {self.n_clusters}")
self.clst_name = self.settings["cluster_name"]
self.filter = self.settings["filter"]
self.Z = self.settings["z"]
gout.emit(f"Selecting structures within {self.e_window} energy window from clusters...")
logger.info(f"Selecting structures within {self.e_window} energy window from clusters")
[docs]
def finalize(self) -> None:
"""
Finalize the selection process.
"""
logger.info(f"Completed selecting structures within {self.e_window} energy window from clusters")
gout.emit(f"Completed selecting structures within {self.e_window} energy window from clusters.\n")
gout.emit("")
[docs]
def select(self, struct_dict: dict) -> None:
"""
Select structures within energy window from each cluster.
Args:
struct_dict: Crystals dictionary
"""
self.struct_dict = struct_dict
sel_xtals = []
for idx in range(self.n_clusters):
clst_structs = self.get_energy(idx)
clst_structs = self.get_range_allranks(clst_structs)
sel_xtals.extend(clst_structs)
logger.info(f"Total selected structures: {len(sel_xtals)}")
self.comm.Barrier()
self._filter_xtals(sel_xtals)
self.comm.Barrier()
[docs]
def get_energy(self, idx: int) -> list:
"""
Get energy for structures in the specified cluster.
Args:
idx: Cluster index
Returns:
List of [name, energy] pairs for structures in cluster
"""
structs = []
for _id, xtal in self.struct_dict.items():
xtal_cluster = int(re.search(r'\d+', xtal.info[self.clst_name]).group())
if idx == xtal_cluster:
e = float(xtal.info[self.filter])
structs.append([_id, e])
return structs
[docs]
def get_window_across_ranks(self, structs: list) -> list:
"""
Get structures within energy window across all ranks.
Args:
structs: List of [xtal_id, energy] pairs for structures in a cluster
Returns:
List of xtal ids within energy window
"""
structs = self.comm.gather(structs, root=0)
if self.is_master:
# Flatten list of lists from all ranks
all_structs = [s for rank_structs in structs if rank_structs
for s in rank_structs]
if not all_structs:
logger.debug("No structures found in cluster")
selected = []
else:
# Find minimum energy and calculate relative lattice energies
min_e = min(s[1] for s in all_structs)
rel_energies = sorted([(_id, eV2kJ((e - min_e) / self.Z))
for _id, e in all_structs], key=lambda x: x[1])
# Select structures within energy window
selected = [_id for _id, rel_e in rel_energies
if 0 <= rel_e <= self.e_window]
# Limit number of structures if specified
if self.n_structs is not None:
selected = selected[:min(len(selected), self.n_structs)]
logger.info(f"Selected {len(selected)} structures within lattice energy "
f"window of {self.e_window:.3f} kJ/mol")
else:
selected = None
selected = self.comm.bcast(selected, root=0)
return selected
def _filter_xtals(self, keep_ids: list) -> None:
"""
Remove structures not in selected list.
Args:
keep_ids: List of structure names to keep
"""
for _id in list(self.struct_dict.keys()):
if _id not in keep_ids:
del self.struct_dict[_id]