Can't figure out what i'm doing wrong

EDIT: skip to the bottom

i’m training SRGAN on VOC2012 using NVIDIA DALI and experimenting between DataParallel and DistributedDataParallel (I was using apex too but I’ve removed it in order to figure out what’s going wrong).

here is my run.sh

python -m torch.distributed.launch \
  --nproc_per_node=1 \
  train_srgan_dali.py \
  --train-mx-path=/home/maksim/data/VOC2012/voc_train.rec \
  --train-mx-index-path=/home/maksim/data/VOC2012/voc_train.idx \
  --val-mx-path=/home/maksim/data/VOC2012/voc_val.rec \
  --val-mx-index-path=/home/maksim/data/VOC2012/voc_val.idx \
  --checkpoint-dir=/home/maksim/dev_projects/atlas_sr/checkpoints \
  --experiment-name=srgan_dali_pascal_3_channel_icnr_dp \
  --batch-size=64 \
  --lr=1e-3 \
  --crop-size=88 \
  --upscale-factor=2 \
  --epochs=100 \
  --workers=1 \
  --channels=3

here is my model and here is my train script

import argparse
import os
import time
from math import log10

import pandas as pd
import torch
import torch.backends.cudnn
import torch.distributed
from nvidia.dali import types
from torch import nn

from data_utils.dali import StupidDALIIterator, SRGANMXNetPipeline
from metrics.metrics import AverageMeter
from metrics.ssim import ssim
from models.SRGAN import (
    Generator,
    Discriminator,
    GeneratorLoss,
)
from util.util import monkey_patch_bn

monkey_patch_bn()

parser = argparse.ArgumentParser()
parser.add_argument("--local_rank", default=0, type=int)
parser.add_argument("--channels", type=int, default=3)
parser.add_argument("--experiment-name", type=str, default="test")
parser.add_argument(
    "--train-mx-path", default="/home/maksim/data/VOC2012/voc_train.rec"
)
parser.add_argument(
    "--train-mx-index-path", default="/home/maksim/data/VOC2012/voc_train.idx"
)
parser.add_argument("--val-mx-path", default="/home/maksim/data/VOC2012/voc_val.rec")
parser.add_argument(
    "--val-mx-index-path", default="/home/maksim/data/VOC2012/voc_val.idx"
)
parser.add_argument("--checkpoint-dir", default="/home/maksim/data/checkpoints")
parser.add_argument("--upscale-factor", type=int, default=2)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--prof", action="store_true", default=False)
parser.add_argument("--lr", type=float, default=1e-3)
parser.add_argument("--crop-size", type=int, default=88)
parser.add_argument("--workers", type=int, default=4)

args = parser.parse_args()
local_rank = args.local_rank
train_mx_path = args.train_mx_path
train_mx_index_path = args.train_mx_index_path
val_mx_path = args.val_mx_path
val_mx_index_path = args.val_mx_index_path
experiment_name = args.experiment_name
checkpoint_dir = args.checkpoint_dir
upscale_factor = args.upscale_factor
epochs = args.epochs
batch_size = args.batch_size
crop_size = args.crop_size
prof = args.prof
workers = args.workers
lr = args.lr
channels = args.channels
print_freq = 10

assert os.path.exists(train_mx_path)
assert os.path.exists(train_mx_index_path)
assert os.path.exists(val_mx_path)
assert os.path.exists(val_mx_index_path)
assert experiment_name
assert os.path.exists(checkpoint_dir)

distributed = False
world_size = 1

if local_rank == 0:
    checkpoint_dir = os.path.join(checkpoint_dir, experiment_name)
    if not os.path.exists(checkpoint_dir):
        os.mkdir(checkpoint_dir)

if "WORLD_SIZE" in os.environ:
    world_size = int(os.environ["WORLD_SIZE"])
    distributed = world_size > 1

netG = Generator(scale_factor=upscale_factor, in_channels=channels)
netD = Discriminator(in_channels=channels)
g = GeneratorLoss()
if distributed:
    gpu = local_rank % torch.cuda.device_count()
    torch.cuda.set_device(gpu)
    torch.distributed.init_process_group(backend="nccl", init_method="env://")
    assert world_size == torch.distributed.get_world_size()
    netG = nn.SyncBatchNorm.convert_sync_batchnorm(netG)
    netD = nn.SyncBatchNorm.convert_sync_batchnorm(netD)
    netG.cuda(gpu)
    netD.cuda(gpu)
    g.cuda(gpu)
    netG = nn.parallel.DistributedDataParallel(netG, device_ids=[gpu])
    netD = nn.parallel.DistributedDataParallel(netD, device_ids=[gpu])
    lr /= world_size
else:
    netG = Generator(scale_factor=upscale_factor, in_channels=channels)
    netD = Discriminator(in_channels=channels)
    netG = nn.DataParallel(netG)
    netD = nn.DataParallel(netD)
    netG = netG.cuda()
    netD = netD.cuda()
    g = g.cuda()

# because vgg excepts 3 channels
if channels == 1:
    generator_loss = lambda fake_out, fake_img, hr_image: g(
        fake_out,
        torch.cat([fake_img, fake_img, fake_img], dim=1),
        torch.cat([hr_image, hr_image, hr_image], dim=1),
    )
else:
    generator_loss = g

optimizerG = torch.optim.Adam(netG.parameters(), lr=lr)
optimizerD = torch.optim.Adam(netD.parameters(), lr=lr)

train_pipe = SRGANMXNetPipeline(
    batch_size=batch_size,
    num_gpus=world_size,
    num_threads=workers,
    device_id=local_rank,
    crop=crop_size,
    mx_path=train_mx_path,
    mx_index_path=train_mx_index_path,
    upscale_factor=upscale_factor,
    image_type=types.DALIImageType.RGB,
)
train_pipe.build()
train_loader = StupidDALIIterator(
    pipelines=[train_pipe],
    output_map=["lr_image", "hr_image"],
    size=int(train_pipe.epoch_size("Reader") / world_size),
    auto_reset=False,
)
val_pipe = SRGANMXNetPipeline(
    batch_size=batch_size,
    num_gpus=world_size,
    num_threads=workers,
    device_id=local_rank,
    crop=crop_size,
    mx_path=val_mx_path,
    mx_index_path=val_mx_index_path,
    upscale_factor=upscale_factor,
    random_shuffle=False,
    image_type=types.DALIImageType.RGB,
)
val_pipe.build()
val_loader = StupidDALIIterator(
    pipelines=[val_pipe],
    output_map=["lr_image", "hr_image"],
    size=int(val_pipe.epoch_size("Reader") / world_size),
    auto_reset=False,
)

g_loss_meter = AverageMeter("g_loss")
d_loss_meter = AverageMeter("d_loss")
sample_speed_meter = AverageMeter("sample_speed")


def train(epoch):
    g_loss_meter.reset()
    d_loss_meter.reset()
    sample_speed_meter.reset()
    netG.train()
    netD.train()

    for i, (lr_image, hr_image) in enumerate(train_loader):
        start = time.time()
        batch_size = lr_image.shape[0]

        if prof and i > 10:
            break

        ############################
        # (1) Update D network: maximize D(x)-1-D(G(z))
        ##########################
        fake_img = netG(lr_image)

        netD.zero_grad()
        real_out = netD(hr_image).mean()
        fake_out = netD(fake_img).mean()
        d_loss = 1 - real_out + fake_out
        d_loss_meter.update(d_loss.item())
        d_loss.backward(retain_graph=True)
        optimizerD.step()

        ############################
        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
        ###########################
        netG.zero_grad()
        g_loss = generator_loss(fake_out, fake_img, hr_image)
        g_loss_meter.update(g_loss.item())
        g_loss.backward()
        optimizerG.step()

        sample_speed_meter.update(world_size * batch_size / (time.time() - start))

        if local_rank == 0 and i % print_freq == 0:
            print(
                "\t".join(
                    [
                        f"epoch {epoch}",
                        f"step {i + 1}/{train_loader.size // batch_size}",
                        str(sample_speed_meter),
                        str(d_loss_meter),
                        str(g_loss_meter),
                    ]
                )
            )


mse_meter = AverageMeter("mse")
ssim_meter = AverageMeter("ssim")
psnr_meter = AverageMeter("psnr")


def validate(epoch):
    mse_meter.reset()
    ssim_meter.reset()
    psnr_meter.reset()
    netG.eval()
    for i, (lr_image, hr_image) in enumerate(val_loader):
        batch_size = lr_image.shape[0]
        if prof and i > 10:
            break

        with torch.no_grad():
            sr_image = netG(lr_image)

        batch_mse = ((sr_image - hr_image) ** 2).mean()
        batch_ssim = ssim(sr_image, hr_image)

        mse_meter.update(batch_mse.item(), batch_size)
        ssim_meter.update(batch_ssim.item(), batch_size)

    psnr_meter.update(10 * log10(1 / mse_meter.avg))

    if local_rank == 0:
        print(
            "\t".join(
                [
                    "\033[1;31m" f"epoch {epoch}",
                    str(mse_meter),
                    str(ssim_meter),
                    str(psnr_meter),
                    "\033[1;0m",
                ]
            )
        )


epoch_time_meter = AverageMeter("epoch")

running_meters = {
    "g_loss": [],
    "d_loss": [],
    "sample_speed": [],
    "mse": [],
    "ssim": [],
    "psnr": [],
    "epoch_time": [],
}


def update_running_meters():
    global running_meters
    running_meters["g_loss"].append(g_loss_meter.avg)
    running_meters["d_loss"].append(d_loss_meter.avg)
    running_meters["sample_speed"].append(sample_speed_meter.avg)
    running_meters["mse"].append(mse_meter.avg)
    running_meters["ssim"].append(ssim_meter.avg)
    running_meters["psnr"].append(psnr_meter.avg)
    running_meters["epoch_time"].append(epoch_time_meter.val)


def main():
    for epoch in range(epochs):
        start = time.time()
        train(epoch)
        validate(epoch)
        if local_rank == 0:
            torch.save(
                netG.state_dict(),
                f"{checkpoint_dir}/netG_epoch_{upscale_factor}_{epoch}.pth",
            )
            torch.save(
                netD.state_dict(),
                f"{checkpoint_dir}/netD_epoch_{upscale_factor}_{epoch}.pth",
            )
            epoch_time_meter.update(time.time() - start)
            update_running_meters()
            if epoch != 0 and not prof:
                data_frame = pd.DataFrame(data=running_meters)
                data_frame.to_csv(
                    os.path.join(checkpoint_dir, "metrics.csv"), index_label="Epoch"
                )

        val_loader.reset()
        train_loader.reset()


if __name__ == "__main__":
    main()

when switching between using DataParallel and DistributedDataParallel I get drastically different psnr performance. I’ve found this post but none of the solutions seems to work. That’s one matter (the difference between averaging and summing gradients). The other matter, the thing that I can’t for the life of me figure out, is the difference in how my losses behave in both cases. Here is a trace of my losses (for one epoch) if I train using DataParallel

epoch 7 step 1/241        d_loss 0.999999      g_loss 0.005589 
epoch 7 step 11/241       d_loss 0.999999      g_loss 0.005433 
epoch 7 step 21/241       d_loss 0.999998      g_loss 0.004887 
epoch 7 step 31/241       d_loss 1.000002      g_loss 0.004837 
epoch 7 step 41/241       d_loss 1.000000      g_loss 0.004958 
epoch 7 step 51/241       d_loss 1.000000      g_loss 0.004784 
epoch 7 step 61/241       d_loss 1.000000      g_loss 0.005808 
epoch 7 step 71/241       d_loss 0.999979      g_loss 0.005283 
epoch 7 step 81/241       d_loss 1.000003      g_loss 0.005585 
epoch 7 step 91/241       d_loss 0.999999      g_loss 0.004718 
epoch 7 step 101/241      d_loss 0.999999      g_loss 0.006046 
epoch 7 step 111/241      d_loss 0.999978      g_loss 0.005157 
epoch 7 step 121/241      d_loss 1.000007      g_loss 0.006780 
epoch 7 step 131/241      d_loss 1.000001      g_loss 0.005851 
epoch 7 step 141/241      d_loss 1.000000      g_loss 0.005644 
epoch 7 step 151/241      d_loss 0.999986      g_loss 0.005973 
epoch 7 step 161/241      d_loss 1.000002      g_loss 0.005687 
epoch 7 step 171/241      d_loss 1.000012      g_loss 0.006535 
epoch 7 step 181/241      d_loss 0.999999      g_loss 0.005457 
epoch 7 step 191/241      d_loss 0.999999      g_loss 0.005313 
epoch 7 step 201/241      d_loss 1.000000      g_loss 0.006094 
epoch 7 step 211/241      d_loss 1.000000      g_loss 0.006187 
epoch 7 step 221/241      d_loss 1.000116      g_loss 0.005385 
epoch 7 step 231/241      d_loss 0.999931      g_loss 0.005718 
epoch 7 step 241/241      d_loss 0.999774      g_loss 0.005635 

From my understanding (and by watching the psnr) this is how the losses should trend for SRGAN.

Now here are my losses when using DistributedDataParallel (across several epochs to show the trend)

epoch 0 step 1/60       d_loss 1.000204 (1.000204)      g_loss 0.153849 (0.153849)
epoch 0 step 11/60      d_loss 0.974728 (0.965737)      g_loss 0.019822 (0.058211)
epoch 0 step 21/60      d_loss 0.468546 (0.831723)      g_loss 0.015897 (0.038876)
epoch 0 step 31/60      d_loss 0.230158 (0.677370)      g_loss 0.014611 (0.031437)
epoch 0 step 41/60      d_loss 0.077666 (0.544439)      g_loss 0.014681 (0.027434)
epoch 0 step 51/60      d_loss 0.020034 (0.447585)      g_loss 0.011524 (0.024474)
epoch 0 step 61/60      d_loss 0.013507 (0.378103)      g_loss 0.011936 (0.022396)
epoch 0 mse 0.006693 (0.007545) ssim 0.661945 (0.645649)        psnr 21.223185 (21.223185)
epoch 1 step 1/60       d_loss 0.019439 (0.019439)      g_loss 0.010366 (0.010366)
epoch 1 step 11/60      d_loss 0.009224 (0.010984)      g_loss 0.009906 (0.010792)
epoch 1 step 21/60      d_loss 0.003987 (0.008465)      g_loss 0.011732 (0.010643)
epoch 1 step 31/60      d_loss 0.007867 (0.007535)      g_loss 0.009154 (0.010312)
epoch 1 step 41/60      d_loss 0.003442 (0.006837)      g_loss 0.010357 (0.010266)
epoch 1 step 51/60      d_loss 0.003987 (0.005997)      g_loss 0.010241 (0.010080)
epoch 1 mse 0.004144 (0.004839) ssim 0.746634 (0.726122)        psnr 23.152690 (23.152690)
epoch 2 step 1/60       d_loss 0.006586 (0.006586)      g_loss 0.009223 (0.009223)
epoch 2 step 11/60      d_loss 0.859120 (0.566964)      g_loss 0.008221 (0.008524)
epoch 2 step 21/60      d_loss 0.876267 (0.731556)      g_loss 0.008248 (0.008669)
epoch 2 step 31/60      d_loss 0.665335 (0.739961)      g_loss 0.010071 (0.008873)
epoch 2 step 41/60      d_loss 0.508060 (0.741789)      g_loss 0.007758 (0.009077)
epoch 2 step 51/60      d_loss 0.533404 (0.670928)      g_loss 0.007410 (0.008923)
epoch 2 mse 0.004435 (0.004207) ssim 0.733819 (0.747270)        psnr 23.760117 (23.760117)
epoch 3 step 1/60       d_loss 0.976557 (0.976557)      g_loss 0.008353 (0.008353)
epoch 3 step 11/60      d_loss 0.873007 (0.948327)      g_loss 0.010379 (0.008218)
epoch 3 step 21/60      d_loss 0.688478 (0.868267)      g_loss 0.006677 (0.008104)
epoch 3 step 31/60      d_loss 0.256862 (0.726863)      g_loss 0.007438 (0.008090)
epoch 3 step 41/60      d_loss 0.101930 (0.586502)      g_loss 0.008943 (0.007990)
epoch 3 step 51/60      d_loss 0.073482 (0.483037)      g_loss 0.009807 (0.007858)
epoch 3 mse 0.003936 (0.003998) ssim 0.749274 (0.763466)        psnr 23.981862 (23.981862)

Notice that in this case d_loss goes to zero rather than to 1.

I’ve been wrestling with it for several days and I can’t for the life of me figure out what is I’m doing wrong in switching from DataParallel to DistributedDataParallel that causes this kind of behavior.

EDIT:

life lesson: this is what happens when you copy paste code without understanding completely. i copied this code from https://github.com/NVIDIA/DALI/blob/master/docs/examples/pytorch/resnet50/main.py

and adapted it for my needs. the problem turned out to be that in the original code

there’s a reduce when logging the metrics. i misread and misinterpreted simultaneously: i missed that the backwards pass is actually run on the unreduced loss further down and misinterpreted delay_allreduce

to mean that apex.DistributedDataParallel wouldn’t be doing any reduce at all and that the reduce_tensor call was necessarily done by hand.

so in summary i was dividing my loss by world_size unnecessarily.

1 Like

Hi @makslevental,

I’m glad that you solved it! :slightly_smiling_face:

I am curious, which Reader are you using in DALI to load VOC2012? Are you using ExternalSource or a custom operator?

@spanev

@spanev since it looks like you work for on DALI: can you tell me why the operators aren’t more orthogonal? for example I really would like to be able to normalize without cropping, to change imagetypes outside of the decoder, or resize but after crop (i.e. I’d like to use decoder, crop, and resize independently of one another but I can’t because resize expects uint8). I’m not complaining (thanks for the toolkit!) I’m just wondering if it’s something having to do with how the cuda kernels are compiled or it’s just a design choice.

1 Like

It is great to see that you are able to use it to accelerate your training. :slight_smile:

DALI is still under active development (and technically still in beta).
The team is currently reworking the whole architecture and working on some major missing features (such as pointwise operations).

A few notes about the operators you mentioned:

I know it may seem a little bit counter-intuitive but you should be using CropMirrorNormalize (minus the crop and mirror options). The DALI and CUDA kernels have compile time mechanisms to reduce the runtime overhead (to none).

We may add an Normalize op in the future but it would still use the same kernels under the hood.

You can actually do it with the Cast operator.

I guess you mean Crop after Resize.

This limitation comes from the fact that Crop and Resize operations are (sorta) commutative, but performance-wise it is better to first crop first to not apply the interpolation algorithm on a data that will be cropped anyway.

I know it may seem a little bit counter-intuitive but you should be using CropMirrorNormalize (minus the crop and mirror options). The DALI and CUDA kernels have compile time mechanisms to reduce the runtime overhead (to none).

how can you use CropMirrorNormalize without crop? not setting any values for any of the crop parameters simply crops to half because of default = 0.5 for both of crop_pos_x and crop_pos_y.

You can actually do it with the Cast operator.

how? i tried to cast prior to resize (because resize expects uint8) but i just … a cast i.e. my images were all black because i got truncation of the floats the came out of the decoder.

I guess you mean Crop after Resize.

sorry yes you’re right.

but performance-wise it is better to first crop first to not apply the interpolation algorithm on a data that will be cropped anyway.

that makes sense but i feel like i should be able to choose to take the penalty. just so we’re concrete: i would like to crop a high resolution image randomly, then resize to half (or quarter scale) in order to produce a low resolution image (i’m working on super resolution networks). right now i do using decodercrop then resize and i have to normalize in just plain pytorch

not ideal. i want to compose decoder -> cropmirrornormalize -> resize but i can’t because of data width mismatches (if i recall correctly resize complains about not getting uint8s for this composition).

It is great to see that you are able to use it to accelerate your training.

yes it’s definitely great - i’m able to saturate my gpus (sometimes a little too much and they get hot :slight_smile: .