Source code for mridc.utils.distributed

# encoding: utf-8
__author__ = "Dimitrios Karkalousos"

# Taken and adapted from: https://github.com/NVIDIA/NeMo/blob/main/nemo/utils/distributed.py

import os

import torch

from mridc.utils import logging


[docs]def initialize_distributed(args, backend="nccl"): """ Initialize distributed training. Parameters ---------- args: The arguments object. backend: The backend to use. default: "nccl" Returns ------- local_rank: The local rank of the process. rank: The rank of the process. world_size: The number of processes. """ # Get local rank in case it is provided. local_rank = args.local_rank # Get rank and world size. rank = int(os.getenv("RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) logging.info( f"Initializing torch.distributed with local_rank: {local_rank}, rank: {rank}, world_size: {world_size}" ) # Set the device id. device = rank % torch.cuda.device_count() if local_rank is not None: device = local_rank torch.cuda.set_device(device) # Call the init process. init_method = "tcp://" master_ip = os.getenv("MASTER_ADDR", "localhost") master_port = os.getenv("MASTER_PORT", "6000") init_method += f"{master_ip}:{master_port}" torch.distributed.init_process_group(backend=backend, world_size=world_size, rank=rank, init_method=init_method) return local_rank, rank, world_size