Can't distribute data to all GPUs with DDP

I am trying to start a training of GMIC model on 2 GPUs. There were some problems that didn’t allow me to use DDP in the source code but I fixed them. However, at this line , I got RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cuda:1! (when checking argument for argument weight in method wrapper__cudnn_convolution).

When I printed device type of the data X and the conv1.weight, I saw that;

data cuda:0
weight cuda:0
data cuda:0
weight cuda:1

It looks like first process works correctly but second process doesn’t move data to GPU1.
When I created my DDP setup I followed the A Comprehensive Tutorial to Pytorch DistributedDataParallel | by namespace-Pt | CodeX | Medium. I set my world_size to 2. How can I distribute data correctly to 2 GPU? Or DDP is already do it for me? Do I need to set something else in my setup apart from the setup that I followed?

Could you share the code you used to initialize your model with DDP?

In general, the simplest way to ensure models are on the right device is with

Here is the minimal code of my training setup.

"""import os
import random

import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.optim as optim
from import DataLoader
from import DistributedSampler
import torchvision.transforms as transforms

from data.Datasets import CustomDataset
from models.gmic import gmic, Evaluator, Trainer, LossFunctions
from utils.Configv2 import Configv2

def setup_ddp_env(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def create_dataset(cfg):
    transforms = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5)])

    train = CustomDataset(cfg.train_path, transform=transform_train)
    val = CustomDataset(cfg.val_path, transform=transform_val)

    return train, val

def create_dataloader(dataset, rank, world_size, batch_size=16, pin_memory=False, num_workers=0):
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=False, drop_last=False)

    dataloader = DataLoader(dataset, batch_size=batch_size, pin_memory=pin_memory, num_workers=num_workers,
                            drop_last=False, shuffle=False, sampler=sampler)

    return dataloader

def create_model(cfg, pretrained):
    model = gmic.GMIC(cfg.gmic_parameters)
    model.load_state_dict(torch.load(pretrained), strict=False)

    return model

def create_training_setup(rank, world_size, cfg, pretrained):
    train, val = create_dataset(cfg)

    train_loader = create_dataloader(train, rank, world_size)
    val_loader = create_dataloader(val, rank, world_size)

    model = create_model(cfg, pretrained)
    model =
    model = DDP(model, device_ids=[rank], output_device=rank, find_unused_parameters=False)

    criterion = LossFunctions.GMICLoss()
    optimizer = optim.Adam(model.parameters(),, weight_decay=0.001)

    trainer = Trainer.Trainer(criterion=criterion, model=model, optimizer=optimizer,
                                   total_epochs=cfg.train.epoch, train_loader=train_loader, parellel_mode=True)

    evaluator = Evaluator.Evaluator(model=model, data_loader=val_loader)

    return train_loader, trainer, val_loader, evaluator

def main(rank, world_size, cfg, pretrained, output_path, weight_path):
    setup_ddp_env(rank, world_size)
    train_loader, trainer, val_loader, evaluator = create_training_setup(rank, world_size, cfg, pretrained)

    for epoch in range(1, cfg.train.epoch):

        train_metrics =
        if epoch % 10 == 0:
            val_metrics = evaluator.evaluate()


if __name__ == '__main__':
    cfg = Configv2('config_path')
    pretrained = 'pretrained_model_path'


        args=(2, cfg, pretrained, output_path, weight_path),

torch.distributed.get_rank() didn’t work. The same error was thrown

I solved the issue. The code broke the data flow at this line. When the patches are created, they are always moved to a specific GPU, which doesn’t allow for correct data distribution across GPUs. So, I changed the line to
output = and it works as I expected.