Losses vary slightly around 50 and don't reduce

I’m trying to build cnn that will classify images of differences between cats and dogs. I wrote such a neural network:

import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def conv2d(in_channels, out_channels, kernel_size):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding='valid'),
        nn.BatchNorm2d(out_channels),
        nn.ReLU()
    )


def initialise_kernel(m):
        if isinstance(m, nn.Conv2d):
            # Initialize kernels of Conv2d layers as kaiming normal
            nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
            # Initialize biases of Conv2d layers at 0
            nn.init.zeros_(m.bias)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)

class ConvNet(nn.Module):
    def __init__(self, channels=3):
        super(ConvNet, self).__init__()

        self.conv1 = conv2d(3, 32, 3)
        self.conv2 = conv2d(32, 64, 3)
        self.conv3 = conv2d(64, 128, 3)
        self.conv4 = conv2d(128, 128, 3)

        self.conv1.apply(initialise_kernel)
        self.conv2.apply(initialise_kernel)
        self.conv3.apply(initialise_kernel)
        self.conv4.apply(initialise_kernel)

        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.linear = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(6272, 512),
            # nn.BatchNorm1d(512),
            nn.ReLU(True),
            # nn.Dropout(0.5),
            nn.Linear(512, 1),
            # nn.Sigmoid()
        )

    def forward(self, x):
        x = self.max_pool(self.conv1(x))
        x = self.max_pool(self.conv2(x))
        x = self.max_pool(self.conv3(x))
        x = self.max_pool(self.conv4(x))
        x = torch.flatten(x, 1)
        x = self.linear(x)
        return x

I’m using those optimyzer, criterion and loss:

criterion = nn.BCELoss()
optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

# Forward pass
output = model(images)
loss = criterion(torch.sigmoid(output), labels.view(-1, 1).float())

optimizer.zero_grad()
loss.backward()
optimizer.step()

My problem is that my losses hover around 50:

Epoch: 1/1000:
	 Training: |================| --> Train_loss: 50.0992
	 Testing:  |================| --> Test loss: 50.2930
Epoch: 20/1000:
	 Training: |================| --> Train_loss: 50.2480
	 Testing:  |================| --> Test loss: 49.1211
Epoch: 31/1000:
	 Training: |================| --> Train_loss: 50.0000
	 Testing:  |================| --> Test loss: 50.0000
Epoch: 40/1000:
	 Training: |================| --> Train_loss: 50.0496
	 Testing:  |================| --> Test loss: 50.5859
Epoch: 61/1000:
	 Training: |================| --> Train_loss: 49.8512
	 Testing:  |================| --> Test loss: 50.2930

I assume that this problem occurs due to the fact that after all the layers of the neural network the output is much greater than one, which is why after the activation function all output values become 1

tensor([[60.7288],
        [44.7729],
        [46.9832],
        [37.3297],
        [49.7330],
        [37.3509],
        [38.2135],
        [37.5487],
        [42.5708],
        [55.1753],
        [41.3743],
        [50.1724],
        [96.6568],
        [35.2697],
        [39.7889],
        [58.9670]], device='cuda:0', grad_fn=<AddmmBackward0>)

How can i fix this problem and what did I do wrong?

Hi sbht!

As you’ve seen, the sigmoid() you are using can easily saturate, after
which useful training stops. This is a well-known numerical problem with
using BCELoss.

Instead use BCEWithLogitsLoss (without the sigmoid() as
BCEWithLogitsLoss has logsigmoid() built into it).

If you still have problems, try training with the plain-vanilla SGD optimizer,
with no momentum nor weight decay and a small learning rate. I find that
when my training is unstable or diverges, it is helpful to get it working with
SGD and only then try to speed up training by adding things like momentum
or using a fancier optimizer.

Best.

K. Frank

1 Like

Thank you so much! Your advice helped me. But now I don’t understand, for what tasks can I use a regular BCELoss?

Hi sbht!

I can’t think of a use case where using BCELoss – with its numerical
instability – would ever be preferable to using BCEWithLogitsLoss.

I suppose that if your network naturally produced probabilities (zero to
one) rather than logits (-inf to inf), then you would be better off using
BCELoss rather than converting your probabilities into logits and using
BCEWithLogitsLoss. But I can’t think of a use case where your network
would naturally produce probabilities. Predictions are typically the output
of a final Linear (or convolutional) layer whose output does range from
-inf to inf.

Best.

K. Frank

1 Like