Source code for gnrs.cluster.selection.center

"""
This module provides the CenterSelection class for selecting the center of a cluster.

This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations

__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"

import re
import logging

import gnrs.output as gout
from gnrs.core.selection import SelectionABC

logger = logging.getLogger("CenterSelection")


[docs] class CENTERSelection(SelectionABC): """ Selection class that identifies and keeps only the center crystals from clusters. This class can select centers based on either: 1. Crystals marked as centers during clustering 2. Crystals with minimum value of a specified property within each cluster """
[docs] def initialize(self) -> None: """ Initialize the center selection process. """ self.logger = logger super().initialize() self.settings["property_name"] = None if self.settings["final_num_cluster"] is not None: self.num_clusters = self.settings["final_num_cluster"] else: self.num_clusters = self.settings["n_clusters"] logger.debug(f"Final cluster {self.num_clusters}") self.cluster_name = self.settings["cluster_name"] self.filter = self.settings.get("filter", 'center') gout.emit(f"Selecting cluster centers using {self.filter}...") logger.info(f"Selecting cluster centers using {self.filter}...")
[docs] def finalize(self) -> None: """ Finalize the selection process. """ logger.info(f"Completed selecting cluster centers using {self.filter}.") gout.emit(f"Completed selecting cluster centers using {self.filter}.\n") gout.emit("")
[docs] def select(self, struct_dict: dict) -> None: """ Select crystals based on configured criteria. Args: struct_dict: Crystals dictionary """ self.struct_dict = struct_dict if self.filter != 'center': # Select based on minimum property value self._select_by_property() else: # Select based on center flag from clustering self._select_by_center_flag()
def _select_by_property(self) -> None: """ Select crystals with minimum property value in each cluster. """ min_xtals = [] # Find minimum property xtal for each cluster for idx in range(self.num_clusters): min_proper = self._get_min_property_xtal(idx) min_proper_xtal = self._get_min_across_ranks(min_proper) if min_proper_xtal is not None: min_xtals.append(min_proper_xtal) self.comm.Barrier() self._filter_xtals(min_xtals) self.comm.Barrier() def _select_by_center_flag(self) -> None: """ Select crystals marked as centers during clustering. """ # Gather centers from all ranks center_xtals = self._gather_center_xtals() self.comm.Barrier() self._filter_xtals(center_xtals) self.comm.Barrier() def _get_min_property_xtal(self, idx: int) -> list: """ Find crystal with minimum property value in specified cluster on this rank. Args: idx: Cluster index Returns: List containing [xtal_id, property_value] or empty list """ min_xtal = [] min_value = float('inf') for _id, xtal in self.struct_dict.items(): xtal_cluster = int(re.search(r'\d+', xtal.info[self.cluster_name]).group()) if xtal_cluster != idx: continue property_value = float(xtal.info[self.filter]) if property_value < min_value: min_value = property_value min_xtal = [_id, property_value] return min_xtal def _gather_center_xtals(self) -> list: """ Gather all crystals marked as centers across all ranks. Returns: List of crystal names that are centers """ center_list = [] for _id, xtal in self.struct_dict.items(): if "center" in xtal.info.get(self.cluster_name, ""): center_list.append(_id) centers = self.comm.gather(center_list, root=0) if self.is_master: center_xtals = [item for sublist in centers if sublist for item in sublist] else: center_xtals = None return self.comm.bcast(center_xtals, root=0) def _get_min_across_ranks(self, min_list: list) -> str | None: """ Find crystal with minimum property value across all ranks. Args: min_list: [xtal_id, property_value] Returns: ID of crystal with minimum property value """ all_min_lists = self.comm.gather(min_list, root=0) if self.is_master: min_entries = [e for e in all_min_lists if e] if not min_entries: min_all_ranks = [] else: min_val = min(e[1] for e in min_entries) min_entries = [e for e in min_entries if e[1] == min_val] min_all_ranks = min(min_entries, key=lambda x: x[0]) if min_entries else [] else: min_all_ranks = None min_all_ranks = self.comm.bcast(min_all_ranks, root=0) if min_all_ranks: return min_all_ranks[0] else: return None def _filter_xtals(self, keep_ids: list) -> None: """ Remove all crystals except those in the keep list. Args: keep_ids: List of crystal IDs to keep """ for _id in list(self.struct_dict.keys()): if _id not in keep_ids: del self.struct_dict[_id]