LSTM for many to one multiclass classification problem

Hello Everyone,

Very new to pytorch. Documentation seems to be really good in pytorch that I gather from my limited reading. Despite that, it can not answer all the doubts of a user. Moreover, I am coming here from this link on Example of Many-to-One LSTM which partially helped me but leave a lot of things not clear to me, and they are as follows:

1st

  rnn = nn.LSTM(10, 20, 2)
   input = Variable(torch.randn(5, 3, 10))
   h0 = Variable(torch.randn(2, 3, 20))
   c0 = Variable(torch.randn(2, 3, 20))
   output, hn = rnn(input, (h0, c0))

In the above code there are two layers of lstm in the model, and the hidden layer size is 20. My question is what is the size of first hidden layer and second layer output. Are they both 20? Because output size is [torch.FloatTensor of size 5x3x20]. If they both are 20 then is there way to assign different hidden sizes for different lstm layer?

2nd

This question Example of Many-to-One LSTM only talks about one final output. How can I use for multiclass classification problem? Providing the same code from the linked question below:

import torch
import torch.nn as nn
from torch.autograd import Variable

time_steps = 10
batch_size = 3
in_size = 5
classes_no = 7

model = nn.LSTM(in_size, classes_no, 2)

input_seq = Variable(torch.randn(time_steps, batch_size, in_size)) 
# size: [torch.FloatTensor of size 10x3x5] i.e there are 10 sequences each of 
# length 3 with 5 features. 

output_seq, _ = model(input_seq)  
# size:  [torch.FloatTensor of size 10x3x7] i.e there are 10 sequences 
# (as it should be because in input there are 10) each of length 3 and 7 features 
# (because hidden_size = 7(classes_no)).

last_output = output_seq[-1] 
 # size: [torch.FloatTensor of size 3x7]  **now here is a doubt** should it not be 
# the size of 10x7? As we want the prediction of all the 10 sequences, and 7 
# predictions for 5 input-features so it should be 10x7.   

loss = nn.CrossEntropyLoss() # seems OK.

target = Variable(torch.LongTensor(batch_size).random_(0, classes_no-1)) 
# [torch.LongTensor of size 3]  # I do not seem to understand this. How can 
# I **change** this for my multiclass classifcation problem on my timeseries data. 
# Say 3 classes then my target should be 10x3, where 3 is for class 0, 1, or 2.  

err = loss(last_output, target) # Then how do you calculate the loss. 

err.backward()

Please read the comment in above code for my doubts. Any help or direction where I can get my answers will be hugely appreciated.

Thank you.

Yes.

No.

Wrong. there are 3 sequences each of length 10 with 5 features.
That should clear up the doubts concerning output_seq and last_output

Concerning the loss

loss = nn.CrossEntropyLoss() # defines which loss function to use.
err = loss(last_output, target) # calculates that loss.

Remarks…
You may need more hidden units in the LSTM layers. In which case you would need to add a Linear layer to squeeze the last_output down to the right size for the number of classes.

1 Like

@jpeg729 Thank you for your answer. Multiclass classification part is not clear. Are you suggesting that I take the last output from the lstm layer, and add a linear layer to output neuron size will be the size of the class (here 3) and squeeze (using softmax) to predict the label?

Yep, exactly that…

Something like this should work

import torch
import torch.nn as nn
from torch.autograd import Variable

time_steps = 10
batch_size = 3
in_size = 5
classes_no = 7
hidden_size = 12 # for example

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(in_size, hidden_size, 2)
        self.linear = nn.Linear(hidden_size, classes_no)

    def forward(self, input_seq):
        output_seq, _ = self.lstm(input_seq)
        last_output = output_seq[-1]
        class_predictions = self.linear(last_output)
        return class_predictions

# initialise model and loss function
model = Model()
loss = nn.CrossEntropyLoss()

# prepare input and target
input_seq = Variable(torch.randn(time_steps, batch_size, in_size)) 
target = Variable(torch.LongTensor(batch_size).random_(0, classes_no-1))

# calculate stuff
output = model(input_seq)
err = loss(last_output, target)
err.backward()
2 Likes

Thanks a lot, @jpeg729. This really cleared my doubt. Just an addendum to any user reading this- could benefit more from this sample tutorial here.

1 Like