DDP leads to Out of Memory error

I’m encountering an issue while running a distributed PyTorch script for model evaluation and metrics reporting in pycharm debugger. The script uses DistributedDataParallel (DDP) for multi-GPU evaluation, and I suspect the issue lies in distributed initialization, rank assignments, or dataset loading.

The command I use to run the script is:

/home/user/envs/my_project_env/bin/python -m torch.distributed.run \
    --nproc_per_node=8 \
    --master_port=12346 \
    /specific/a/home/users/username/.pycharm_helpers/pydev/pydevd.py \
    --multiprocess \
    --qt-support=auto \
    --client localhost \
    --port 34779 \
    --file /home/user/projects/my_project/eval/evaluate.py \
    --model_path /home/user/models/checkpoints/2024_11_21_model \
    --data_path /home/user/data/dataset \
    --batch_size 4

here is the broad overview of my code:

import argparse
import torch
import os
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from expir.utils import disable_torch_init
from expir.model.builder import load_custom_model
from expir.utils.collators import DataCollatorForSupervisedDataset
from expir.dataset import CustomDataset
from expir.eval import evaluate_dataset

def evaluate_model(args):
    if 'LOCAL_RANK' in os.environ:
        args.local_rank = int(os.environ['LOCAL_RANK'])
    else:
        args.local_rank = -1

    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device = torch.device(f'cuda:{args.local_rank}')
    else:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    if args.local_rank != -1:
        dist.init_process_group(backend='nccl')
        args.distributed = True
    else:
        args.distributed = False

    model_path = os.path.expanduser(args.model_path)
    model = load_custom_model(model_path).to(device)
    if args.distributed:
        model = DDP(model, device_ids=[args.local_rank], output_device=args.local_rank)

    collator = DataCollatorForSupervisedDataset()
    dataset = CustomDataset(args.data_path)
    evaluate_dataset(model, dataset, collator, device, args)

    if args.distributed:
        dist.destroy_process_group()

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, required=True, help="Path to the pre-trained model.")
    parser.add_argument("--data_path", type=str, required=True, help="Path to the dataset directory.")
    parser.add_argument("--batch_size", type=int, default=4, help="Batch size for evaluation.")
    parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training.")
    args = parser.parse_args()

    evaluate_model(args)

the strange behaviour that i see is that befor calling DDP i already see the model on all gpus. i’m not able to make sense of the example