1) Confusion about training loss, 2) why validation loss is so high?

I am constructing a DNN model using pytorch. I came across two ‘loss’ terms for training (see code below). My questions:

  1. Which training loss should I compare against validation loss? Shall I use loss or running loss?
  2. In both cases, training loss differs by several folds as compare to validation loss. What am I doing wrong?
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1105, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 64)
        self.fc4 = nn.Linear(64, 16)
        self.fc5 = nn.Linear(16, 2)
        self.bn_1 = nn.BatchNorm1d(1024, momentum=0.5)
        self.bn_2 = nn.BatchNorm1d(256, momentum=0.5)
        self.bn_3 = nn.BatchNorm1d(64, momentum=0.5)
        self.bn_4 = nn.BatchNorm1d(16, momentum=0.5)
        self.dropout = nn.Dropout(p=0.2)
    
    def forward(self, inputs):
        x = inputs
        x = self.dropout(self.bn_1(F.relu(self.fc1(x))))
        x = self.dropout(self.bn_2(F.relu(self.fc2(x))))
        x = self.dropout(self.bn_3(F.relu(self.fc3(x))))
        x = self.dropout(self.bn_4(F.relu(self.fc4(x))))
        x = self.fc5(x)
        return F.log_softmax(x, dim=1)

my_model = Net().to(device)

# data shape:: Train: (163082, 1105) Validation: (40771, 1105)
BATCH_SIZE = 32
nSamples = [157544, 5538] #composition of classes in my dataset for class 0 and 1, resp
normedWeights = [1 - (x / sum(nSamples)) for x in nSamples]
normedWeights = torch.FloatTensor(normedWeights).to(device, dtype=torch.float64)
loss_function = nn.NLLLoss(weight=normedWeights)
optimizer = optim.Adam(my_model.parameters(), lr=0.0001, eps=0.0001, amsgrad=True, weight_decay=1e-5)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10, verbose=True)

def train(my_model):
    EPOCHS = 100
    for epoch in range(EPOCHS):
        train_batch_loss = 0.0
        val_loss = 0.0
        correct = 0
        total = 0
        val_correct = 0
        val_total = 0
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(device, dtype=torch.float32), y.to(device, dtype=torch.long)
            optimizer.zero_grad()  
            output = my_model(x)
            loss = loss_function(output.double(), y)
            train_batch_loss += F.nll_loss(output.double(), y, reduction='sum')
            loss.backward()
            #torch.nn.utils.clip_grad_norm_(my_model.parameters(), 0.5)
            optimizer.step()
        _, predicted = torch.max(output.data, 1)
        correct += predicted.eq(y).sum().item()
        total += len(y)
        scheduler.step(i)
        running_loss = train_batch_loss/len(train_loader)
            
        for j, (x_val, y_val) in enumerate(val_loader):
            x_val, y_val = x_val.to(device, dtype=torch.float32), y_val.to(device, dtype=torch.long)
            with torch.no_grad():
                my_model.eval()
                val_outputs = my_model(x_val)
                val_loss1 = loss_function(val_outputs.double(), y_val)
                val_loss += val_loss1.item()
        _, val_predicted = torch.max(val_outputs.data, 1)
        val_correct += val_predicted.eq(y_val).sum().item()
        val_total += len(y_val)
        print(f"Epoch {epoch}. LOSS:: Train_loss {loss:.4f}  Train_running_loss {running_loss:.4f}   Val_loss {val_loss:.4f}")
        print(f"Accuracy:: Train {correct/total:.4f}  Validation {val_correct/val_total:.4f}")

output

Epoch 0. LOSS:: Train_loss 0.1771  Train_running_loss 6.0901   Val_loss 385.8817
         Accuracy:: Train 0.9000  Validation 1.0000
Epoch 1. LOSS:: Train_loss 0.0815  Train_running_loss 5.8791   Val_loss 450.4003
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 2. LOSS:: Train_loss 0.1021  Train_running_loss 5.5039   Val_loss 442.1833
         Accuracy:: Train 0.9000  Validation 1.0000
Epoch 3. LOSS:: Train_loss 0.2438  Train_running_loss 5.4403   Val_loss 415.6329
         Accuracy:: Train 0.9000  Validation 1.0000
Epoch 4. LOSS:: Train_loss 0.0914  Train_running_loss 5.1918   Val_loss 380.8049
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 5. LOSS:: Train_loss 0.0973  Train_running_loss 5.0323   Val_loss 401.0601
         Accuracy:: Train 0.9000  Validation 1.0000
Epoch 6. LOSS:: Train_loss 0.1308  Train_running_loss 4.8982   Val_loss 409.4360
         Accuracy:: Train 0.9000  Validation 1.0000
Epoch 7. LOSS:: Train_loss 0.0257  Train_running_loss 4.7714   Val_loss 456.5145
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 8. LOSS:: Train_loss 0.0241  Train_running_loss 4.9047   Val_loss 533.7724
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 9. LOSS:: Train_loss 0.0700  Train_running_loss 4.7236   Val_loss 423.4640
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 10. LOSS:: Train_loss 0.0446  Train_running_loss 4.4727   Val_loss 410.2520
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch    12: reducing learning rate of group 0 to 1.0000e-04.
Epoch 11. LOSS:: Train_loss 0.0138  Train_running_loss 4.2515   Val_loss 492.4581
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 12. LOSS:: Train_loss 0.0056  Train_running_loss 3.7988   Val_loss 496.0938
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 13. LOSS:: Train_loss 0.0163  Train_running_loss 3.5083   Val_loss 493.1467
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 14. LOSS:: Train_loss 0.0043  Train_running_loss 3.4447   Val_loss 522.9645
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 15. LOSS:: Train_loss 0.0061  Train_running_loss 3.3147   Val_loss 533.8937
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 16. LOSS:: Train_loss 0.0708  Train_running_loss 3.3023   Val_loss 505.2642
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 17. LOSS:: Train_loss 0.0039  Train_running_loss 3.1817   Val_loss 529.2422
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 18. LOSS:: Train_loss 0.0055  Train_running_loss 3.1025   Val_loss 548.8464
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 19. LOSS:: Train_loss 0.0276  Train_running_loss 3.0549   Val_loss 560.4675
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 20. LOSS:: Train_loss 0.0037  Train_running_loss 3.0281   Val_loss 586.8056
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 21. LOSS:: Train_loss 0.0060  Train_running_loss 2.9651   Val_loss 603.7631
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch    23: reducing learning rate of group 0 to 1.0000e-05.
Epoch 22. LOSS:: Train_loss 0.0027  Train_running_loss 2.9114   Val_loss 487.0047
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 23. LOSS:: Train_loss 0.0024  Train_running_loss 2.8859   Val_loss 565.6329
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 24. LOSS:: Train_loss 0.0097  Train_running_loss 2.8546   Val_loss 591.2556
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 25. LOSS:: Train_loss 0.0021  Train_running_loss 2.8961   Val_loss 588.1403
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 26. LOSS:: Train_loss 0.0045  Train_running_loss 2.8124   Val_loss 506.0154
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 27. LOSS:: Train_loss 0.0070  Train_running_loss 2.7557   Val_loss 505.5475
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 28. LOSS:: Train_loss 0.0048  Train_running_loss 2.8138   Val_loss 626.8176
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 29. LOSS:: Train_loss 0.0009  Train_running_loss 2.7847   Val_loss 540.7521
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 30. LOSS:: Train_loss 0.0098  Train_running_loss 2.7887   Val_loss 539.4330
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 31. LOSS:: Train_loss 0.0185  Train_running_loss 2.7797   Val_loss 574.4599
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 32. LOSS:: Train_loss 0.0026  Train_running_loss 2.7819   Val_loss 527.2850
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch    34: reducing learning rate of group 0 to 1.0000e-06.
Epoch 33. LOSS:: Train_loss 0.0035  Train_running_loss 2.7485   Val_loss 533.4987
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 34. LOSS:: Train_loss 0.0033  Train_running_loss 2.7217   Val_loss 548.1089
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 35. LOSS:: Train_loss 0.0086  Train_running_loss 2.7513   Val_loss 528.2917
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 36. LOSS:: Train_loss 0.0017  Train_running_loss 2.7649   Val_loss 544.4396
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 37. LOSS:: Train_loss 0.0131  Train_running_loss 2.7475   Val_loss 570.9015
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 38. LOSS:: Train_loss 0.0010  Train_running_loss 2.7525   Val_loss 586.9201
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 39. LOSS:: Train_loss 0.0047  Train_running_loss 2.7542   Val_loss 568.6170
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 40. LOSS:: Train_loss 0.0269  Train_running_loss 2.7772   Val_loss 551.5580
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 41. LOSS:: Train_loss 0.0059  Train_running_loss 2.7409   Val_loss 618.6259
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 42. LOSS:: Train_loss 0.0051  Train_running_loss 2.7635   Val_loss 597.0938
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 43. LOSS:: Train_loss 0.0018  Train_running_loss 2.7605   Val_loss 528.7895
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch    45: reducing learning rate of group 0 to 1.0000e-07.
Epoch 44. LOSS:: Train_loss 0.0014  Train_running_loss 2.7500   Val_loss 565.9728
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 45. LOSS:: Train_loss 0.0008  Train_running_loss 2.7274   Val_loss 650.7986
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 46. LOSS:: Train_loss 0.0032  Train_running_loss 2.7719   Val_loss 640.0589
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 47. LOSS:: Train_loss 0.0015  Train_running_loss 2.7303   Val_loss 514.7046
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 48. LOSS:: Train_loss 0.0015  Train_running_loss 2.7317   Val_loss 545.7140
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 49. LOSS:: Train_loss 0.0058  Train_running_loss 2.7294   Val_loss 536.0183
         Accuracy:: Train 1.0000  Validation 1.0000
Epoch 50. LOSS:: Train_loss 0.0033  Train_running_loss 2.7457   Val_loss 568.7015
         Accuracy:: Train 1.0000  Validation 1.0000

NOTE: several measures are taken so far to reduce loss

  1. Data is standardized
  2. Batch normalization
  3. simplest architecture (single layer, 4 neurons) up to complex architecture (as shown above)
  4. drop out layer
  5. batchnorm momentum increased to 0.5
  6. gradient clipping (clip 0.5)
    My data consist of numerical values so no data augmentation can be done!

Hello Swapnil!

I don’t really understand what you are trying to do nor what your issue is.

But I have a couple of comments:

When you calculate loss and val_loss you use normedWeights in
loss_function. But when you calculate running_loss you do not
use normedWeights. Also, running_loss is an average – you
normalize it by dividing by len(train_loader), but val_loss is just
a sum over losses, and you don’t average it over the length of your
enumerate(val_loader) loop.

So your three loss values are normalized and/or class-weighted
differently, and are not directly comparable.

As an aside, you are performing a binary (two-class) classification.

You can treat this as a multi-class classification (as you are doing), but
you will be a little bit better off if you treat it as a binary classification.

So I would recommend that you have a single output from your final
layer, get rid of the log_softmax(), and use BCEWithLogitsLoss as
your loss function.

Something like this:

       self.fc5 = nn.Linear(16, 1)
...
       x = self.fc5(x)
       return x
...
posWeight = torch.FloatTensor ([nSamples[1] / (nSamples[0] + nSamples[1])]).to(device, dtype=torch.float64)
loss_function = nn.BCEWithLogitsLoss (pos_weight = posWeight)

Good luck.

K. Frank

1 Like