I use torch.nn.parallel.DistributedDataParallel API in PyTorch1.1 to spawn my multi-card model (2 GPUs) to 8 GPUs. According to official tutorial GETTING STARTED WITH DISTRIBUTED DATA PARALLEL, DistributedDataParallel is the recommanded way to parallel one’s model. I am not confident about my implementation and I can’t find other valuable tutorials, so come here for help.
My ideas are simply as follow:
- split my 3D CNN model into 2 GPUs (simply called dev_in and dev_out),
- use DistributedDataParallel() to spawn my 2-GPUs model to 4 Processes, each model replica using same random seed to initialize weights, and each Process don’t share GPUs with other Processes.
- wrap my dataset with Dataset() and DataLoader() api, and manually separate one batch’s data equally in batch-dim, so each Process (with 2 GPUs) will process different data with SAME weights.
- after forward propagate in each Process, collect loss value in each Process and average them, then using this averaged loss value to get gradients and update All 4 models in 4 Processes,
- after each epoch of training and validation, calculate ACC and AUC scores for training dataset and validation dataset respectively.
- after one epoch of training dataset, using validation dataset to validate my my model, currently I use ONE PROCESS model (before warpped by DistributedDataParallel() API to start my validation, because there is something wrong I can’t deal with when I used model after DistributedDataParallel())
Currently, here is my code related:
# sample/train.py
import tempfile
import torch.distributed as dist
import torch.nn as nn
from torch import optim
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel
from torch.distributed import Backend
from time import time
from sample.networks.XXXXNet import XXXXNet
from sample.data import XXXXDataSet
from torch.utils.data import DataLoader
import yaml
import torch
import os
from yaml import CLoader as Loader
from sklearn.metrics import accuracy_score, auc, roc_curve
import numpy as np
from torch.utils.tensorboard import SummaryWriter
cfg = yaml.load(
open(os.path.join(os.path.abspath(
os.path.join(os.path.dirname(__file__), "../config/config.yml")))),
Loader=Loader
)["DATASET"][0]
np.random.seed(cfg["SEED"])
torch.random.manual_seed(cfg["SEED"])
tempfile.tempdir = os.path.abspath("~/tmp")
NAME = "%dGPUs_1e-6" % cfg["WORLD_SIZE"]
writer = SummaryWriter(log_dir=os.path.join(os.path.dirname(__file__), "logs/tb_logs/%s" % NAME))
def setup_env(rank, world_size):
"""
Initialize the distributed environment.
:param rank: Rank of the current process.
:param world_size: Number of processes participating in the job.
:return:
"""
assert isinstance(world_size, int) and world_size > 0
assert isinstance(rank, int) and 0 <= rank < world_size
os.environ['MASTER_ADDR'] = cfg["MASTER_ADDR"]
os.environ['MASTER_PORT'] = cfg["MASTER_PORT"]
# Initialize the process group
dist.init_process_group(Backend.NCCL, rank=rank, world_size=world_size)
# Explicitly setting seed to make sure that models created in two processes
# start from same random weights and biases.
torch.manual_seed(cfg["SEED"])
def cleanup_env():
"""
Destroy the default process group.
:return:
"""
dist.destroy_process_group()
def train_model(rank, world_size, offset=1, ):
"""
Training model.
:param rank: Rank of the current process.
:param world_size: The number of processes in the current process group.
:param offset: The index of first GPU to use.
:return:
"""
assert isinstance(world_size, int) and world_size > 0
assert isinstance(rank, int) and 0 <= rank < world_size
assert isinstance(offset, int) and offset >= 0
setup_env(rank, world_size)
# Setup mp_model and devices for this process
dev_in = rank * 2 + offset
dev_out = rank * 2 + 1 + offset
mp_model = XXXXNet(dev_in=dev_in, dev_out=dev_out)
ddp_mp_model = DistributedDataParallel(mp_model)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(
ddp_mp_model.parameters(),
lr=float(cfg["LEARNING_RATE"]) * world_size,
weight_decay=float(cfg["L2"]),
)
old_lr = float(cfg["LEARNING_RATE"]) * world_size
batch_size = world_size * cfg["BATCH_SIZE_PER_CARD"]
# Training dataset
dataset_train = XXXXDataSet(val=False, shape=cfg["CUBE_SIZE"][1:])
data_loader_train = DataLoader(dataset_train, batch_size=batch_size, num_workers=cfg["NUM_WORKERS"])
# Validation dataset
dataset_val = XXXXDataSet(val=True, shape=cfg["CUBE_SIZE"][1:])
data_loader_val = DataLoader(dataset_val, batch_size=cfg["BATCH_SIZE_PER_CARD"], num_workers=cfg["NUM_WORKERS"])
with open(os.path.join(os.path.dirname(__file__), "logs/" + NAME + ".log"), "w") as log_file:
def _print(string, file=log_file, target_rank=0):
"""
Print to cmd and log file simultaneously.
:param string: Content need to print.
:param file: Log file object.
:return:
"""
if target_rank == -1:
print(string, file=file)
print(string)
file.flush()
elif target_rank == rank:
print(string, file=file)
print(string)
file.flush()
_print(str(cfg))
no_optim = 0
total_epoch = cfg["EPOCHS"]
epoch_best_loss_train = 100.
epoch_best_loss_val = 100.
for epoch in range(1, total_epoch + 1):
# ==TRAINING====TRAINING====TRAINING====TRAINING====TRAINING====TRAINING==
tic_train = time()
# =====ACC&AUC start=====
prods_train, gts_train = [], []
# ======ACC&AUC end======
data_loader_iter_train = iter(data_loader_train)
train_epoch_loss = 0
for img, label in data_loader_iter_train:
inp = img[rank * cfg["BATCH_SIZE_PER_CARD"]: (rank + 1) * cfg["BATCH_SIZE_PER_CARD"]]
label = label[rank * cfg["BATCH_SIZE_PER_CARD"]: (rank + 1) * cfg["BATCH_SIZE_PER_CARD"]].to(dev_out)
# Calculate loss
if inp.size()[0] < 2:
_print("inp is None!!!!!!!!!!!!", target_rank=-1)
train_loss = torch.tensor(0.)
else:
optimizer.zero_grad()
pred = ddp_mp_model(inp)
train_loss = loss_fn(pred, label)
train_loss_lst = [torch.zeros_like(train_loss)] * world_size
prods_train_lst = [torch.zeros_like(pred)] * world_size
label_train_lst = [torch.zeros_like(label)] * world_size
dist.all_gather(prods_train_lst, pred) # Sync between all processes
dist.all_gather(label_train_lst, label) # Sync between all processes
dist.all_gather(train_loss_lst, train_loss) # Sync between all processes
dist.all_reduce(train_loss, op=dist.ReduceOp.SUM) # Sync between all processes
train_loss /= torch.tensor(train_loss_lst).nonzero().size(0)
# Backward propagate and update weights
train_loss.backward()
optimizer.step()
train_epoch_loss += train_loss.item()
# =====ACC&AUC start=====
prods_train.append(torch.cat(prods_train_lst, dim=0).cpu().detach().numpy())
gts_train.append(torch.cat(label_train_lst, dim=0).cpu().numpy())
prods_train = np.concatenate(tuple(prods_train))
gts_train = np.concatenate(tuple(gts_train))
prods_train = prods_train[:, 1]
prods_01 = np.where(prods_train > 0.5, 1, 0) # Turn probability to 0-1 binary output
acc_NN = accuracy_score(gts_train, prods_01)
false_positive_rate, recall, thresholds = roc_curve(gts_train, prods_train, pos_label=1)
roc_auc = auc(false_positive_rate, recall)
# ======ACC&AUC end======
train_epoch_loss /= len(data_loader_iter_train)
_print("******************************")
_print("epoch[%03d/%03d], time: %02dm:%02ds" %
(epoch, cfg["EPOCHS"], int(time() - tic_train) // 60, int(time() - tic_train) % 60))
_print("train loss = %6.4f" % train_epoch_loss)
_print("CUBE_SIZE: %s" % str(cfg["CUBE_SIZE"]))
_print("ACC = %6.4f, AUC = %6.4f" % (acc_NN, roc_auc))
# ==Validation====Validation====Validation====Validation====Validation====Validation==
_print("------------------------------")
mp_model.eval()
tic_val = time()
# =====code for ACC&AUC start=====
prods_val = []
gts_val = []
# ======code for ACC&AUC end======
data_loader_iter_val = iter(data_loader_val)
val_epoch_loss = 0
with torch.no_grad():
for val_img, val_label in data_loader_iter_val:
val_label = val_label.to(dev_out)
# Calculate predicts and loss
val_pred = ddp_mp_model(val_img)
val_loss = loss_fn(val_pred, val_label)
val_epoch_loss += val_loss.item()
# =====code for ACC&AUC start=====
val_pred = val_pred.cpu().detach().numpy()
val_label = val_label.cpu().numpy()
prods_val.append(val_pred)
gts_val.append(val_label)
prods_val = np.concatenate(tuple(prods_val))
gts_val = np.concatenate(tuple(gts_val))
prods_val = prods_val[:, 1]
prods_01_val = np.where(prods_val > 0.5, 1, 0) # Turn probability to 0-1 binary output
acc_NN_val = accuracy_score(gts_val, prods_01_val)
false_positive_rate_val, recall_val, thresholds_val = roc_curve(gts_val, prods_val, pos_label=1)
roc_auc_val = auc(false_positive_rate_val, recall_val)
# ======code for ACC&AUC end======
val_epoch_loss /= len(data_loader_iter_val)
_print("validation time: %02dm:%02ds" % (int(time() - tic_val) // 60, int(time() - tic_val) % 60))
_print("validation loss = %6.4f" % val_epoch_loss)
_print("validation ACC = %6.4f, validation AUC = %6.4f" % (acc_NN_val, roc_auc_val))
if rank == 0:
writer.add_scalars(main_tag="lr", tag_scalar_dict={"train": old_lr}, global_step=epoch)
writer.add_scalars(main_tag="time",
tag_scalar_dict={"train": time() - tic_train,
"val": time() - tic_val}, global_step=epoch)
writer.add_scalars(main_tag="loss",
tag_scalar_dict={"train": train_epoch_loss,
"val": val_epoch_loss}, global_step=epoch)
writer.add_scalars(main_tag="ACC",
tag_scalar_dict={"train": acc_NN,
"val": acc_NN_val}, global_step=epoch)
writer.add_scalars(main_tag="AUC",
tag_scalar_dict={"train": roc_auc,
"val": roc_auc_val}, global_step=epoch)
mp_model.train()
# ==Validation End====Validation End====Validation End====Validation End====Validation End==
if train_epoch_loss >= epoch_best_loss_train:
no_optim += 1
else:
no_optim = 0
epoch_best_loss_train = train_epoch_loss
torch.save(ddp_mp_model.state_dict(),
os.path.join(os.path.dirname(__file__), "weights/" + NAME + ".th"))
if no_optim > 6:
_print("early stop at [%03d] epoch" % epoch)
break
if no_optim > 3:
if old_lr < 5e-7:
break
ddp_mp_model.load_state_dict(torch.load(
os.path.join(os.path.dirname(__file__), "weights/" + NAME + ".th")))
new_lr = old_lr / 5.0
for param_group in optimizer.param_groups:
param_group['lr'] = new_lr
_print("update learning rate: %f -> %f" % (old_lr, new_lr))
old_lr = new_lr
_print("******************************")
_print("Finish!")
cleanup_env()
def ddp_train(demo_fn, world_size):
"""
:param demo_fn: Function.
:param world_size: The number of processes in the current process group.
:return:
"""
mp.spawn(demo_fn,
args=(world_size,),
nprocs=world_size,
join=True)
if __name__ == '__main__':
ddp_train(
train_model,
world_size=cfg["WORLD_SIZE"],
)
writer.close()
And here is another .py file
# sample/networks/XXXXNet.py
import yaml
from yaml import CLoader as Loader
import torch.nn.functional as F
import os
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed import Backend
cfg = yaml.load(
open(os.path.join(os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../config/config.yml")))),
Loader=Loader
)["DATASET"][0]
non_linearity = nn.LeakyReLU
class FireModule3D(nn.Module):
"""
FireModule3D module
(Tested 5.10)
"""
def __init__(self, in_channels, out_channels, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True):
"""
Init function
:param in_channels: Number of input channels.
:param out_channels: Number of output channels.
:param kernel_size: Kernel size.
:param dilation: Dilation rate of dilated convolution.
:param bias: Whether to use bias.
:param squeeze_ratio: Squeeze ratio of Fire Module.
:param pct_3x3: Percent of 3x3 convolution in expand layer.
:param activation: Activation function.
:param use_bn: Whether to use batch normalization.
:param momentum: The value used for the running_mean and running_var computation.
:param use_dp: Whether to use dropout.
:param use_bypass: Whether to use bypass connection.
"""
super(FireModule3D, self).__init__()
self.use_bn = use_bn
self.use_dp = use_dp
self.use_bypass = use_bypass
e_i = out_channels
s_1x1 = int(squeeze_ratio * e_i) # number of channels in squeeze 1x1 layer
e_3x3 = int(pct_3x3 * e_i) # number of channels in expand 3x3 layer
e_1x1 = e_i - e_3x3
self.activation = activation(inplace=True)
self.squeeze1x1 = nn.Conv3d(in_channels=in_channels, out_channels=s_1x1,
kernel_size=1, dilation=1, groups=1, bias=bias)
self.expand1x1 = nn.Conv3d(in_channels=s_1x1, out_channels=e_1x1,
kernel_size=1, dilation=1, groups=1, bias=bias)
self.expand3x3 = nn.Conv3d(in_channels=s_1x1, out_channels=e_3x3, kernel_size=kernel_size,
padding=1, dilation=dilation, bias=bias)
# Bypass connection
if self.use_bypass:
if in_channels != out_channels:
self.bypass = nn.Conv3d(in_channels=in_channels, out_channels=out_channels,
kernel_size=1, bias=bias)
else:
self.bypass = None
if self.use_bn:
self.bn_s1x1 = nn.BatchNorm3d(num_features=s_1x1, momentum=momentum)
self.bn_e1x1 = nn.BatchNorm3d(num_features=e_1x1, momentum=momentum)
self.bn_e3x3 = nn.BatchNorm3d(num_features=e_3x3, momentum=momentum)
if self.use_dp:
self.dp = nn.Dropout2d(0.5)
def forward(self, x):
"""
Forward computation function.
:param x: Input tensor.
:return: Result tensor.
"""
# Squeeze 1x1 layer
squeeze = self.squeeze1x1(x)
if self.use_bn:
squeeze = self.bn_s1x1(squeeze)
squeeze = self.activation(squeeze)
# Expand 1x1 layer
expand1x1 = self.expand1x1(squeeze)
if self.use_dp:
expand1x1 = self.dp(expand1x1)
if self.use_bn:
expand1x1 = self.bn_e1x1(expand1x1)
# Expand 3x3 layer
expand3x3 = self.expand3x3(squeeze)
if self.use_dp:
expand3x3 = self.dp(expand3x3)
if self.use_bn:
expand3x3 = self.bn_e3x3(expand3x3)
merge = self.activation(torch.cat([expand1x1, expand3x3], dim=1))
if self.use_bypass: # Bypass connection
if self.bypass is not None:
x = self.bypass(x)
merge = merge + x
return merge
class XXXXNet(nn.Module):
def __init__(self, nb_class=2, dev_in=None, dev_out=None):
super(XXXXNet, self).__init__()
self.device1 = dev_in
self.device2 = dev_out
self.conv0 = nn.Sequential(nn.Conv3d(1, 8, kernel_size=7, stride=2, padding=3, bias=False),
non_linearity(inplace=True)).to(self.device1)
self.conv1 = FireModule3D(in_channels=8, out_channels=8, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device1)
self.conv2 = FireModule3D(in_channels=8, out_channels=8, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device1)
self.conv3 = FireModule3D(in_channels=8, out_channels=8, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device1)
self.conv4 = FireModule3D(in_channels=8, out_channels=16, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.mp1 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)
self.conv5 = FireModule3D(in_channels=16, out_channels=16, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv6 = FireModule3D(in_channels=16, out_channels=16, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv7 = FireModule3D(in_channels=16, out_channels=16, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv8 = FireModule3D(in_channels=16, out_channels=32, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.mp2 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)
self.conv9 = FireModule3D(in_channels=32, out_channels=32, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv10 = FireModule3D(in_channels=32, out_channels=32, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv11 = FireModule3D(in_channels=32, out_channels=32, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv12 = FireModule3D(in_channels=32, out_channels=64, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.mp3 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)
self.conv13 = FireModule3D(in_channels=64, out_channels=64, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv14 = FireModule3D(in_channels=64, out_channels=64, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv15 = FireModule3D(in_channels=64, out_channels=64, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv16 = FireModule3D(in_channels=64, out_channels=128, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.mp4 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)
self.conv17 = FireModule3D(in_channels=128, out_channels=128, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv18 = FireModule3D(in_channels=128, out_channels=128, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv19 = FireModule3D(in_channels=128, out_channels=128, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.conv20 = FireModule3D(in_channels=128, out_channels=256, kernel_size=3,
dilation=1, bias=False,
squeeze_ratio=0.125, pct_3x3=0.5, activation=non_linearity,
use_bn=True, momentum=0.1, use_dp=False, use_bypass=True).to(self.device2)
self.mp5 = nn.MaxPool3d(kernel_size=2, stride=2).to(self.device2)
self.fc1 = nn.Sequential(
nn.Linear(256 * 7 * 3 * 5, 256, bias=False),
non_linearity(inplace=True),
nn.Dropout2d(p=0.5),
).to(self.device2)
self.fc3 = nn.Linear(256, nb_class, bias=False).to(self.device2)
def forward(self, x):
x = x.to(self.device1)
x = self.conv0(x)
x = self.conv3(self.conv2(self.conv1(x)))
x = x.to(self.device2)
x = self.mp1(self.conv4(x))
x = self.mp2(self.conv8(self.conv7(self.conv6(self.conv5(x)))))
x = self.mp3(self.conv12(self.conv11(self.conv10(self.conv9(x)))))
x = self.mp4(self.conv16(self.conv15(self.conv14(self.conv13(x)))))
x = self.mp5(self.conv20(self.conv19(self.conv18(self.conv17(x)))))
# flatten
x = torch.flatten(x, start_dim=1)
x = self.fc1(x)
x = F.softmax(self.fc3(x), dim=1) # .squeeze().contiguous()
return x
And my config/config.yml file is
DATASET:
- NAME: "XXXX"
ROOT: "/home/aaa/organized_data"# "/Users/aaa/fsdownload" #
IMAGE_FOLDER: "raw_scans"
NPY_FOLDER: "pre_result"
NPY_FOLDER2: "my_npy" #
TRAIN_CSV: "train.csv"
VAL_CSV: "val.csv"
SEED: 1
BATCH_SIZE_PER_CARD: 2
NUM_WORKERS: 4
MOMENTUM: 0.01
LEARNING_RATE: 1e-6
NUM_CLASSES: 2
WORLD_SIZE: 4
CUBE_SIZE: [1, 450, 220, 325] #(C, D, H, W)
EPOCHS: 100
MASTER_ADDR: "localhost"
MASTER_PORT: "12355"
VIS_PORT: 8097
L2: 5e-3
In my case:
Process0 using GPU1 and GPU2
Process1 using GPU3 and GPU4
Process2 using GPU5 and GPU6
Process3 using GPU7 and GPU8
my server has 10 GPUs and I didn’t use GPU0 and 9.
When I monitored the running process of the program using nvidia-smi, I found that GPU 2, 4, 6, 8 are often unable to complete tasks at the same time, and the GPUs that completed calculation first would wait for the straggler, so my GPU overall usage is low.
I think there are a lot of things in my code that can be improved, so where should I start optimizing my code? Looking forward to any suggestions.