Loss function(s) not reflecting optimum values and increasing with time

I’ve encountered some strange behavior in my model’s training outputs. It seems that the steps the optimizer takes do not help with convergence, and in many cases do the opposite of what the loss function would entail. For the sake of testing, I make all labels of the dataset 0.86 (arbitrary value) to see if the model can at least predict a constant value. Using the MSELoss here are the outputs at each training step :

              Loss                       |                 Model Output.               |       Label    

tensor(0.3838, grad_fn=<MseLossBackward0>) tensor([[0.2405]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.2788, grad_fn=<MseLossBackward0>) tensor([[0.3320]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.1464, grad_fn=<MseLossBackward0>) tensor([[0.4774]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0539, grad_fn=<MseLossBackward0>) tensor([[0.6278]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0121, grad_fn=<MseLossBackward0>) tensor([[0.7500]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0006, grad_fn=<MseLossBackward0>) tensor([[0.8354]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0009, grad_fn=<MseLossBackward0>) tensor([[0.8908]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0043, grad_fn=<MseLossBackward0>) tensor([[0.9257]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0077, grad_fn=<MseLossBackward0>) tensor([[0.9475]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0104, grad_fn=<MseLossBackward0>) tensor([[0.9619]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0124, grad_fn=<MseLossBackward0>) tensor([[0.9713]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0138, grad_fn=<MseLossBackward0>) tensor([[0.9776]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0149, grad_fn=<MseLossBackward0>) tensor([[0.9820]], grad_fn=<MeanBackward1>) tensor([[0.8600]])

The same thing also happens when changing the loss to L1Loss:

              Loss                    |                 Model Output.             |       Label    
tensor(0.8063, grad_fn=<SumBackward0>) tensor([[0.0537]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.7222, grad_fn=<SumBackward0>) tensor([[0.1378]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.5698, grad_fn=<SumBackward0>) tensor([[0.2902]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.3777, grad_fn=<SumBackward0>) tensor([[0.4823]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.1891, grad_fn=<SumBackward0>) tensor([[0.6709]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0419, grad_fn=<SumBackward0>) tensor([[0.8181]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0509, grad_fn=<SumBackward0>) tensor([[0.9109]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.0926, grad_fn=<SumBackward0>) tensor([[0.9526]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.1121, grad_fn=<SumBackward0>) tensor([[0.9721]], grad_fn=<MeanBackward1>) tensor([[0.8600]])
tensor(0.1224, grad_fn=<SumBackward0>) tensor([[0.9824]], grad_fn=<MeanBackward1>) tensor([[0.8600]])

Here the steps the optimizer is taking is not conducive to lowering the loss, but instead moves in the direction of converging to 1.

my model is defined as follows :


class DualBertForClassification(nn.Module):
    def __init__(self, bert_model_a, bert_model_b):
        super(DualBertForClassification, self).__init__()
        
        self.bert_model_wt = bert_model_a
        self.bert_model_mutant = bert_model_b
        self.layer_1 = nn.Linear(1024, 512)
        self.layer_2 = nn.Linear(512, 128)
        self.layer_3 = nn.Linear(128, 16)
        self.layer_4 = nn.Linear(16, 1)

    def forward(self, x):
        x_a = x[0]  
        x_b = x[1]
        x = torch.cat(
            (
                self.bert_model_wt(**x_a).last_hidden_state, # [batch_sz, sequence_sz, 1024]
                self.bert_model_mutant(**x_b).last_hidden_state # [batch_sz, sequence_sz, 1024]
            ), 
            1
        )  # [batch_sz, 2*sequence_sz, 1024, ]
        x = torch.tanh(self.layer_1(x)) # [batch_sz, 2*sequence_sz, 512, ]
        x = torch.tanh(self.layer_2(x)) # [batch_sz, 2*sequence_sz, 128, ]
        x = torch.tanh(self.layer_3(x)) # [batch_sz, 2*sequence_sz, 16, ]
        x = torch.tanh(self.layer_4(x)) # [batch_sz, 2*sequence_sz, 1, ]
        x = torch.mean(x, dim = 1)  # [batch_sz, 1]
        return x

bert_model_a , bert_model_b are both pretrained identical BERT models. Whose output sizes are shown in the comments. Here is also the training loop which produced the printed output shown above.

    model1 = BertModel.from_pretrained(model_name)
    model2 = BertModel.from_pretrained(model_name)
    model = DualBertForClassification(model1, model2)

    
    optimizer = torch.optim.SGD(model.parameters(), lr=0.05, momentum=0.9)

    
    loss_fct = nn.L1Loss(reduction="sum") # also with nn.MSELoss(reduction = "sum")
    model.train()
    for index, (input_a, input_b) in enumerate(zip(wt_inputs, alt_inputs)):
        label_val = torch.tensor([[0.86]], dtype = torch.float32)
        model_output = model((input_a, input_b))
        loss = loss_fct(model_output, label_val)
        print(loss, model_output, label_val)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

My thoughts are that this could be an issue with the way autograd is handling the two input BERT models, but I still cannot think of an explanation that could explain the growing loss over time. Any help is greatly appreciated, thank you

Hi Jimmy!

Your code looks okay (but who knows what goblins might be lurking in your
BertModel).

Try turning off momentum in your optimizer and start by running with a very
small learning rate. Plain-vanilla SGD takes steps in the direction of smaller
loss (That’s what gradient descent is.) and if the steps are small – small learning
rate – they should indeed reduce the loss, rather than “jumping” to a point with
a higher loss.

Note, however, that if your batches are “noisy,” then even if the optimizer step you
take using the gradients for, say, batch #12 would reduce the loss for batch #12,
it is possible that that step could increase the loss for batch #13. I don’t thing that
this is your issue, but to rule it out you could try passing in the same batch over
and over again.

To rule out something weird in BertModel try training (with SGD, no momentum,
and a small learning rate) with the BertModel parameters frozen. The “head”
of your model (layer_1, …, layer_4) is very standard and should easily train
to spit out your fixed target value of 0.86 and train down to a loss of zero. (With
L1Loss, your training could end up jumping back and forth across the minimum,
necessitating an additional reduction your learning rate, but with MSELoss this
shouldn’t happen.)

Best.

K. Frank

Thanks a lot for the response and suggestions! It seems as though there isn’t anything super suspicious going on under the hood of loss computation, backprop etc. I tried all of the things you suggested

  • lower learning rate
  • no gradient computation on the two BertModel instances
  • using the same batch each step as input

And the model did in fact converge to 0.86 as expected. It looks like your comment about noise across batches was correct since using the same batch each time seemed to solve it. Also I ran the exact same configuration as described above but without freezing the weights of the BertModel, and again it seems to converge at 0.86 predictions, so it doesn’t seem like there is anything unusual about those components either.

Thanks again for the help!