LSTM with inputs as mfcc features not learning

Hello, a newbie here. I’m trying to implement an LSTM classifier (specifically the LSTM classifier in Table 5 in this paper) but I’m confused about a couple of things:

  1. I’ve extracted MFCC, MFCCs velocity and MFCCs acceleration features with 13 lower-order MFCCs of the audio samples in my dataset. So, at the end, I have the input in the shape of (330, 39, 500) where 330 is the number of samples. Should the LSTM’s input size be 39 or 500? If the input size should be 500, then should the hidden size be 39?
  2. I have tried a couple of combinations such as the hidden size is 39 and input size is 500 but LSTM is not learning and outputs are the same.

Which steps can I follow from here?
Edit: I’m using Adam optimizer and cross entropy loss function.

There are 2 classes and 128 LSTM units. LSTM model code:

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=0.3)
        self.act = nn.ReLU()
        self.dense1 = nn.Linear(32, 8)
        self.dense2 = nn.Linear(8,num_classes)
        
    def forward(self, x):
        # Set initial states
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device) # 2 for bidirection 
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).requires_grad_().to(device)
        
        out, _ = self.lstm(x, (h0.detach(), c0.detach()))  
        out = self.act(out[:, -1, :])
        out = self.act(self.dense1(out))
        out = F.softmax(self.act(self.dense2(out)))
        return out

The training code:

model.train()
train_loss = 0
correct = 0
total = 0

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        images = images.float().to(device)
        images = images.view(images.shape[0], images.shape[1], 1)
        labels = labels.to(device)
        
        optimizer.zero_grad()

        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()

        train_loss += loss.item()
        currentTrainLoss=loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        
        print('Loss: {:.3f} | Acc: {:.3f}'.format(train_loss/(i+1), 100.*correct/total))

# Test the model
model.eval()
with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        images = images.float()
        images = images.view(images.shape[0], images.shape[1], 1)
        labels = labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Test Accuracy of the model on test images: {} %'.format(100 * correct / total)) 

Hey @chericha

Why do you detach these two from graph?

...
out, _ = self.lstm(x, (h0.detach(), c0.detach()))
...

This thing in forward disables the learning. Refer to the doc

Hello @zetyquickly
I saw one of the implementations of LSTM and they detached there saying

We need to detach as we are doing truncated backpropagation through time (BPTT)
If we don’t, we’ll backprop all the way to the start even after going through another batch

I turned that part into

...
out, _ = self.lstm(x, (h0, c0))
...

It still gives the same results.

I see. So, I think that was a correct use of detach.

May you try to reduce the number of linear layers at the end and see is there any difference?

out, _ = self.lstm(x, (h0.detach(), c0.detach())) 
out = self.act(out[:, -1, :])
out = self.act(self.dense1(out))
out = F.softmax(self.act(self.dense2(out)))

to

out, _ = self.lstm(x, (h0.detach(), c0.detach())) 
# do the proper `.view()` of the `out` tensor 
out = F.softmax(self.dense2(out))

Two points are from the bottom to the top of your code:

  1. softmax is an activation function there’s no specific need to do ReLu first,
  2. There’s no particular need to apply ReLU on LSTM's output because nonlinearities are present within LSTM (see here tanh)
1 Like

Thank you, it’s learning now!