Extremely negative large losses during training for facial landmark prediction problem

This block of code reports very bad losses:

model = ConvNet(num_classes).to(device)

# Loss and optimizer
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
    for i, sample_batched in enumerate(train_loader):
        #print(i, sample_batched['image'].size(),
        # sample_batched['landmarks'].size())
        
        images_batch, landmarks_batch = \
            sample_batched['image'], sample_batched['landmarks']
        
        images = images_batch
        labels = landmarks_batch.reshape(-1, 68 * 2)
        
        images = Variable(images.float())
        labels = Variable(labels)
        
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        outputs = model(images)
        
        #print("Label Shape", labels.shape, "Output Shape", outputs.shape)
        
        
        loss = criterion(outputs, labels.float())
        
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (i+1) % 5 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
                   .format(epoch+1, num_epochs, i+1, total_step, loss.item()))

The losses are:

Epoch [1/15], Step [5/17], Loss: -19443.7852
Epoch [1/15], Step [10/17], Loss: -48699.7461
Epoch [1/15], Step [15/17], Loss: -71593.1797
Epoch [2/15], Step [5/17], Loss: -126392.4062
Epoch [2/15], Step [10/17], Loss: -119797.7344
Epoch [2/15], Step [15/17], Loss: -234912.5000
Epoch [3/15], Step [5/17], Loss: -271398.0000
Epoch [3/15], Step [10/17], Loss: -428674.2500
Epoch [3/15], Step [15/17], Loss: -285817.3125
Epoch [4/15], Step [5/17], Loss: -387621.0938
Epoch [4/15], Step [10/17], Loss: -488975.9375
Epoch [4/15], Step [15/17], Loss: -593710.8750
Epoch [5/15], Step [5/17], Loss: -565681.6250
Epoch [5/15], Step [10/17], Loss: -633298.5000
Epoch [5/15], Step [15/17], Loss: -712257.2500
Epoch [6/15], Step [5/17], Loss: -708618.5000
Epoch [6/15], Step [10/17], Loss: -627527.0625
Epoch [6/15], Step [15/17], Loss: -996953.1875
Epoch [7/15], Step [5/17], Loss: -956897.9375
Epoch [7/15], Step [10/17], Loss: -1206485.2500
Epoch [7/15], Step [15/17], Loss: -1061303.5000
Epoch [8/15], Step [5/17], Loss: -1044490.9375
Epoch [8/15], Step [10/17], Loss: -1279828.0000
Epoch [8/15], Step [15/17], Loss: -1121125.0000
Epoch [9/15], Step [5/17], Loss: -1569353.0000
Epoch [9/15], Step [10/17], Loss: -1250445.8750
Epoch [9/15], Step [15/17], Loss: -1473160.0000
Epoch [10/15], Step [5/17], Loss: -1510457.2500
Epoch [10/15], Step [10/17], Loss: -2036726.8750
Epoch [10/15], Step [15/17], Loss: -1585122.0000
Epoch [11/15], Step [5/17], Loss: -2009158.3750
Epoch [11/15], Step [10/17], Loss: -1488373.8750
Epoch [11/15], Step [15/17], Loss: -1942903.0000
Epoch [12/15], Step [5/17], Loss: -2108666.7500
Epoch [12/15], Step [10/17], Loss: -2771515.0000
Epoch [12/15], Step [15/17], Loss: -2192605.7500
Epoch [13/15], Step [5/17], Loss: -2039266.5000
Epoch [13/15], Step [10/17], Loss: -3268939.7500
Epoch [13/15], Step [15/17], Loss: -2243900.0000
Epoch [14/15], Step [5/17], Loss: -3102693.5000
Epoch [14/15], Step [10/17], Loss: -2941674.0000
Epoch [14/15], Step [15/17], Loss: -3125370.7500
Epoch [15/15], Step [5/17], Loss: -2586744.7500
Epoch [15/15], Step [10/17], Loss: -1802294.8750
Epoch [15/15], Step [15/17], Loss: -4690499.0000

What are some ways to make them more sensible? This is for facial landmark prediction.

I understand the following ConvNet is super simple but is that to be the only reason for these terrible losses?

num_classes = 68 * 2 #68 coordinates X and Y flattened

class ConvNet(nn.Module):
    def __init__(self, num_classes=10):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 18, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(18),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
        self.layer2 = nn.Sequential(
            nn.Conv2d(18, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2, padding=0))
        self.fc = nn.Linear(32 * 56 * 56, num_classes)
        
    def forward(self, x):
        out = self.layer1(x)
        #print("After Layer 1", out.shape)
        out = self.layer2(out)
        #print("After Layer 2", out.shape)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        #print("Final Layer Output Shape", out.shape)
        return out

Hello Mona!

You are probably passing values for target (your labels) that are
either negative or larger than one to MultiLabelSoftMarginLoss,
causing it to return invalid negative values.

Conceptually, MultiLabelSoftMarginLoss should never return a
negative loss value. Its input (the outputs of your model) should
be logits that run from -inf to inf that are, in effect, converted to
probabilities inside of MultiLabelSoftMarginLoss. (This part is
correct in your code as your model returns the result of its last
Linear layer for your outputs.)

Its target (your labels), however, should be probabilities that run
from 0.0 to 1.0*. If the target values are outside of this range, you
violate the requirements of MultiLabelSoftMarginLoss, which can
then return (invalid) negative values, and your training will likely diverge
(which it appears to be doing).

Check that your labels lie within [0.0, 1.0], and if not, ask what they
mean conceptually, and how (and whether) they can be converted to
probability-like values.

*) The documentation for MultiLabelSoftMarginLoss states that

 y[i]∈{0,  1},

that is, that y (the target) is equal to either 0 or 1. But I believe
that MultiLabelSoftMarginLoss is meaningful for values of y that
range from 0.0 to 1.0, and the code accepts such values.

Best.

K. Frank

1 Like