I’m currently working with a version of Improved WGAN. My architecture is exactly the one described in the DCGAN repo (minus the sigmoid activation in the discriminator). To calculate the gradient penalty, I’m using the following lines of code (adapted from caogang’s github) :
def calc_gradient_penalty(netD, real_data, fake_data, batch_size=50, gpu=0): alpha = torch.rand(batch_size, 1, 1) alpha = alpha.expand(real_data.size()) alpha = alpha.cuda(gpu) interpolates = alpha * real_data + ((1 - alpha) * fake_data) interpolates = interpolates.cuda(gpu) interpolates = autograd.Variable(interpolates, requires_grad=True) disc_interpolates = netD(interpolates) gradients = autograd.grad(outputs=disc_interpolates, inputs=interpolates, grad_outputs = torch.ones(disc_interpolates.size()).cuda(gpu), create_graph = True, retain_graph=True, only_inputs=True) return ((gradients.norm(2, dim=1) - 1) ** 2).mean()
The code does not return any errors, however, the GPU memory goes up until no more is available, then the program crashes. As soon as I remove the BatchNorm Layers in the discriminator, the problem is fixed.
Side note, removing batch norm fixed another problem I had, which was that the critic values returns were very high (on the order of 10x5, sometimes givings NaNs). That being said, I’m not sure if this is due to the same bug.