Source code for gnrs.descriptor.descriptor_task
"""
This module provides the DescriptorEvaluationTask class for evaluating descriptors.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import os
import logging
import importlib
import numpy as np
from mpi4py import MPI
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import gnrs.parallel as gp
import gnrs.output as gout
from gnrs.core.task import TaskABC
from gnrs.parallel.io import write_parallel
AVAILABLE_DESCRIPTORS = ["ACSF"]
logger = logging.getLogger("DescriptorTask")
[docs]
class DescriptorEvaluationTask(TaskABC):
"""
Task for evaluating crystal structure descriptors.
"""
[docs]
def __init__(
self,
comm: MPI.Comm,
config: dict,
gnrs_info: dict,
descriptor: str
) -> None:
"""Initialize the descriptor evaluation task.
Args:
comm: MPI communicator
config: Config dictionary
gnrs_info: Genarris info dictionary
descriptor: Descriptor class name
"""
super().__init__(comm, config, gnrs_info)
self.desc_name = descriptor.lower()
self.task_name = self.desc_name
self.desc_file = f"gnrs.descriptor.{self.desc_name}"
self.desc_class = f"{descriptor.upper()}Descriptor"
self.explain_variance: float | None = None
self.desc = None
# Check if the descriptor is implemented
try:
desc_module = importlib.import_module(self.desc_file)
self.desc = getattr(desc_module, self.desc_class)
except (ImportError, AttributeError) as e:
logger.error(f"Unable to find requested descriptor: {str(e)}")
logger.error(f"Available descriptors: {AVAILABLE_DESCRIPTORS}")
raise
[docs]
def initialize(self) -> None:
"""
Initialize the descriptor evaluation task.
"""
title = f"Descriptor Evaluation: {self.desc_name}"
super().initialize(self.task_name, title)
logger.info(f"Starting descriptor evaluation task: {self.desc_name}")
[docs]
def pack_settings(self) -> dict:
"""
Pack settings needed for descriptor evaluation.
Returns:
Task settings dictionary
"""
task_set = {"molecule_path": self.gnrs_info["molecule_path"]}
task_set.update(self.config[self.task_name])
return task_set
[docs]
def print_settings(self, task_set: dict) -> None:
"""
Print task settings in a formatted table.
Args:
task_set: Task settings dictionary
"""
super().print_settings(task_set)
[docs]
def create_folders(self) -> None:
"""
Create folders needed for the task.
"""
super().create_folders()
[docs]
def collect_results(self) -> None:
"""
Collect and save the results of the task.
"""
write_parallel(self.struct_path, self.structs)
[docs]
def analyze(self) -> None:
"""
Analyze the results of the task.
"""
pass
[docs]
def finalize(self) -> None:
"""
Finalize the task and update runtime settings.
"""
logger.info(f"Completed {self.desc_name} descriptor evaluation")
super().finalize(self.task_name)
def _pca_compression(self) -> None:
"""
Perform PCA compression on the descriptor.
"""
logger.info("Performing PCA compression")
local_features = np.array(
[xtal.info[self.task_name][0, :] for xtal in self.structs.values()]
)
all_features = self.comm.gather(local_features, root=0)
if self.is_master:
features = np.vstack([
feat for sublist in all_features if sublist is not None
for feat in sublist
])
n_samples, n_features = features.shape
logger.info(f"PCA input: {n_samples} samples with {n_features} features")
pca = PCA(
n_components=self.n_components,
whiten=True,
random_state=gp.base_seed,
)
pca.fit(features)
explained_var = np.sum(pca.explained_variance_ratio_)
logger.info(f"PCA compression: {pca.n_components_} components explain {explained_var:.4f} of variance")
else:
pca = None
pca = self.comm.bcast(pca, root=0)
for xtal in self.structs.values():
feature = xtal.info[self.task_name]
compressed = pca.transform(feature.reshape(1, -1))
xtal.info[f"{self.task_name}_pca"] = compressed
def _clean_full_descriptor(self) -> None:
"""
Remove the uncompressed descriptor to save memory.
"""
for xtal in self.structs.values():
del xtal.info[self.task_name]
def _standardize(self) -> None:
"""
Standardize the descriptor using StandardScaler.
Args:
name: Name of the descriptor to standardize
"""
local_features = np.array([x.info[self.task_name][0, :] for x in self.structs.values()])
fp = np.memmap(
f"features_{self.rank}.dat", dtype="float32", mode="w+", shape=local_features.shape
)
fp[:] = local_features[:]
fp.flush()
all_features = self.comm.gather(fp, root=0)
if self.is_master:
scaler = StandardScaler()
features = np.vstack([
feat for sublist in all_features if sublist is not None
for feat in sublist
])
scaler.fit(features)
else:
scaler = None
scaler = self.comm.bcast(scaler, root=0)
for xtal in self.structs.values():
feature = xtal.info[self.task_name]
scaled = scaler.transform(feature.reshape(1, -1))
xtal.info[self.task_name] = scaled