I’m trying to do a text classification example using LSTM + LinearRegression + softmax. The input of forward is a [8, 33, 300] vector, the 8 is batch_size, 33 is the max tensor length and 300 is embedding size. However, I’ having trouble in setting the right dimension in the forward() method. I’ll paste the code piece first
class LSTMMasker(nn.Module): def __init__(self, embedding_dim, hidden_dim, vocab_size, target_size, args, corpus, layer=1): super(LSTMMasker, self).__init__() self.hidden_dim = hidden_dim # hidden dim = 512 self.corpus = corpus self.embedding = nn.Embedding(vocab_size, embedding_dim) # embedding_dim = 300 # when batch_first is True, input is of shape (batch, tensor_len, * self.lstm = nn.LSTM(embedding_dim, hidden_dim, layer, batch_first=True) # layer = 1 self.hidden = self.init_hidden(args.batch_size) # batch_size = 8 # The linear layer that maps from hidden state space to tag space self.output = nn.Linear(hidden_dim, target_size) # hidden_dim=512, target_size=2 self.use_cuda = args.cuda def init_hidden(self, batch_size=300): return (torch.zeros(1, batch_size, self.hidden_dim), torch.zeros(1, batch_size, self.hidden_dim)) def forward(self, inputs): d_word, d_len = inputs # d_word has size [8, 33, 300], 8 is batch_size, 33 is max sentence length embeds = self.embedding(d_word # embeds has size [8, 33, 300], 300 is embedding size of each word lstm_out, self.hidden = self.lstm(embeds, self.hidden) print(lstm_out.size()) # which is [8, 33, 512] output = self.output(lstm_out) # output has size [8, 33, 2] result = F.log_softmax(output, dim=1) # will cause error return result
Then I get an error in the softmax line
Expected target size (8, 2), got torch.Size()
However, as you can see, the data is of size [8, 33, 2], how can I make it [8, 2]? As far as I know, .view method can not do that as it just reorganize the tensor.
Also, I find the document of batch_size is a little bit vague, and I have to use some magic option like
batch_first=True to control the order of batch_size, which is kind of error prone.
But anyway, my focus is still on the LSTM itself, does anyone has any clue about my issue? Thanks!