Hi,
I’m trying to do a text classification example using LSTM + LinearRegression + softmax. The input of forward is a [8, 33, 300] vector, the 8 is batch_size, 33 is the max tensor length and 300 is embedding size. However, I’ having trouble in setting the right dimension in the forward() method. I’ll paste the code piece first
class LSTMMasker(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, target_size, args, corpus, layer=1):
super(LSTMMasker, self).__init__()
self.hidden_dim = hidden_dim # hidden dim = 512
self.corpus = corpus
self.embedding = nn.Embedding(vocab_size, embedding_dim) # embedding_dim = 300
# when batch_first is True, input is of shape (batch, tensor_len, *
self.lstm = nn.LSTM(embedding_dim, hidden_dim, layer, batch_first=True) # layer = 1
self.hidden = self.init_hidden(args.batch_size) # batch_size = 8
# The linear layer that maps from hidden state space to tag space
self.output = nn.Linear(hidden_dim, target_size) # hidden_dim=512, target_size=2
self.use_cuda = args.cuda
def init_hidden(self, batch_size=300):
return (torch.zeros(1, batch_size, self.hidden_dim), torch.zeros(1, batch_size, self.hidden_dim))
def forward(self, inputs):
d_word, d_len = inputs # d_word has size [8, 33, 300], 8 is batch_size, 33 is max sentence length
embeds = self.embedding(d_word # embeds has size [8, 33, 300], 300 is embedding size of each word
lstm_out, self.hidden = self.lstm(embeds, self.hidden)
print(lstm_out.size()) # which is [8, 33, 512]
output = self.output(lstm_out) # output has size [8, 33, 2]
result = F.log_softmax(output, dim=1) # will cause error
return result
Then I get an error in the softmax line
Expected target size (8, 2), got torch.Size([8])
However, as you can see, the data is of size [8, 33, 2], how can I make it [8, 2]? As far as I know, .view method can not do that as it just reorganize the tensor.
Also, I find the document of batch_size is a little bit vague, and I have to use some magic option like batch_first=True
to control the order of batch_size, which is kind of error prone.
But anyway, my focus is still on the LSTM itself, does anyone has any clue about my issue? Thanks!