nn.CrossEntropyLoss for conditional GAN enormous

I’m trying to use torch.nn.CrossEntropyLoss in the discriminator of a conditional DCGAN-based GAN, which uses images of 1 of 27 classes/categories, in addition to the usual torch.nn.BCELoss the discriminator uses, as I want the discriminator to also be able to classify the class of images it receives as well as discern real from fake images.

Though I don’t get any runtime errors, this additional loss is enormous (>400), and doesn’t decrease significantly, even after 25 epochs. The data set I’m using is ±400k labeled images of size 64x64, so I don’t think it’s a data problem. I figure I must be doing something seriously wrong!

This is the structure of the discriminator:

      #y_dim = number of image categories/classes  
    def __init__(self, channels,y_dim, num_disc_filters):
            super(CanDiscriminator, self).__init__()
            self.ngpu = 1
            self.conv = nn.Sequential(
            nn.Conv2d(channels, num_disc_filters // 2, 4, 2, 1, bias=False),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(num_disc_filters // 2, num_disc_filters, 4, 2, 1, bias=False),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(num_disc_filters, num_disc_filters * 2, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(num_disc_filters * 2),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(num_disc_filters * 2, num_disc_filters * 4, 4, 2, 1, bias=False),
                    nn.BatchNorm2d(num_disc_filters * 4),
                    nn.LeakyReLU(0.2, inplace=True),
                    nn.Conv2d(num_disc_filters * 4, num_disc_filters * 8, 4, 1, 0, bias=False),
                    nn.BatchNorm2d(num_disc_filters * 8),
                    nn.LeakyReLU(0.2, inplace=True),
            self.real_fake_head = nn.Linear(num_disc_filters * 8, 1)
            self.sig = nn.Sigmoid()
            self.fc = nn.Sequential() 
    def forward(self, inp):
        x = self.conv(inp)
        x = x.view(x.size(0),-1) 
        real_out = self.sig(self.real_fake_head(x))
        real_out = real_out.view(-1,1).squeeze(1)
        category = self.fc(x) 
        # category = image class
        # real_out = real/fake classification
        return real_out, category

when training, I do the following:

class_loss = nn.CrossEntropyLoss(reduction="sum")
# predicted_labels = the 1D tensor of labels the discriminator assigns to the images
# actual_labels = 1D tensor of the images' real labels
disc_closs_loss = class_loss(predicted_labels, actual_labels)

and then later add disc_class_loss to the normal discriminator BCE loss.

However, something goes horribly wrong somewhere.
As an example, on the first batch of the first epoch is:
Loss_D: 1.4125 Loss_G: 0.0834 Class_D: 421.9417

where Loss_D is the normal discriminator BCE loss and Class_D is the enormous CrossEntropy loss. Additionally, when I print out the predicted_labels tensor, I get this kind of result:

tensor([[0.0352, 0.0435, 0.0442,  ..., 0.0407, 0.0421, 0.0362],
        [0.0386, 0.0392, 0.0466,  ..., 0.0358, 0.0319, 0.0398], 

Where am I going wrong? Do I need to use nn.CrossEntropyLoss to achieve this goal of classification of class/category for my discriminator? I’ve seen similar GANs in TensorFlow use a sigmoid_cross_entropy_loss_with_logits loss, of which the PyTorch equivalent is nn.CrossEntropyLoss (AFAIK), but they don’t seem to suffer from this problem.

Any help would be greatly appreciated! Thanks in advance.

Your loss will be a bit higher, since you pass reduction='sum' to nn.CrossEntropyLoss.
Would the default elementwise_mean not work?

EDIT: That being said, you should also have a look at Conditional GANs, which use a slightly different approach to create class-specific images.

I’ve tried with the default elementwise_mean and it’s still >400 unfortunately.
The Conditional GAN is a bit different, as the generator is explicitly given labels, but I could look at it again.

Is there perhaps another reason why the loss would be so high?
Also, a PyTorch forum post said that this is now the PyTorch equivalent of sigmoid_cross_entropy_loss_with_logits. Is that true, and is this loss appropriate for a multiclass problem? (i.e. >2 classes, but still only 1 class per image)

Yes, sigmoid_cross_entropy_loss_with_logits should be equivalent to nn.BCEWithLogitsLoss.
However, if you’re dealing with a multi-class classification, i.e. only one positive class per sample, you would usually use nn.CrossEntropyLoss. You could use nn.BCEWithLogitsLoss for a multi-label classification, i.e. more than one positive class per sample.

Are you using a batch size of 1?
If using the mean doesn’t work, you might want to weight both losses to get approx. the same range.

Forgot to say in my previous message, but thanks for the help and advice! Really appreciate it.

Yes, it is only one positive class per sample. My batch size is 128. I could weight the losses, but how would I go about doing so such that I don’t severely compromise the two objectives (real vs fake and category/class distinction)? The class distinction is pretty important, but obviously I wouldn’t want the discriminator to be terrible at recognising fakes.

Any weighting advice? Or just trial-and-error weighting parameters (that could maybe change during training if a loss threshold is achieved)?

Thanks again!

Well, I would try to scale the classification loss to match the discriminator loss.

But before you are doing that, let’s dig a bit deeper into the CrossEntropyLoss.
You’ve said, that elementwise_mean doesn’t reduce the loss and it stays approx. at >400.
However, you’re using a batch size of 128. So the sum and mean should definitely change the loss value.
Could you explain your setup a bit?
How many classes are you dealing with?

Just re-ran and double checked it with elementwise_mean and it is significantly lower (so sorry for that misinformation), but it is still at least 5x higher than the normal BCE loss. I have 27 classes. As for setup, what exactly do you mean?

Yeah, I meant the shapes etc. as setup, but fortunately the loss went down. Puh :sweat_smile:

If the loss is still approx. 5x higher, your model might still learn something useful.
Could you try to run the training again? If nothing happens, you should weight this loss down a bit or adapt your learning rate.

Oh okay - I’ll put in the shapes etc. next time :slight_smile:
Will do! Thanks once again!

Sorry for the late reply! Thanks very much; the loss did go down. I made the following PyTorch forum post that unfortunately gained no traction. Any chance you could have a quick peek? Many many thanks once again; you are one of the main reasons I really like the PyTorch forum :smile:

1 Like