Source code for gnrs.core.gpu
"""
GPU device management for MPI-parallel workloads.
Implements a worker/feeder pattern where only a subset of MPI ranks
(GPU workers) load models onto GPUs, while the remaining ranks (feeders)
send structures to workers via MPI for computation. This avoids GPU OOM
when running with many MPI ranks and few GPUs.
This source code is licensed under the BSD-3-Clause license found in the
LICENSE file in the root directory of this source tree.
"""
from __future__ import annotations
__author__ = ["Yi Yang"]
__email__ = "yiy5@andrew.cmu.edu"
__group__ = "https://www.noamarom.com/"
import logging
from typing import Optional
import torch
from mpi4py import MPI
logger = logging.getLogger("gpu")
# MPI tags for worker/feeder communication
TAG_WORK_REQUEST = 100
TAG_WORK_DATA = 101
TAG_WORK_RESULT = 102
TAG_SHUTDOWN = 103
[docs]
class GPUDeviceManager:
"""
Manages GPU device allocation across MPI ranks.
Partitions ranks into GPU workers and CPU feeders. Workers are assigned
to GPUs. Feeders offload computation to workers via MPI.
Typical usage in HPC:
- 1 GPU node with 1-4 GPUs, 32-128 CPU cores
- Workers: 1 per GPU (or configurable)
- Feeders: all remaining ranks
"""
[docs]
def __init__(
self,
comm: MPI.Comm,
max_workers_per_gpu: int = 1,
) -> None:
"""
Initialize GPU device manager.
Args:
comm: MPI communicator.
max_workers_per_gpu: Maximum number of worker ranks per GPU.
"""
self.comm = comm
self.rank = comm.Get_rank()
self.size = comm.Get_size()
self.num_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
self.max_workers_per_gpu = max_workers_per_gpu
self._num_workers = min(
self.num_gpus * self.max_workers_per_gpu,
self.size,
)
if self.num_gpus == 0:
self._num_workers = self.size
self._is_worker = self.rank < self._num_workers
self._device: Optional[str] = None
self._assign_device()
logger.info(
f"GPU Device Manager: gpus={self.num_gpus} workers={self._num_workers} feeders={self.num_feeders}"
)
def _assign_device(self) -> None:
"""
Assign a CUDA device to worker ranks, CPU to feeders.
"""
if not self._is_worker:
self._device = "cpu"
return
if self.num_gpus == 0:
self._device = "cpu"
return
gpu_id = self.rank % self.num_gpus
self._device = f"cuda:{gpu_id}"
torch.cuda.set_device(gpu_id)
@property
def device(self) -> str:
"""
The torch device string for this rank.
"""
return self._device
@property
def is_worker(self) -> bool:
"""
Whether this rank is a GPU worker.
"""
return self._is_worker
@property
def is_feeder(self) -> bool:
"""
Whether this rank is a CPU feeder.
"""
return not self._is_worker
@property
def num_workers(self) -> int:
"""
Total number of GPU worker ranks.
"""
return self._num_workers
@property
def num_feeders(self) -> int:
"""
Total number of CPU feeder ranks.
"""
return self.size - self._num_workers
@property
def worker_ranks(self) -> list[int]:
"""
List of all worker rank IDs.
"""
return list(range(self._num_workers))
@property
def feeder_ranks(self) -> list[int]:
"""
List of all feeder rank IDs.
"""
return list(range(self._num_workers, self.size))
[docs]
def assigned_worker(self) -> int:
"""
Return the worker rank this feeder is assigned to (round-robin).
Returns:
Worker rank ID
"""
if self._is_worker:
return self.rank
feeder_index = self.rank - self._num_workers
return feeder_index % self._num_workers