Is my code the correct way using DistributedDataParallel in single node multi GPUs?

I use torch.nn.parallel.DistributedDataParallel API in PyTorch1.1 to spawn my multi-card model (2 GPUs) to 8 GPUs. According to official tutorial GETTING STARTED WITH DISTRIBUTED DATA PARALLEL, DistributedDataParallel is the recommanded way to parallel one’s model. I am not confident about my implementation and I can’t find other valuable tutorials, so come here for help.

My ideas are simply as follow:

  1. split my 3D CNN model into 2 GPUs (simply called dev_in and dev_out),
  2. use DistributedDataParallel() to spawn my 2-GPUs model to 4 Processes, each model replica using same random seed to initialize weights, and each Process don’t share GPUs with other Processes.
  3. wrap my dataset with Dataset() and DataLoader() api, and manually separate one batch’s data equally in batch-dim, so each Process (with 2 GPUs) will process different data with SAME weights.
  4. after forward propagate in each Process, collect loss value in each Process and average them, then using this averaged loss value to get gradients and update All 4 models in 4 Processes,
  5. after each epoch of training and validation, calculate ACC and AUC scores for training dataset and validation dataset respectively.
  6. after one epoch of training dataset, using validation dataset to validate my my model, currently I use ONE PROCESS model (before warpped by DistributedDataParallel() API to start my validation, because there is something wrong I can’t deal with when I used model after DistributedDataParallel())

Currently, here is my code related:

# sample/train.py

import tempfile
import torch.distributed as dist
import torch.nn as nn
from torch import optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import Backend
from time import time
from sample.networks.XXXXNet import XXXXNet
from sample.data import XXXXDataSet
from torch.utils.data import DataLoader
import yaml
import torch
import os
from yaml import CLoader as Loader
from sklearn.metrics import accuracy_score, auc, roc_curve
import numpy as np
from torch.utils.tensorboard import SummaryWriter

cfg = yaml.load(
    open(os.path.join(os.path.abspath(
        os.path.join(os.path.dirname(__file__), "../config/config.yml")))),
    Loader=Loader
)["DATASET"][0]

np.random.seed(cfg["SEED"])
torch.random.manual_seed(cfg["SEED"])
tempfile.tempdir = os.path.abspath("~/tmp")
NAME = "%dGPUs_1e-6" % cfg["WORLD_SIZE"]
writer = SummaryWriter(log_dir=os.path.join(os.path.dirname(__file__), "logs/tb_logs/%s" % NAME))


def setup_env(rank, world_size):
    """
    Initialize the distributed environment.

    :param rank: Rank of the current process.
    :param world_size: Number of processes participating in the job.
    :return:
    """
    assert isinstance(world_size, int) and world_size > 0
    assert isinstance(rank, int) and 0 <= rank < world_size

    os.environ['MASTER_ADDR'] = cfg["MASTER_ADDR"]
    os.environ['MASTER_PORT'] = cfg["MASTER_PORT"]

    # Initialize the process group
    dist.init_process_group(Backend.NCCL, rank=rank, world_size=world_size)

    # Explicitly setting seed to make sure that models created in two processes
    # start from same random weights and biases.
    torch.manual_seed(cfg["SEED"])


def cleanup_env():
    """
    Destroy the default process group.

    :return:
    """
    dist.destroy_process_group()


def train_model(rank, world_size, offset=1, ):
    """
    Training model.

    :param rank: Rank of the current process.
    :param world_size: The number of processes in the current process group.
    :param offset: The index of first GPU to use.
    :return:
    """
    assert isinstance(world_size, int) and world_size > 0
    assert isinstance(rank, int) and 0 <= rank < world_size
    assert isinstance(offset, int) and offset >= 0

    setup_env(rank, world_size)

    # Setup mp_model and devices for this process
    dev_in = rank * 2 + offset
    dev_out = rank * 2 + 1 + offset
    mp_model = XXXXNet(dev_in=dev_in, dev_out=dev_out)
    ddp_mp_model = DistributedDataParallel(mp_model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.Adam(
        ddp_mp_model.parameters(),
        lr=float(cfg["LEARNING_RATE"]) * world_size,
        weight_decay=float(cfg["L2"]),
    )
    old_lr = float(cfg["LEARNING_RATE"]) * world_size
    batch_size = world_size * cfg["BATCH_SIZE_PER_CARD"]

    # Training dataset
    dataset_train = XXXXDataSet(val=False, shape=cfg["CUBE_SIZE"][1:])
    data_loader_train = DataLoader(dataset_train, batch_size=batch_size, num_workers=cfg["NUM_WORKERS"])
    # Validation dataset
    dataset_val = XXXXDataSet(val=True, shape=cfg["CUBE_SIZE"][1:])
    data_loader_val = DataLoader(dataset_val, batch_size=cfg["BATCH_SIZE_PER_CARD"], num_workers=cfg["NUM_WORKERS"])

    with open(os.path.join(os.path.dirname(__file__), "logs/" + NAME + ".log"), "w") as log_file:
        def _print(string, file=log_file, target_rank=0):
            """
            Print to cmd and log file simultaneously.

            :param string: Content need to print.
            :param file: Log file object.
            :return:
            """
            if target_rank == -1:
                print(string, file=file)
                print(string)
                file.flush()
            elif target_rank == rank:
                print(string, file=file)
                print(string)
                file.flush()

        _print(str(cfg))
        no_optim = 0
        total_epoch = cfg["EPOCHS"]
        epoch_best_loss_train = 100.
        epoch_best_loss_val = 100.

        for epoch in range(1, total_epoch + 1):
            # ==TRAINING====TRAINING====TRAINING====TRAINING====TRAINING====TRAINING==
            tic_train = time()
            # =====ACC&AUC start=====
            prods_train, gts_train = [], []
            # ======ACC&AUC end======
            data_loader_iter_train = iter(data_loader_train)
            train_epoch_loss = 0
            for img, label in data_loader_iter_train:
                inp = img[rank * cfg["BATCH_SIZE_PER_CARD"]: (rank + 1) * cfg["BATCH_SIZE_PER_CARD"]]
                label = label[rank * cfg["BATCH_SIZE_PER_CARD"]: (rank + 1) * cfg["BATCH_SIZE_PER_CARD"]].to(dev_out)
                # Calculate loss
                if inp.size()[0] < 2:
                    _print("inp is None!!!!!!!!!!!!", target_rank=-1)
                    train_loss = torch.tensor(0.)
                else:
                    optimizer.zero_grad()
                    pred = ddp_mp_model(inp)
                    train_loss = loss_fn(pred, label)
                train_loss_lst = [torch.zeros_like(train_loss)] * world_size
                prods_train_lst = [torch.zeros_like(pred)] * world_size
                label_train_lst = [torch.zeros_like(label)] * world_size

                dist.all_gather(prods_train_lst, pred)  # Sync between all processes
                dist.all_gather(label_train_lst, label)  # Sync between all processes
                dist.all_gather(train_loss_lst, train_loss)  # Sync between all processes
                dist.all_reduce(train_loss, op=dist.ReduceOp.SUM)  # Sync between all processes
                train_loss /= torch.tensor(train_loss_lst).nonzero().size(0)
                # Backward propagate and update weights
                train_loss.backward()
                optimizer.step()
                train_epoch_loss += train_loss.item()

                # =====ACC&AUC start=====
                prods_train.append(torch.cat(prods_train_lst, dim=0).cpu().detach().numpy())
                gts_train.append(torch.cat(label_train_lst, dim=0).cpu().numpy())

            prods_train = np.concatenate(tuple(prods_train))
            gts_train = np.concatenate(tuple(gts_train))
            prods_train = prods_train[:, 1]
            prods_01 = np.where(prods_train > 0.5, 1, 0)  # Turn probability to 0-1 binary output
            acc_NN = accuracy_score(gts_train, prods_01)
            false_positive_rate, recall, thresholds = roc_curve(gts_train, prods_train, pos_label=1)
            roc_auc = auc(false_positive_rate, recall)
            # ======ACC&AUC end======
            train_epoch_loss /= len(data_loader_iter_train)
            _print("******************************")
            _print("epoch[%03d/%03d], time: %02dm:%02ds" %
                   (epoch, cfg["EPOCHS"], int(time() - tic_train) // 60, int(time() - tic_train) % 60))
            _print("train loss = %6.4f" % train_epoch_loss)
            _print("CUBE_SIZE: %s" % str(cfg["CUBE_SIZE"]))
            _print("ACC = %6.4f, AUC = %6.4f" % (acc_NN, roc_auc))

            # ==Validation====Validation====Validation====Validation====Validation====Validation==
            _print("------------------------------")
            mp_model.eval()
            tic_val = time()
            # =====code for ACC&AUC start=====
            prods_val = []
            gts_val = []
            # ======code for ACC&AUC end======
            data_loader_iter_val = iter(data_loader_val)
            val_epoch_loss = 0
            with torch.no_grad():
                for val_img, val_label in data_loader_iter_val:
                    val_label = val_label.to(dev_out)
                    # Calculate predicts and loss
                    val_pred = ddp_mp_model(val_img)
                    val_loss = loss_fn(val_pred, val_label)
                    val_epoch_loss += val_loss.item()
                    # =====code for ACC&AUC start=====
                    val_pred = val_pred.cpu().detach().numpy()
                    val_label = val_label.cpu().numpy()
                    prods_val.append(val_pred)
                    gts_val.append(val_label)
            prods_val = np.concatenate(tuple(prods_val))
            gts_val = np.concatenate(tuple(gts_val))

            prods_val = prods_val[:, 1]
            prods_01_val = np.where(prods_val > 0.5, 1, 0)  # Turn probability to 0-1 binary output
            acc_NN_val = accuracy_score(gts_val, prods_01_val)
            false_positive_rate_val, recall_val, thresholds_val = roc_curve(gts_val, prods_val, pos_label=1)
            roc_auc_val = auc(false_positive_rate_val, recall_val)
            # ======code for ACC&AUC end======
            val_epoch_loss /= len(data_loader_iter_val)
            _print("validation time: %02dm:%02ds" % (int(time() - tic_val) // 60, int(time() - tic_val) % 60))
            _print("validation loss = %6.4f" % val_epoch_loss)
            _print("validation ACC = %6.4f, validation AUC = %6.4f" % (acc_NN_val, roc_auc_val))
            if rank == 0:
                writer.add_scalars(main_tag="lr", tag_scalar_dict={"train": old_lr}, global_step=epoch)
                writer.add_scalars(main_tag="time",
                                   tag_scalar_dict={"train": time() - tic_train,
                                                    "val": time() - tic_val}, global_step=epoch)
                writer.add_scalars(main_tag="loss",
                                   tag_scalar_dict={"train": train_epoch_loss,
                                                    "val": val_epoch_loss}, global_step=epoch)
                writer.add_scalars(main_tag="ACC",
                                   tag_scalar_dict={"train": acc_NN,
                                                    "val": acc_NN_val}, global_step=epoch)
                writer.add_scalars(main_tag="AUC",
                                   tag_scalar_dict={"train": roc_auc,
                                                    "val": roc_auc_val}, global_step=epoch)
            mp_model.train()
            # ==Validation End====Validation End====Validation End====Validation End====Validation End==

            if train_epoch_loss >= epoch_best_loss_train:
                no_optim += 1
            else:
                no_optim = 0
                epoch_best_loss_train = train_epoch_loss
                torch.save(ddp_mp_model.state_dict(),
                           os.path.join(os.path.dirname(__file__), "weights/" + NAME + ".th"))
            if no_optim > 6:
                _print("early stop at [%03d] epoch" % epoch)
                break
            if no_optim > 3:
                if old_lr < 5e-7:
                    break
                ddp_mp_model.load_state_dict(torch.load(
                    os.path.join(os.path.dirname(__file__), "weights/" + NAME + ".th")))
                new_lr = old_lr / 5.0
                for param_group in optimizer.param_groups:
                    param_group['lr'] = new_lr
                _print("update learning rate: %f -> %f" % (old_lr, new_lr))
                old_lr = new_lr
            _print("******************************")
        _print("Finish!")

    cleanup_env()


def ddp_train(demo_fn, world_size):
    """

    :param demo_fn: Function.
    :param world_size: The number of processes in the current process group.
    :return:
    """
    mp.spawn(demo_fn,
             args=(world_size,),
             nprocs=world_size,
             join=True)


if __name__ == '__main__':
    ddp_train(
        train_model,
        world_size=cfg["WORLD_SIZE"],
    )
    writer.close()

And here is another .py file

# sample/networks/XXXXNet.py

import yaml
from yaml import CLoader as Loader
import torch.nn.functional as F
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import Backend

cfg = yaml.load(
    open(os.path.join(os.path.abspath(
        os.path.join(os.path.dirname(__file__), "../../config/config.yml")))),
    Loader=Loader
)["DATASET"][0]

non_linearity = nn.LeakyReLU


class FireModule3D(nn.Module):
    """
    FireModule3D module

    (Tested 5.10)
    """

    def __init__(self, in_channels, out_channels, kernel_size=3,
                 dilation=1, bias=False,
                 squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                 use_bn=True, momentum=0.1, use_dp=False, use_bypass=True):
        """
        Init function

        :param in_channels: Number of input channels.
        :param out_channels: Number of output channels.
        :param kernel_size: Kernel size.
        :param dilation: Dilation rate of dilated convolution.
        :param bias: Whether to use bias.
        :param squeeze_ratio: Squeeze ratio of Fire Module.
        :param pct_3x3: Percent of 3x3 convolution in expand layer.
        :param activation: Activation function.
        :param use_bn: Whether to use batch normalization.
        :param momentum: The value used for the running_mean and running_var computation.
        :param use_dp: Whether to use dropout.
        :param use_bypass: Whether to use bypass connection.
        """
        super(FireModule3D, self).__init__()
        self.use_bn = use_bn
        self.use_dp = use_dp
        self.use_bypass = use_bypass

        e_i = out_channels
        s_1x1 = int(squeeze_ratio * e_i)  # number of channels in squeeze 1x1 layer
        e_3x3 = int(pct_3x3 * e_i)  # number of channels in expand 3x3 layer
        e_1x1 = e_i - e_3x3

        self.activation = activation(inplace=True)
        self.squeeze1x1 = nn.Conv3d(in_channels=in_channels, out_channels=s_1x1,
                                    kernel_size=1, dilation=1, groups=1, bias=bias)
        self.expand1x1 = nn.Conv3d(in_channels=s_1x1, out_channels=e_1x1,
                                   kernel_size=1, dilation=1, groups=1, bias=bias)
        self.expand3x3 = nn.Conv3d(in_channels=s_1x1, out_channels=e_3x3, kernel_size=kernel_size,
                                   padding=1, dilation=dilation, bias=bias)

        # Bypass connection
        if self.use_bypass:
            if in_channels != out_channels:
                self.bypass = nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
                                        kernel_size=1, bias=bias)
            else:
                self.bypass = None

        if self.use_bn:
            self.bn_s1x1 = nn.BatchNorm3d(num_features=s_1x1, momentum=momentum)
            self.bn_e1x1 = nn.BatchNorm3d(num_features=e_1x1, momentum=momentum)
            self.bn_e3x3 = nn.BatchNorm3d(num_features=e_3x3, momentum=momentum)
        if self.use_dp:
            self.dp = nn.Dropout2d(0.5)

    def forward(self, x):
        """
        Forward computation function.

        :param x: Input tensor.
        :return: Result tensor.
        """
        # Squeeze 1x1 layer
        squeeze = self.squeeze1x1(x)
        if self.use_bn:
            squeeze = self.bn_s1x1(squeeze)
        squeeze = self.activation(squeeze)

        # Expand 1x1 layer
        expand1x1 = self.expand1x1(squeeze)
        if self.use_dp:
            expand1x1 = self.dp(expand1x1)
        if self.use_bn:
            expand1x1 = self.bn_e1x1(expand1x1)

        # Expand 3x3 layer
        expand3x3 = self.expand3x3(squeeze)
        if self.use_dp:
            expand3x3 = self.dp(expand3x3)
        if self.use_bn:
            expand3x3 = self.bn_e3x3(expand3x3)

        merge = self.activation(torch.cat([expand1x1, expand3x3], dim=1))

        if self.use_bypass:  # Bypass connection
            if self.bypass is not None:
                x = self.bypass(x)
            merge = merge + x
        return merge


class XXXXNet(nn.Module):
    def __init__(self, nb_class=2, dev_in=None, dev_out=None):
        super(XXXXNet, self).__init__()
        self.device1 = dev_in
        self.device2 = dev_out

        self.conv0 = nn.Sequential(nn.Conv3d(1, 8, kernel_size=7, stride=2, padding=3, bias=False),
                                   non_linearity(inplace=True)).to(self.device1)
        self.conv1 = FireModule3D(in_channels=8, out_channels=8, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device1)
        self.conv2 = FireModule3D(in_channels=8, out_channels=8, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device1)
        self.conv3 = FireModule3D(in_channels=8, out_channels=8, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device1)
        self.conv4 = FireModule3D(in_channels=8, out_channels=16, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.mp1 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)

        self.conv5 = FireModule3D(in_channels=16, out_channels=16, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv6 = FireModule3D(in_channels=16, out_channels=16, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv7 = FireModule3D(in_channels=16, out_channels=16, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv8 = FireModule3D(in_channels=16, out_channels=32, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.mp2 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)

        self.conv9 = FireModule3D(in_channels=32, out_channels=32, kernel_size=3,
                                  dilation=1, bias=False,
                                  squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                  use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv10 = FireModule3D(in_channels=32, out_channels=32, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv11 = FireModule3D(in_channels=32, out_channels=32, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv12 = FireModule3D(in_channels=32, out_channels=64, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.mp3 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)

        self.conv13 = FireModule3D(in_channels=64, out_channels=64, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv14 = FireModule3D(in_channels=64, out_channels=64, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv15 = FireModule3D(in_channels=64, out_channels=64, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv16 = FireModule3D(in_channels=64, out_channels=128, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.mp4 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)

        self.conv17 = FireModule3D(in_channels=128, out_channels=128, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv18 = FireModule3D(in_channels=128, out_channels=128, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv19 = FireModule3D(in_channels=128, out_channels=128, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.conv20 = FireModule3D(in_channels=128, out_channels=256, kernel_size=3,
                                   dilation=1, bias=False,
                                   squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
                                   use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
        self.mp5 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)

        self.fc1 = nn.Sequential(
            nn.Linear(256 * 7 * 3 * 5, 256, bias=False),
            non_linearity(inplace=True),
            nn.Dropout2d(p=0.5),
        ).to(self.device2)
        self.fc3 = nn.Linear(256, nb_class, bias=False).to(self.device2)

    def forward(self, x):
        x = x.to(self.device1)
        x = self.conv0(x)

        x = self.conv3(self.conv2(self.conv1(x)))
        x = x.to(self.device2)
        x = self.mp1(self.conv4(x))

        x = self.mp2(self.conv8(self.conv7(self.conv6(self.conv5(x)))))
        x = self.mp3(self.conv12(self.conv11(self.conv10(self.conv9(x)))))
        x = self.mp4(self.conv16(self.conv15(self.conv14(self.conv13(x)))))
        x = self.mp5(self.conv20(self.conv19(self.conv18(self.conv17(x)))))

        # flatten
        x = torch.flatten(x, start_dim=1)
        x = self.fc1(x)
        x = F.softmax(self.fc3(x), dim=1)  # .squeeze().contiguous()
        return x

And my config/config.yml file is

DATASET:
  - NAME: "XXXX"
    ROOT: "/home/aaa/organized_data"# "/Users/aaa/fsdownload" #
    IMAGE_FOLDER: "raw_scans"
    NPY_FOLDER: "pre_result"
    NPY_FOLDER2: "my_npy" #
    TRAIN_CSV: "train.csv"
    VAL_CSV: "val.csv"
    SEED: 1
    BATCH_SIZE_PER_CARD: 2
    NUM_WORKERS: 4
    MOMENTUM: 0.01
    LEARNING_RATE: 1e-6
    NUM_CLASSES: 2
    WORLD_SIZE: 4
    CUBE_SIZE: [1, 450, 220, 325] #(C, D, H, W)
    EPOCHS: 100
    MASTER_ADDR: "localhost"
    MASTER_PORT: "12355"
    VIS_PORT: 8097
    L2: 5e-3

In my case:
Process0 using GPU1 and GPU2
Process1 using GPU3 and GPU4
Process2 using GPU5 and GPU6
Process3 using GPU7 and GPU8
my server has 10 GPUs and I didn’t use GPU0 and 9.

When I monitored the running process of the program using nvidia-smi, I found that GPU 2, 4, 6, 8 are often unable to complete tasks at the same time, and the GPUs that completed calculation first would wait for the straggler, so my GPU overall usage is low.

I think there are a lot of things in my code that can be improved, so where should I start optimizing my code? Looking forward to any suggestions.:grinning:

You don’t need to average the loss before calling loss.backward(). The gradients that are computed on each process are reduced across processes, and upon returning from loss.backward() each process has identical gradients for their model parameters.

Regarding the utilization, check out torchgpipe. It might be useful here.