I’m investigating the use of a Wasserstein GAN with gradient penalty in PyTorch. I’m heavily borrowing from Caogang’s implementation, but am using the discriminator and generator losses used in this implementation because I get Invalid gradient at index 0 - expected shape[] but got [1]
if I try to call .backward()
with the one
and mone
args used in the Caogang implementation.
I’m training on a dataset of 400k 64x64 images, and have gotten a normal WGAN (with weight clipping to work) [i.e. it produces passable images after 25 epochs], despite the fact that the D and G losses both hover around 3. I calculate them using torch.mean(D_real)
etc. for all epochs. However, in the WGAN-GP version, the generator loss increases dramatically (starts at ~24, then climbs rapidly to 6000 (!) in only the 6th epoch), while the discriminator loss starts at -7, decreases to -5000, then by the 6th epoch is up to +50 (!). WGAN-GP and LSGAN versions of my GAN both completely fail to produce passable images even after 25 epochs. I use nn.MSELoss()
for the LSGAN version of my GAN.
I don’t use any tricks like one-sided label smoothing, and I train with default learning rats in both the LSGAN and WGANGP papers. I use the Adam optimizer and I train the discriminator 5 times for every generator update in my WGANs. Why does this crazy loss behavior happen, and why does the normal weight-clipping WGAN still ‘work’ but WGANGP and LSGAN completely fail?
This happens when using LSGAN or WGANGP irrespective of the structure, whether both G and D are normal DCGANs or when using this modified DCGAN, the Creative Adversarial Network, which requires that D be able to classify images and G generate ambiguous images. It does this through an additional K-way classification loss, for which I’m using nn.CrossEntropyLoss
, and adding to D_loss
.
I get erratic loss behavior (G and D losses not steadily decreasing, but instead going up and down) in ‘normal’ DCGAN versions of my GAN, with a nn.BCELoss
and the following D and G networks:
class Can64Discriminator(nn.Module):
def __init__(self, channels,y_dim, num_disc_filters):
super(Can64Discriminator, self).__init__()
self.ngpu = 1
self.conv = nn.Sequential(
nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_disc_filters * 8),
nn.LeakyReLU(0.2, inplace=True),
)
self.real_fake_head = nn.Linear(num_disc_filters * 8, 1)
# no bn and lrelu needed
self.sig = nn.Sigmoid()
self.fc = nn.Sequential()
self.fc.add_module("linear_layer{0}".format(num_disc_filters*16),nn.Linear(num_disc_filters*8,num_disc_filters*16))
self.fc.add_module("linear_layer{0}".format(num_disc_filters*8),nn.Linear(num_disc_filters*16,num_disc_filters*8))
self.fc.add_module("linear_layer{0}".format(num_disc_filters),nn.Linear(num_disc_filters*8,y_dim))
self.fc.add_module('softmax',nn.Softmax(dim=1))
def forward(self, inp):
x = self.conv(inp)
x = x.view(x.size(0),-1)
real_out = self.sig(self.real_fake_head(x))
real_out = real_out.view(-1,1).squeeze(1)
style = self.fc(x)
return real_out,style
class Can64Generator(nn.Module):
def __init__(self, z_noise, channels, num_gen_filters):
super(Can64Generator,self).__init__()
self.ngpu = 1
self.main = nn.Sequential(
nn.ConvTranspose2d(z_noise, num_gen_filters * 16, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_gen_filters * 16),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 16, num_gen_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters * 4),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 4, num_gen_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters * 2),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters * 2, num_gen_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_gen_filters),
nn.ReLU(True),
nn.ConvTranspose2d(num_gen_filters, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, inp):
output = self.main(inp)
return output
What could be causing this? I’d like to make as minimal change as possible, as I want to compare loss functions alone. Any help would be greatly appreciated.
Thanks in advance!