Source code for gnrs.parallel.io

"""
This module provides functionality for reading and writing parallel data.

This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""

from __future__ import annotations
__author__ = ["Yi Yang", "Rithwik Tom"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"

import logging
import random

from ase import Atoms
from ase.io.jsonio import encode, decode

import gnrs.parallel as gp

logger = logging.getLogger("parallel_io")


[docs] def read_geometry_out(file_path: str) -> dict: """ Master process reads geometry file and scatters data to other processes. Args: file_path: Path to the geometry output file Returns: Dictionary mapping random IDs to Atoms objects """ if gp.is_master: with open(file_path, "r") as gfile: str_data = gfile.read() str_data = str_data.split("####### END STRUCTURE #######") str_data = str_data[:-1] str_data = _make_scatterable_form(str_data) else: str_data = None str_data = gp.comm.scatter(str_data, root=0) struct_list = [ str2atoms(str_geo.split("\n")) for str_geo in str_data if str_geo is not None ] # random IDs struct_dict = {f"{random.getrandbits(60):x}": s for s in struct_list} return struct_dict
[docs] def str2atoms(geometry_str: list) -> Atoms: """ Constructs Atoms object from aims geometry format. Args: geometry_string: List of strings containing geometry data Returns: ASE Atoms object representing the crystal structure """ species, cell, pos, spg = [], [], [], None for line in geometry_str: sline = line.split() if not sline: continue if "lattice_vector" in line: cell.append([float(x) for x in sline[1:4]]) elif sline[0] == "atom": pos.append([float(x) for x in sline[1:4]]) species.append(sline[4]) elif "SPGLIB_detected_spacegroup" in line: spg = int(sline[-1]) xtal = Atoms(symbols="".join(species), positions=pos, cell=cell, pbc=True) if spg is not None: xtal.info["spg"] = spg return xtal
[docs] def write_parallel(file_path: str, struct_dict: dict, gather: bool = True, mode: str = "w") -> None: """ Convert structures to JSON strings, gather and store to file. Args: file_path: Path to output file struct_dict: Dictionary of structures to write gather: Whether to gather data from all processes mode: File opening mode """ if not struct_dict: logger.info("No structures to write!") return # Convert to list of JSON strings str_list = [f'"{k}": {encode(v)},\n' for k, v in struct_dict.items()] if gather: str_list = gp.comm.gather(str_list, root=0) if gp.is_master and str_list: num_structs = sum(len(e) for e in str_list) logger.info(f"Writing {num_structs} structures to file") if gp.is_master: # Flatten if gather and str_list: str_list = [s for sublist in str_list for s in sublist] if str_list: str_list[-1] = str_list[-1][:-2] with open(file_path, mode) as wfile: wfile.write("{\n") wfile.writelines(str_list) wfile.write("\n}")
[docs] def read_parallel(file_path: str, scatter: bool = True) -> dict: """ Reads JSON database of structures. Args: file_path: Path to JSON file scatter: Whether to scatter data to all processes Returns: Dictionary mapping IDs to Atoms objects """ logger.info(f"Reading structures from {file_path}") if gp.is_master: with open(file_path, "r") as rfile: str_list = rfile.readlines() # Remove {} and add comma to the last element str_list = str_list[1:-1] if str_list: str_list[-1] = str_list[-1] + "," str_list = _make_scatterable_form(str_list) else: str_list = None str_list = gp.comm.scatter(str_list, root=0) struct_list = [] # Construct struct_list for str_struct in str_list: if str_struct is None: continue s_id, s = str_struct.split(":", 1) s_id = s_id.strip('"') s = s[:-2] # Remove comma and newline struct_list.append([s_id, decode(s)]) if not scatter: struct_list = gp.comm.gather(struct_list, root=0) if struct_list: struct_list = [item for sublist in struct_list for item in sublist] struct_dict = {s[0]: s[1] for s in struct_list} return struct_dict
def _make_scatterable_form(str_list: list) -> list: """ Construct a list of length comm.size with padding for even distribution. Args: str_list: List of strings to distribute Returns: List of sublists for each process """ ave, res = divmod(len(str_list), gp.size) counts = [ave + 1 if p < res else ave for p in range(gp.size)] # Determine the starting and ending indices of each sub-task starts = [sum(counts[:p]) for p in range(gp.size)] ends = [sum(counts[: p + 1]) for p in range(gp.size)] new_list = [str_list[starts[p]: ends[p]] for p in range(gp.size)] return new_list