Source code for gnrs.descriptor.acsf
"""
This module provides the ACSF descriptor implementation.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
from mpi4py import MPI
from ase import Atoms
from ase.io import read
from dscribe.descriptors import ACSF
from gnrs.core.descriptor import DescriptorABC
[docs]
class ACSFDescriptor(DescriptorABC):
"""
Computes Atom-Centered Symmetry Function (ACSF) descriptors.
ACSFs can be used to represent the local environment near an atom
by using a fingerprint composed of the output of multiple two- and
three-body functions that can be customized to detect specific structural features.
"""
[docs]
def __init__(self, comm: MPI.Comm, task_settings: dict) -> None:
"""
Initialize the ACSF descriptor calculator.
Args:
comm: MPI communicator for parallel computation
task_settings: Task settings Dictionary for ACSF descriptor
"""
super().__init__(comm, task_settings)
mol_path = self.tsk_set.get("molecule_path", None)
r_cut = self.tsk_set.get("r_cut", None)
g2_params = self.tsk_set.get("g2_params", None)
g3_params = self.tsk_set.get("g3_params", None)
g4_params = self.tsk_set.get("g4_params", None)
self.vector_pooling = self.tsk_set.get("vector_pooling", None)
species_list = []
unique_species = set()
mol_len = 0
for mpth in mol_path:
mol = read(mpth, parallel=False)
symbols = list(mol.symbols)
species_list.extend(symbols)
unique_species.update(symbols)
mol_len += len(symbols)
self.mol_len = mol_len
# Initialize ACSF descriptor
self.acsf = ACSF(
r_cut=r_cut,
g2_params=g2_params,
g3_params=g3_params,
g4_params=g4_params,
species=unique_species,
periodic=True
)
[docs]
def initialize(self) -> None:
"""
Initialization step.
"""
pass
[docs]
def compute(self, xtal: Atoms) -> None:
"""
Compute ACSF descriptors for the given crystal structure.
Args:
xtal: Crystal structure to compute descriptors.
"""
acsf_xtal = self.acsf.create(
xtal,
centers=None,
n_jobs=1,
verbose=False
)
if self.vector_pooling is None:
xtal.info["acsf"] = acsf_xtal[:self.mol_len].reshape(1, -1)
elif self.vector_pooling == "mean":
xtal.info["acsf"] = acsf_xtal.mean(axis=0, keepdims=True)
elif self.vector_pooling == "sum":
xtal.info["acsf"] = acsf_xtal.sum(axis=0, keepdims=True)
elif self.vector_pooling == "max":
xtal.info["acsf"] = acsf_xtal.max(axis=0, keepdims=True)
del acsf_xtal
return
[docs]
def finalize(self) -> None:
"""
Finalization step.
"""
pass