EDIT: skip to the bottom
i’m training SRGAN on VOC2012 using NVIDIA DALI and experimenting between DataParallel and DistributedDataParallel (I was using apex too but I’ve removed it in order to figure out what’s going wrong).
here is my run.sh
python -m torch.distributed.launch \
--nproc_per_node=1 \
train_srgan_dali.py \
--train-mx-path=/home/maksim/data/VOC2012/voc_train.rec \
--train-mx-index-path=/home/maksim/data/VOC2012/voc_train.idx \
--val-mx-path=/home/maksim/data/VOC2012/voc_val.rec \
--val-mx-index-path=/home/maksim/data/VOC2012/voc_val.idx \
--checkpoint-dir=/home/maksim/dev_projects/atlas_sr/checkpoints \
--experiment-name=srgan_dali_pascal_3_channel_icnr_dp \
--batch-size=64 \
--lr=1e-3 \
--crop-size=88 \
--upscale-factor=2 \
--epochs=100 \
--workers=1 \
--channels=3
here is my model and here is my train script
import argparse
import os
import time
from math import log10
import pandas as pd
import torch
import torch.backends.cudnn
import torch.distributed
from nvidia.dali import types
from torch import nn
from data_utils.dali import StupidDALIIterator, SRGANMXNetPipeline
from metrics.metrics import AverageMeter
from metrics.ssim import ssim
from models.SRGAN import (
Generator,
Discriminator,
GeneratorLoss,
)
from util.util import monkey_patch_bn
monkey_patch_bn()
parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--channels", type=int, default=3)
parser.add_argument("--experiment-name", type=str, default="test")
parser.add_argument(
"--train-mx-path", default="/home/maksim/data/VOC2012/voc_train.rec"
)
parser.add_argument(
"--train-mx-index-path", default="/home/maksim/data/VOC2012/voc_train.idx"
)
parser.add_argument("--val-mx-path", default="/home/maksim/data/VOC2012/voc_val.rec")
parser.add_argument(
"--val-mx-index-path", default="/home/maksim/data/VOC2012/voc_val.idx"
)
parser.add_argument("--checkpoint-dir", default="/home/maksim/data/checkpoints")
parser.add_argument("--upscale-factor", type=int, default=2)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--prof", action="store_true", default=False)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--crop-size", type=int, default=88)
parser.add_argument("--workers", type=int, default=4)
args = parser.parse_args()
local_rank = args.local_rank
train_mx_path = args.train_mx_path
train_mx_index_path = args.train_mx_index_path
val_mx_path = args.val_mx_path
val_mx_index_path = args.val_mx_index_path
experiment_name = args.experiment_name
checkpoint_dir = args.checkpoint_dir
upscale_factor = args.upscale_factor
epochs = args.epochs
batch_size = args.batch_size
crop_size = args.crop_size
prof = args.prof
workers = args.workers
lr = args.lr
channels = args.channels
print_freq = 10
assert os.path.exists(train_mx_path)
assert os.path.exists(train_mx_index_path)
assert os.path.exists(val_mx_path)
assert os.path.exists(val_mx_index_path)
assert experiment_name
assert os.path.exists(checkpoint_dir)
distributed = False
world_size = 1
if local_rank == 0:
checkpoint_dir = os.path.join(checkpoint_dir, experiment_name)
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)
if "WORLD_SIZE" in os.environ:
world_size = int(os.environ["WORLD_SIZE"])
distributed = world_size > 1
netG = Generator(scale_factor=upscale_factor, in_channels=channels)
netD = Discriminator(in_channels=channels)
g = GeneratorLoss()
if distributed:
gpu = local_rank % torch.cuda.device_count()
torch.cuda.set_device(gpu)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
assert world_size == torch.distributed.get_world_size()
netG = nn.SyncBatchNorm.convert_sync_batchnorm(netG)
netD = nn.SyncBatchNorm.convert_sync_batchnorm(netD)
netG.cuda(gpu)
netD.cuda(gpu)
g.cuda(gpu)
netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
lr /= world_size
else:
netG = Generator(scale_factor=upscale_factor, in_channels=channels)
netD = Discriminator(in_channels=channels)
netG = nn.DataParallel(netG)
netD = nn.DataParallel(netD)
netG = netG.cuda()
netD = netD.cuda()
g = g.cuda()
# because vgg excepts 3 channels
if channels == 1:
generator_loss = lambda fake_out, fake_img, hr_image: g(
fake_out,
torch.cat([fake_img, fake_img, fake_img], dim=1),
torch.cat([hr_image, hr_image, hr_image], dim=1),
)
else:
generator_loss = g
optimizerG = torch.optim.Adam(netG.parameters(), lr=lr)
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr)
train_pipe = SRGANMXNetPipeline(
batch_size=batch_size,
num_gpus=world_size,
num_threads=workers,
device_id=local_rank,
crop=crop_size,
mx_path=train_mx_path,
mx_index_path=train_mx_index_path,
upscale_factor=upscale_factor,
image_type=types.DALIImageType.RGB,
)
train_pipe.build()
train_loader = StupidDALIIterator(
pipelines=[train_pipe],
output_map=["lr_image", "hr_image"],
size=int(train_pipe.epoch_size("Reader") / world_size),
auto_reset=False,
)
val_pipe = SRGANMXNetPipeline(
batch_size=batch_size,
num_gpus=world_size,
num_threads=workers,
device_id=local_rank,
crop=crop_size,
mx_path=val_mx_path,
mx_index_path=val_mx_index_path,
upscale_factor=upscale_factor,
random_shuffle=False,
image_type=types.DALIImageType.RGB,
)
val_pipe.build()
val_loader = StupidDALIIterator(
pipelines=[val_pipe],
output_map=["lr_image", "hr_image"],
size=int(val_pipe.epoch_size("Reader") / world_size),
auto_reset=False,
)
g_loss_meter = AverageMeter("g_loss")
d_loss_meter = AverageMeter("d_loss")
sample_speed_meter = AverageMeter("sample_speed")
def train(epoch):
g_loss_meter.reset()
d_loss_meter.reset()
sample_speed_meter.reset()
netG.train()
netD.train()
for i, (lr_image, hr_image) in enumerate(train_loader):
start = time.time()
batch_size = lr_image.shape[0]
if prof and i > 10:
break
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
##########################
fake_img = netG(lr_image)
netD.zero_grad()
real_out = netD(hr_image).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss_meter.update(d_loss.item())
d_loss.backward(retain_graph=True)
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
netG.zero_grad()
g_loss = generator_loss(fake_out, fake_img, hr_image)
g_loss_meter.update(g_loss.item())
g_loss.backward()
optimizerG.step()
sample_speed_meter.update(world_size * batch_size / (time.time() - start))
if local_rank == 0 and i % print_freq == 0:
print(
"\t".join(
[
f"epoch {epoch}",
f"step {i + 1}/{train_loader.size // batch_size}",
str(sample_speed_meter),
str(d_loss_meter),
str(g_loss_meter),
]
)
)
mse_meter = AverageMeter("mse")
ssim_meter = AverageMeter("ssim")
psnr_meter = AverageMeter("psnr")
def validate(epoch):
mse_meter.reset()
ssim_meter.reset()
psnr_meter.reset()
netG.eval()
for i, (lr_image, hr_image) in enumerate(val_loader):
batch_size = lr_image.shape[0]
if prof and i > 10:
break
with torch.no_grad():
sr_image = netG(lr_image)
batch_mse = ((sr_image - hr_image) ** 2).mean()
batch_ssim = ssim(sr_image, hr_image)
mse_meter.update(batch_mse.item(), batch_size)
ssim_meter.update(batch_ssim.item(), batch_size)
psnr_meter.update(10 * log10(1 / mse_meter.avg))
if local_rank == 0:
print(
"\t".join(
[
"\033[1;31m" f"epoch {epoch}",
str(mse_meter),
str(ssim_meter),
str(psnr_meter),
"\033[1;0m",
]
)
)
epoch_time_meter = AverageMeter("epoch")
running_meters = {
"g_loss": [],
"d_loss": [],
"sample_speed": [],
"mse": [],
"ssim": [],
"psnr": [],
"epoch_time": [],
}
def update_running_meters():
global running_meters
running_meters["g_loss"].append(g_loss_meter.avg)
running_meters["d_loss"].append(d_loss_meter.avg)
running_meters["sample_speed"].append(sample_speed_meter.avg)
running_meters["mse"].append(mse_meter.avg)
running_meters["ssim"].append(ssim_meter.avg)
running_meters["psnr"].append(psnr_meter.avg)
running_meters["epoch_time"].append(epoch_time_meter.val)
def main():
for epoch in range(epochs):
start = time.time()
train(epoch)
validate(epoch)
if local_rank == 0:
torch.save(
netG.state_dict(),
f"{checkpoint_dir}/netG_epoch_{upscale_factor}_{epoch}.pth",
)
torch.save(
netD.state_dict(),
f"{checkpoint_dir}/netD_epoch_{upscale_factor}_{epoch}.pth",
)
epoch_time_meter.update(time.time() - start)
update_running_meters()
if epoch != 0 and not prof:
data_frame = pd.DataFrame(data=running_meters)
data_frame.to_csv(
os.path.join(checkpoint_dir, "metrics.csv"), index_label="Epoch"
)
val_loader.reset()
train_loader.reset()
if __name__ == "__main__":
main()
when switching between using DataParallel
and DistributedDataParallel
I get drastically different psnr
performance. I’ve found this post but none of the solutions seems to work. That’s one matter (the difference between averaging and summing gradients). The other matter, the thing that I can’t for the life of me figure out, is the difference in how my losses behave in both cases. Here is a trace of my losses (for one epoch) if I train using DataParallel
epoch 7 step 1/241 d_loss 0.999999 g_loss 0.005589
epoch 7 step 11/241 d_loss 0.999999 g_loss 0.005433
epoch 7 step 21/241 d_loss 0.999998 g_loss 0.004887
epoch 7 step 31/241 d_loss 1.000002 g_loss 0.004837
epoch 7 step 41/241 d_loss 1.000000 g_loss 0.004958
epoch 7 step 51/241 d_loss 1.000000 g_loss 0.004784
epoch 7 step 61/241 d_loss 1.000000 g_loss 0.005808
epoch 7 step 71/241 d_loss 0.999979 g_loss 0.005283
epoch 7 step 81/241 d_loss 1.000003 g_loss 0.005585
epoch 7 step 91/241 d_loss 0.999999 g_loss 0.004718
epoch 7 step 101/241 d_loss 0.999999 g_loss 0.006046
epoch 7 step 111/241 d_loss 0.999978 g_loss 0.005157
epoch 7 step 121/241 d_loss 1.000007 g_loss 0.006780
epoch 7 step 131/241 d_loss 1.000001 g_loss 0.005851
epoch 7 step 141/241 d_loss 1.000000 g_loss 0.005644
epoch 7 step 151/241 d_loss 0.999986 g_loss 0.005973
epoch 7 step 161/241 d_loss 1.000002 g_loss 0.005687
epoch 7 step 171/241 d_loss 1.000012 g_loss 0.006535
epoch 7 step 181/241 d_loss 0.999999 g_loss 0.005457
epoch 7 step 191/241 d_loss 0.999999 g_loss 0.005313
epoch 7 step 201/241 d_loss 1.000000 g_loss 0.006094
epoch 7 step 211/241 d_loss 1.000000 g_loss 0.006187
epoch 7 step 221/241 d_loss 1.000116 g_loss 0.005385
epoch 7 step 231/241 d_loss 0.999931 g_loss 0.005718
epoch 7 step 241/241 d_loss 0.999774 g_loss 0.005635
From my understanding (and by watching the psnr) this is how the losses should trend for SRGAN.
Now here are my losses when using DistributedDataParallel
(across several epochs to show the trend)
epoch 0 step 1/60 d_loss 1.000204 (1.000204) g_loss 0.153849 (0.153849)
epoch 0 step 11/60 d_loss 0.974728 (0.965737) g_loss 0.019822 (0.058211)
epoch 0 step 21/60 d_loss 0.468546 (0.831723) g_loss 0.015897 (0.038876)
epoch 0 step 31/60 d_loss 0.230158 (0.677370) g_loss 0.014611 (0.031437)
epoch 0 step 41/60 d_loss 0.077666 (0.544439) g_loss 0.014681 (0.027434)
epoch 0 step 51/60 d_loss 0.020034 (0.447585) g_loss 0.011524 (0.024474)
epoch 0 step 61/60 d_loss 0.013507 (0.378103) g_loss 0.011936 (0.022396)
epoch 0 mse 0.006693 (0.007545) ssim 0.661945 (0.645649) psnr 21.223185 (21.223185)
epoch 1 step 1/60 d_loss 0.019439 (0.019439) g_loss 0.010366 (0.010366)
epoch 1 step 11/60 d_loss 0.009224 (0.010984) g_loss 0.009906 (0.010792)
epoch 1 step 21/60 d_loss 0.003987 (0.008465) g_loss 0.011732 (0.010643)
epoch 1 step 31/60 d_loss 0.007867 (0.007535) g_loss 0.009154 (0.010312)
epoch 1 step 41/60 d_loss 0.003442 (0.006837) g_loss 0.010357 (0.010266)
epoch 1 step 51/60 d_loss 0.003987 (0.005997) g_loss 0.010241 (0.010080)
epoch 1 mse 0.004144 (0.004839) ssim 0.746634 (0.726122) psnr 23.152690 (23.152690)
epoch 2 step 1/60 d_loss 0.006586 (0.006586) g_loss 0.009223 (0.009223)
epoch 2 step 11/60 d_loss 0.859120 (0.566964) g_loss 0.008221 (0.008524)
epoch 2 step 21/60 d_loss 0.876267 (0.731556) g_loss 0.008248 (0.008669)
epoch 2 step 31/60 d_loss 0.665335 (0.739961) g_loss 0.010071 (0.008873)
epoch 2 step 41/60 d_loss 0.508060 (0.741789) g_loss 0.007758 (0.009077)
epoch 2 step 51/60 d_loss 0.533404 (0.670928) g_loss 0.007410 (0.008923)
epoch 2 mse 0.004435 (0.004207) ssim 0.733819 (0.747270) psnr 23.760117 (23.760117)
epoch 3 step 1/60 d_loss 0.976557 (0.976557) g_loss 0.008353 (0.008353)
epoch 3 step 11/60 d_loss 0.873007 (0.948327) g_loss 0.010379 (0.008218)
epoch 3 step 21/60 d_loss 0.688478 (0.868267) g_loss 0.006677 (0.008104)
epoch 3 step 31/60 d_loss 0.256862 (0.726863) g_loss 0.007438 (0.008090)
epoch 3 step 41/60 d_loss 0.101930 (0.586502) g_loss 0.008943 (0.007990)
epoch 3 step 51/60 d_loss 0.073482 (0.483037) g_loss 0.009807 (0.007858)
epoch 3 mse 0.003936 (0.003998) ssim 0.749274 (0.763466) psnr 23.981862 (23.981862)
Notice that in this case d_loss
goes to zero rather than to 1.
I’ve been wrestling with it for several days and I can’t for the life of me figure out what is I’m doing wrong in switching from DataParallel
to DistributedDataParallel
that causes this kind of behavior.
EDIT:
life lesson: this is what happens when you copy paste code without understanding completely. i copied this code from https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/resnet50/main.py
and adapted it for my needs. the problem turned out to be that in the original code
there’s a reduce when logging the metrics. i misread and misinterpreted simultaneously: i missed that the backwards pass is actually run on the unreduced loss further down and misinterpreted delay_allreduce
to mean that apex.DistributedDataParallel
wouldn’t be doing any reduce at all and that the reduce_tensor
call was necessarily done by hand.
so in summary i was dividing my loss by world_size
unnecessarily.