I’m trying to use torch.nn.CrossEntropyLoss
in the discriminator of a conditional DCGAN-based GAN, which uses images of 1 of 27 classes/categories, in addition to the usual torch.nn.BCELoss
the discriminator uses, as I want the discriminator to also be able to classify the class of images it receives as well as discern real from fake images.
Though I don’t get any runtime errors, this additional loss is enormous (>400), and doesn’t decrease significantly, even after 25 epochs. The data set I’m using is ±400k labeled images of size 64x64, so I don’t think it’s a data problem. I figure I must be doing something seriously wrong!
This is the structure of the discriminator:
#y_dim = number of image categories/classes
def __init__(self, channels,y_dim, num_disc_filters):
super(CanDiscriminator, self).__init__()
self.ngpu = 1
self.conv = nn.Sequential(
nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(num_disc_filters * 4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(num_disc_filters * 8),
nn.LeakyReLU(0.2, inplace=True),
)
self.real_fake_head = nn.Linear(num_disc_filters * 8, 1)
self.sig = nn.Sigmoid()
self.fc = nn.Sequential()
self.fc.add_module("linear_layer{0}".format(num_disc_filters*16),nn.Linear(num_disc_filters*8,num_disc_filters*16))
self.fc.add_module("linear_layer{0}".format(num_disc_filters*8),nn.Linear(num_disc_filters*16,num_disc_filters*8))
self.fc.add_module("linear_layer{0}".format(num_disc_filters),nn.Linear(num_disc_filters*8,y_dim))
self.fc.add_module('softmax',nn.Softmax(dim=1))
def forward(self, inp):
x = self.conv(inp)
x = x.view(x.size(0),-1)
real_out = self.sig(self.real_fake_head(x))
real_out = real_out.view(-1,1).squeeze(1)
category = self.fc(x)
# category = image class
# real_out = real/fake classification
return real_out, category
when training, I do the following:
class_loss = nn.CrossEntropyLoss(reduction="sum")
# predicted_labels = the 1D tensor of labels the discriminator assigns to the images
# actual_labels = 1D tensor of the images' real labels
disc_closs_loss = class_loss(predicted_labels, actual_labels)
disc_closs_loss.backward(retain_graph=True)
and then later add disc_class_loss
to the normal discriminator BCE loss.
However, something goes horribly wrong somewhere.
As an example, on the first batch of the first epoch is:
Loss_D: 1.4125 Loss_G: 0.0834 Class_D: 421.9417
where Loss_D
is the normal discriminator BCE loss and Class_D
is the enormous CrossEntropy
loss. Additionally, when I print out the predicted_labels
tensor, I get this kind of result:
tensor([[0.0352, 0.0435, 0.0442, ..., 0.0407, 0.0421, 0.0362],
[0.0386, 0.0392, 0.0466, ..., 0.0358, 0.0319, 0.0398],
Where am I going wrong? Do I need to use nn.CrossEntropyLoss
to achieve this goal of classification of class/category for my discriminator? I’ve seen similar GANs in TensorFlow use a sigmoid_cross_entropy_loss_with_logits
loss, of which the PyTorch equivalent is nn.CrossEntropyLoss
(AFAIK), but they don’t seem to suffer from this problem.
Any help would be greatly appreciated! Thanks in advance.