Loss function finding extreme local minimum

I’ve created a model to predict an array of continuous values from an input sequence. An example of a label set is something in the form of

a = [0.0, 0.0, 0.3, 0.72, 0.0, 0.04, … , 0.0]

What I quickly found, however, was that the model began to converge to predicting a set of values (nearly) identical to one another in a post I outlined here. An example of a common prediction would be :

b = [0.231, 0.231, 0.231, 0.231, 0.231, 0.231, … , 0.231]

My loss function at the time was the torch MSE implemented as MSELoss(reduction = "sum"). To counteract this I altered my loss function to consider standard deviation of predictions to increase diversity. The updated loss function now looks like:

def my_loss(y_pred, y, phi = 100):
    return torch.sum(torch.square(y - y_pred)) - (torch.std(y_pred) * phi)
    #       ^ same as MSELoss(reduction = "sum")  ^negative term of std deviation

With this as the loss function we do see a good amount of diversity in our predictions. While training we see a growing difference in the minimum and maximum of our predicted labels with a much lower loss value too (see loss output below). However, despite this, given enough time the entire thing collapses in very few steps and we return to predicting the same values akin to array b :

step :  100 loss :  6.5379  max prediction: 0.18211 min prediction: 0.16466 difference :  0.01744939 std dev: 0.00366005
step :  200 loss :  0.65646 max prediction: 0.11481 min prediction: 0.05075 difference :  0.06405885 std dev: 0.018370248
step :  300 loss : -1.83972 max prediction: 0.12452 min prediction: 0.01304 difference :  0.11148625 std dev: 0.043092
step :  400 loss : -3.99177 max prediction: 0.17989 min prediction: 0.00535 difference :  0.17454928 std dev: 0.08250562
step :  500 loss : -5.79127 max prediction: 0.23568 min prediction: 0.00406 difference :  0.23161854 std dev: 0.11023289
step : 1500 loss : -8.42063 max prediction: 0.667 min prediction:   0.00038 difference :  0.66662633 std dev: 0.25211474
step : 1600 loss : -9.68138 max prediction: 0.67537 min prediction: 0.00041 difference :  0.67495567 std dev: 0.21875821
step : 1700 loss : -19.4889 max prediction: 0.70115 min prediction: 0.00033 difference :  0.70081335 std dev: 0.23015715
step : 1800 loss : 1.72278  max prediction: 0.02762 min prediction: 0.02762 difference :  3.3024698e-06 std dev: 5.7693285e-07
step : 1900 loss : 2.03984  max prediction: 0.02339 min prediction: 0.02339 difference :  2.6654452e-06 std dev: 3.6730162e-07

step :18700 loss : 1.86279  max prediction: 0.02236 min prediction: 0.02236 difference :  4.189089e-06  std dev: 6.8502027e-07

up until 1700 steps, it actually is looking pretty good! But you can see between steps 1700 - 1800 the entire set of predictions start to look the same given the very small differnce (min - max) and the small standard deviation. Even with the much higher loss we see this continue long into training (nearly 20,000 steps). It should be noted that when I run everything with the same batch over and over it eventually does converge to exactly the correct labels, so it doesn’t seem like there is anything super unusual going on under the hood. Also I know the phi = 100 value in the loss is a bit odd since predicting exactly the correct labels is not necessarily optimal with this loss function, but this was just to illustrate the problem at hand. I’m also using a scheduler with training so I’m not sure if a larger learning rate may help. My thoughts are possibly increasing batch size (currently batch size = 1 for illustration) may help since it’s less likely every sequence in a batch falls into this strange phenomenon.

my model is defined as :

class MyModel(nn.Module):
    def __init__(self, bert_model, output_size):
        super(MyModel, self).__init__()
        self.bert_model = bert_model
        self.linear_1 = nn.Linear(1024, 512)
        self.linear_2 = nn.Linear(512, 128)
        self.linear_3 = nn.Linear(128, 64)
        self.linear_4 = nn.Linear(64, output_size)

    def forward(self, input_ids, attn_mask):
        x = F.relu(self.bert_model(input_ids, attention_mask=attn_mask).last_hidden_state)
        x = F.relu(self.linear_1(x))
        x = F.relu(self.linear_2(x))
        x = F.relu(self.linear_3(x))
        x = torch.flatten(torch.sigmoid(self.linear_4(x)))
        return x

and my training loop looks like:

lr = 0.0005
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

scheduler1 = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
mask_domains = True
for step, batch in enumerate(train_data_loader):
      labels = torch.from_numpy(np.asarray(batch["labels"], dtype=np.float32)).cuda()
      inputs = batch["input_ids"].cuda()
      attn_mask = batch["attention_mask"].cuda()
      outputs = model(inputs, attn_mask).cuda()

      # create mask to exclude the padding tokens in loss calculation
      # making the outputs and labels 0 shouldn't affect loss value 
      # due to reduction = "sum"
      loss_mask = torch.where(labels == -100.0, 0, 1).cuda()
      lst1 = list(outputs.cpu().detach().numpy())
      lst2 = list(filter(lambda x : x > 0.01, lst1))

      # apply the mask for padding tokens 
      labels = torch.mul(labels, loss_mask)
      outputs = torch.mul(outputs, loss_mask)

      #creating mask to exclude additional tokens we don't want in our calculation
      if mask_domains:
            domains = np.asarray(batch["domains"])
            # making the outputs and labels 0 shouldn't affect loss value 
            # due to reduction = "sum"
            domains_mask = np.where(domains == -1, 0, 1)
            domains_mask = torch.from_numpy(domains_mask).cuda()
            labels = torch.mul(labels, domains_mask)
            outputs = torch.mul(outputs, domains_mask) 
      loss = my_loss(outputs[torch.where(domains_mask == 1)], labels[torch.where(domains_mask == 1)])

      running_loss += loss.item()
      if verbose and step > 0 and step % 100 == 0:
                "step : ", step, 
                "loss :" ,  round(loss.item(), 5), 
                "max prediction:",  round(np.max(lst1), 5), 
                "min prediction:", round(np.min(lst1), 5), 
                "difference : ", np.max(lst1) - np.min(lst1), 
                "std dev:", np.std(lst1)
            running_loss = 0

Thanks to anyone who took the time to read it this far.

The idea to encourage a larger standard deviation is cool, but I wouldn’t do it - usually you want the loss function to directly represent what a good output should look like, and let the network decide how to do it. Here you tell it “if your new solution is more diverse but less close to the data, I can live with that”.
I say usually because regularization (e.g. weight decay) is exactly the opposite, where you add a loss term that does not enforce correct prediction but a simpler model/better training dynamics. So I see your std loss as a regularizer, and would ask myself both whether I need a regularizer (which is usually added after we successfully fit the training data, and want to generalize to unseen validation data) and if do need it, why choose the std one first.

A problem where the model seems to learn a too simplistic answer is the opposite case - underfitting, where the model is not expressive enough for even the training data. But to be frank I wouldn’t rush to conclude it’s either of these and debug first (without the added std term).

What I would try first is replace the dataset with just two batches. Does the model successfully learn them? (and not return a constant)?
Also, I would keep printing the MSE and not your loss value - does it constantly decrease (i.e. does the optimization works, but it gets stuck in a local minimum)? Or does it not necessarily decrease over time (optimization fails, need to tune hyperparameters).

Hope this gives gives some ideas on how to continue from here.

Hi Jimmy!

I speculate that something is fishy with your data after step 1700. You are
using plain-vanilla SGD with a small learning rate so you would expect the
loss to go down, rather than jump up dramatically. (The large downward jump
from step 1600 to step 1700 does seem a bit odd, though.)

A few items of really wacky data could kick your model into a weird place, but
you train another 17,000 steps without your model going back down again.

If I am reading your code correctly, you only iterate over train_data_loader
once. That is to say, you only run one epoch and you only process any given
data item once. (From this I deduce that you have about 18,700 individual data
items in your training set.)

Is it possible that your first ~1700 data items are “sensible,” but after that the
data are flaky? Two suggestions: Try printing our (or storing in an array) your
loss for every step. You might also try no longer actually training your model
after, say, step 1600, but continue computing the loss.

Both of these suggestions are focused on seeing whether it’s your model that
has gone bad or whether the data (after step 1700) are bad.

As an aside (not that it’s relevant to the current issue), it seems you only call
scheduler1.step() if verbose is turned on, and then, only every 100 steps.


K. Frank