Problem with the attention decoder while using an LSTM

This may be a newbie problem but I’ve been stuck with this for a while and I don’t know exactly where the dimensions are mismatched.
The attention decoder is a variation of the one in the seq2seq tutorial:

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1, max_length=MAX_LENGTH):
        super(AttnDecoderRNN, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.dropout_p = dropout_p
        self.max_length = max_length

        self.embedding = nn.Embedding(self.output_size, self.hidden_size)
        self.attn = nn.Linear(self.hidden_size * 2, self.max_length)
        self.attn_combine = nn.Linear(self.hidden_size * 2, self.hidden_size)
        self.dropout = nn.Dropout(self.dropout_p)
        self.lstm = nn.LSTM(self.hidden_size, self.hidden_size)
        self.out = nn.Linear(self.hidden_size, self.output_size)

    def forward(self, input, hidden, encoder_outputs):
        embedded = self.embedding(input).view(1, 1, 1, -1)
        embedded = self.dropout(embedded)
        attn_weights = F.softmax(
            self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
        attn_applied = torch.bmm(attn_weights.unsqueeze(0),
                                 encoder_outputs.unsqueeze(0))
        output = torch.cat((embedded[0], attn_applied[0]), 1)
        output = self.attn_combine(output).unsqueeze(0)

        output = F.relu(output)
        output, hidden = self.lstm(output, hidden)

        output = F.log_softmax(self.out(output[0]), dim=1)
        return output, hidden, attn_weights

    def initHidden(self):
        return (torch.zeros(1, 1, self.hidden_size, device=device),torch.zeros(1, 1, self.hidden_size, device=device))

I’m getting the following error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-175-7b1ecba569ab> in <module>()
      4 attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p=0.1).to(device)
      5 
----> 6 trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

6 frames
<ipython-input-170-29eaa6fdc28a> in trainIters(encoder, decoder, n_iters, print_every, plot_every, learning_rate)
     17 
     18         loss = train(input_tensor, target_tensor, encoder,
---> 19                      decoder, encoder_optimizer, decoder_optimizer, criterion)
     20         print_loss_total += loss
     21         plot_loss_total += loss

<ipython-input-168-77785d753587> in train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length)
     30         for di in range(target_length):
     31             decoder_output, decoder_hidden, decoder_attention = decoder(
---> 32                 decoder_input, decoder_hidden, encoder_outputs)
     33             loss += criterion(decoder_output, target_tensor[di])
     34             decoder_input = target_tensor[di]  # Teacher forcing

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-166-f1855786cb3b> in forward(self, input, hidden, encoder_outputs)
     39         print(embedded[0].size(), hidden[0].size())
     40         attn_weights = F.softmax(
---> 41             self.attn(torch.cat((embedded[0], hidden[0]), 1)), dim=1)
     42         attn_applied = torch.bmm(attn_weights.unsqueeze(0),
     43                                  encoder_outputs.unsqueeze(0))

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/linear.py in forward(self, input)
     90     @weak_script_method
     91     def forward(self, input):
---> 92         return F.linear(input, self.weight, self.bias)
     93 
     94     def extra_repr(self):

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in linear(input, weight, bias)
   1406         ret = torch.addmm(bias, input, weight.t())
   1407     else:
-> 1408         output = input.matmul(weight.t())
   1409         if bias is not None:
   1410             output += bias

RuntimeError: size mismatch, m1: [2 x 256], m2: [512 x 5] at /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:268

Hope you can help me.

It seems to me that you are working with pairs of 3D tensors from the initHidden but you covert your embedded input into a 4D tensor. I may be mistaken but this seems the most obvious at first glance.