AMP for DCGAN training

Hello,

I’m trying to implement automatic mixed precision for a DCGAN model, but after profiling the model using DLProf, it still says that AMP is not enabled.

I read everything in the documentation about Automatic Mixed Precision, but can’t figure out what I did wrong. Can anyone help with this?

Thanks a lot.

from os import path, makedirs

from PIL import Image
from torch import nn, load, optim, ones, zeros, randn, mean, FloatTensor
from torch.autograd import Variable
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader

import model
from text2image_dataset import Text2ImageDataset
from utils import Utils


class Trainer(object):
    def __init__(self, dataset, split, lr, save_path, l1_coef, l2_coef, pre_trained_gen,
                 pre_trained_disc, batch_size, num_workers, epochs):

        self.generator = nn.DataParallel(model.generator().cuda())
        self.discriminator = nn.DataParallel(model.discriminator().cuda())

        if pre_trained_disc:
            self.discriminator.load_state_dict(load(pre_trained_disc))
        else:
            self.discriminator.apply(Utils.weights_init)

        if pre_trained_gen:
            self.generator.load_state_dict(load(pre_trained_gen))
        else:
            self.generator.apply(Utils.weights_init)

        self.dataset = Text2ImageDataset(dataset, split=split)

        self.noise_dim = 100
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.lr = lr
        self.beta1 = 0.5
        self.num_epochs = epochs

        self.l1_coef = l1_coef
        self.l2_coef = l2_coef

        self.data_loader = DataLoader(self.dataset, batch_size=self.batch_size, shuffle=True,
                                      num_workers=self.num_workers)

        self.optimizerD = optim.Adam(self.discriminator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))
        self.optimizerG = optim.Adam(self.generator.parameters(), lr=self.lr, betas=(self.beta1, 0.999))

        # AMP
        # Creates a GradScaler once at the beginning of training.
        self.scalerD = GradScaler(enabled=True)
        self.scalerG = GradScaler(enabled=True)

        self.checkpoints_path = 'checkpoints'
        self.save_path = save_path

    def train(self):
        criterion = nn.BCEWithLogitsLoss()
        l2_loss = nn.MSELoss()
        l1_loss = nn.L1Loss()
        iteration = 0

        for epoch in range(self.num_epochs):
            for sample in self.data_loader:
                iteration += 1
                right_images = sample['right_images']
                right_embed = sample['right_embed']

                right_images = Variable(right_images.float()).cuda()
                right_embed = Variable(right_embed.float()).cuda()

                real_labels = ones(right_images.size(0))
                fake_labels = zeros(right_images.size(0))

                smoothed_real_labels = FloatTensor(Utils.smooth_label(real_labels.numpy(), -0.1))

                real_labels = Variable(real_labels).cuda()
                smoothed_real_labels = Variable(smoothed_real_labels).cuda()
                fake_labels = Variable(fake_labels).cuda()

                # =============================================
                # ========= Train the discriminator ===========
                # =============================================
                self.optimizerD.zero_grad()
                self.discriminator.zero_grad()

                with autocast(enabled=True):
                    outputs, activation_real = self.discriminator(right_images, right_embed)
                    real_loss = criterion(outputs, smoothed_real_labels)
                    real_score = outputs

                    noise = Variable(randn(right_images.size(0), 100)).cuda()
                    noise = noise.view(noise.size(0), 100, 1, 1)
                    fake_images = self.generator(right_embed, noise)
                    outputs, _ = self.discriminator(fake_images, right_embed)
                    fake_loss = criterion(outputs, fake_labels)
                    fake_score = outputs

                    d_loss = real_loss + fake_loss

                # AMP
                self.scalerD.scale(d_loss).backward()
                self.scalerD.step(optimizer=self.optimizerD)
                self.scalerD.update()

                # =============================================
                # =========== Train the generator =============
                # =============================================
                self.optimizerG.zero_grad()
                self.generator.zero_grad()

                with autocast(enabled=True):
                    noise = Variable(randn(right_images.size(0), 100)).cuda()
                    noise = noise.view(noise.size(0), 100, 1, 1)
                    fake_images = self.generator(right_embed, noise)
                    outputs, activation_fake = self.discriminator(fake_images, right_embed)
                    _, activation_real = self.discriminator(right_images, right_embed)

                    activation_fake = mean(activation_fake, 0)
                    activation_real = mean(activation_real, 0)

                    # ======= Generator Loss function============ 
                    g_loss = criterion(outputs, real_labels) \
                             + self.l2_coef * l2_loss(activation_fake, activation_real.detach()) \
                             + self.l1_coef * l1_loss(fake_images, right_images)

                # AMP
                self.scalerG.scale(g_loss).backward()
                self.scalerG.step(optimizer=self.optimizerG)
                self.scalerG.update()

                if iteration % 5 == 0:
                    print("Epoch: %d, iteration: %d, d_loss= %f, g_loss= %f, D(X)= %f, D(G(X))= %f" % (
                        epoch, iteration, d_loss.data.cpu().mean(), g_loss.data.cpu().mean(),
                        real_score.data.cpu().mean(),
                        fake_score.data.cpu().mean()))

            if epoch >= 0:
                Utils.save_checkpoint(self.discriminator, self.generator, self.checkpoints_path, self.save_path, epoch)

DLProf is not supported anymore, so I would be careful about it’s ability to detect amp usage using a current PyTorch release.
Are you seeing any speedups using amp and could you share which GPU you are using, as the actual speedup should be visible if your device has TensorCores?
Also, remove the Variable usage as this class was deprecated in PyTorch 0.4. :wink:

1 Like

Thank you so much for the tips! :smiley:

Not seeing any speedups, the running time is exactly the same for both the AMP version and the original code. I also profiled the code using Pytorch Profiles, now that you’ve mentioned DLProf is not supported, but I got the same results.

I tested the code on an NVIDIA GeForce RTX 2080 Ti.

Did you see where the bottleneck is coming from in the profile or were you able to compare the different kernel times? If so, were you seeing any change in the kernels and their runtime?
Also, could you rerun the profile with a static random input tensor to check if the data loading might create a bottleneck?

1 Like

You’re right! The data loading part does create a bottleneck, when increasing the number of workers I notice a difference between the AMP version and the original one (initially the timings were the same). And I can also view the tensor core operations using the pytorch profiler, thank you! :smile:

1 Like