Source code for gnrs.core.registry
"""
Task registry for Genarris.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
import importlib
_TASK_TYPES = {
"generation": ("gnrs.generation", "StructureGenerationTask"),
"energy": ("gnrs.energy", "EnergyCalculationTask"),
"optimize": ("gnrs.optimize", "GeometryOptimizationTask"),
"descriptor": ("gnrs.descriptor", "DescriptorEvaluationTask"),
"cluster": ("gnrs.cluster", "ClusterSelectionTask"),
"dedup": ("gnrs.deduplication", "DuplicateRemovalTask"),
}
_ENERGY_METHODS = {"maceoff", "uma", "aimnet", "aims", "vasp", "dftb"}
_OPTIMIZERS = {"bfgs", "lbfgs"}
_RIGID_PRESS_OPTIMIZERS = {"rigid_press", "symm_rigid_press"}
_CLUSTERERS = {"ap", "kmeans"}
_SELECTORS = {"center", "window"}
_DESCRIPTORS = {"acsf"}
[docs]
def resolve_task(task_name: str):
"""
Resolve a config task name into (task_class, extra_args).
Args:
task_name: Task name from the config file.
"""
name = task_name.strip().lower()
# 1) generation
if name == "generation":
cls = _import_class(*_TASK_TYPES["generation"])
return cls, ()
# 2) rigid press optimizers: rigid_press, symm_rigid_press
if name in _RIGID_PRESS_OPTIMIZERS:
cls = _import_class(*_TASK_TYPES["optimize"])
return cls, (name,)
# duplicate removal
if name == "dedup":
cls = _import_class(*_TASK_TYPES["dedup"])
return cls, ()
# 3) optimizer + energy: bfgs_maceoff, lbfgs_uma, ...
for opt in _OPTIMIZERS:
prefix = opt + "_"
if name.startswith(prefix):
energy_method = name[len(prefix) :]
if energy_method in _ENERGY_METHODS:
cls = _import_class(*_TASK_TYPES["optimize"])
return cls, (opt, energy_method)
# 4) SPE: maceoff, uma, vasp, ...
if name in _ENERGY_METHODS:
cls = _import_class(*_TASK_TYPES["energy"])
return cls, (name,)
# 5) descriptor: acsf, ...
if name in _DESCRIPTORS:
cls = _import_class(*_TASK_TYPES["descriptor"])
return cls, (name,)
# 6) cluster + selection: ap_center, kmeans_window, ...
for cm in _CLUSTERERS:
prefix = cm + "_"
if name.startswith(prefix):
selection = name[len(prefix) :]
if selection in _SELECTORS:
cls = _import_class(*_TASK_TYPES["cluster"])
return cls, (cm, selection)
raise ValueError(
f"Unknown task: {task_name}. "
f"Could not resolve to any registered task type."
)
def _import_class(module_path: str, class_name: str):
"""
Import and return a class from the given module path.
"""
mod = importlib.import_module(module_path)
return getattr(mod, class_name)