Multi-class classification predictions come out all the same

Hi Everybody,

I’m training a net to detect what pitch-classes are playing in music. I’ve trained a net on the large Maestro dataset which has classical music along with their scores. Taking a look at the predictions the net made, I’m pretty disappointed! Instead of showing around 1 for pitch classes that are playing, and 0 for pitch classes that are not, it’s showing about 0.76 for every pitch class. Kind of like the net is saying “maybe?” to every classification. Below is a sample of the prediction, which should be detecting the presence of 3 pitch classes (and the absence of the other 9.) Could this be a symptom of the net getting stuck on a local minimum? How can I improve things?

The optimizer is SGD, the loss function is BCEWithLogitsLoss, and the model’s code is below, under the output of predictions.

[-0.7595, -0.7598, -0.7598, -0.7599, -0.7597, -0.7598, -0.7602, -0.7596,
-0.7596, -0.7594, -0.7601, -0.7602],
[-0.7596, -0.7599, -0.7599, -0.7600, -0.7598, -0.7599, -0.7603, -0.7597,
-0.7598, -0.7596, -0.7602, -0.7603],
[-0.7594, -0.7596, -0.7596, -0.7599, -0.7595, -0.7598, -0.7601, -0.7594,
-0.7595, -0.7593, -0.7600, -0.7601],
[-0.7596, -0.7598, -0.7598, -0.7600, -0.7597, -0.7599, -0.7603, -0.7596,
-0.7597, -0.7595, -0.7602, -0.7603],
[-0.7597, -0.7600, -0.7599, -0.7600, -0.7599, -0.7599, -0.7603, -0.7598,
-0.7598, -0.7596, -0.7602, -0.7603],

class AudioNet1(nn.Module):

    def __init__(self, input_size=256,
                 h1_nodes=256,
                 h2_nodes=128,
                 h3_nodes=64,
                 output_size=12,
                 device='cpu'):
        super(AudioNet1, self).__init__()
        
        self.inputLayer = nn.Linear(input_size, h1_nodes).to(device)
        self.bn1 = nn.BatchNorm1d(h1_nodes).to(device)
        self.do1 = nn.Dropout(p=0.2).to(device)
        self.hiddenOne = nn.Linear(h1_nodes, h2_nodes).to(device)
        self.bn2 = nn.BatchNorm1d(h2_nodes).to(device)
        self.do2 = nn.Dropout(p=0.2).to(device)
        self.hiddenTwo = nn.Linear(h2_nodes, h3_nodes).to(device)
        self.bn3 = nn.BatchNorm1d(h3_nodes).to(device)
        self.do3 = nn.Dropout(p=0.2).to(device)
        self.hiddenThree = nn.Linear(h3_nodes, output_size).to(device)
        self.bn4 = nn.BatchNorm1d(output_size).to(device)
        self.do4 = nn.Dropout(p=0.2).to(device)
        self.lstm = nn.LSTM(input_size=output_size, hidden_size=12, num_layers=2)
    
    def forward(self, x):
        x = F.leaky_relu(self.inputLayer(x).float())
        x = self.bn1(x)
        x = self.do1(x)
        x = F.leaky_relu(self.hiddenOne(x))
        x = self.bn2(x)
        x = self.do2(x)
        x = F.leaky_relu(self.hiddenTwo(x))
        x = self.bn3(x)
        x = self.do3(x)
        x = F.leaky_relu(self.hiddenThree(x))
        x = self.bn4(x)
        x = self.do4(x)       
        bs, n = x.shape
        x, _ = self.lstm(x.view(1, bs, n))
        x = torch.squeeze(x, 0)
        return x

This operation self.lstm(x.view(1, bs, n)) looks dangerous, since it seems you would like to permute the dimensions and are using view for it.
If that’s the case, use x.permute() and pass the desired output dimensions to this operation.
I might be wrong and this code might only unsqueeze dim0, in which case you could use x.unsqueeze(0) to be more explicit.

I would generally recommend to try to overfit a small dataset, such as 10 samples, and make sure your model is able to do so. If your model cannot overfit this set, there might be other bugs in the code I’m missing at the moment.

Thanks for your suggestions ptrblck!

Actually, I do just want to unsqueeze the tensor, thanks for pointing that out.

The suggestion to overfit a small dataset is very helpful. Doing that, I’ve made a much quicker loop of experimentation and evaluation. I haven’t solved the problem yet, but I feel I’m working toward it in a much more purposeful way.

I’ll check back in when I have a firmer handle on what’s happening.

Might be too obvious but did you try to increase the lr?

I have tried increasing the lr, but it dosen’t seem to change the basic behavior of what’s going on.

I’ve found that with the lstm, the BCE with logits error gets stuck at about 0.38, and the predictions are very close to each other on the 12 classes (as it was above.) However, if I take the softmax of the output, the three highest values are quite reliably the correct three classes! It’s still not useful as a classifier though because the values are so close, and I think that would lead to a lot of false positives.

I tried removing the lstm entirely, and now the error rate does not get stuck, and goes down to about .006 after 100 epochs. Furthermore, the outputs are varied in the way you would want. The softmax of the output shows very small probabilities for the “off” notes, and close to 1/3 for the “on” notes.

Is the moral of the story that I don’t need an lstm for this problem? I doubt that, because it’s a time series, and context in time is surely meaningful.

BTW I did change the lstm code using unsqueeze; that section of the code is currently as follows:

x = x.unsqueeze(0)
x, _ = self.lstm(x)
x = x.squeeze(0)

I have also tried using just an lstm, with pretty much identical results to the above. The BCE with logits loss stalls out pretty quick, the outputs are all very close to 0.76. But when you take the softmax of the output, it’s something like 10 or 20 percent for the “on” notes and about 5 percent for the off notes. Increasing the layers of the lstm changes things a little, but the overall behavior is the same.

Pretty sure it’s me, not the lstm that’s at fault! I think I need to go and learn more about how to use them.

This might be indeed due to a high learning rate.
nn.BCEWithLogitsLoss will internally apply a log_sigmoid and thus expects raw logits as the model output. If your model trains better with an additional softmax this could mean that smaller gradients might be beneficial for your model.

To be more clear – I’m not using softmax during training, just as a way to interpret the predictions.

Ah OK, in that case forget what I said as I misunderstood the issue. :wink: