Pytorch squeeze doesn't work

    def forward(self, input_seq, encoder_outputs, hidden=None):
        outputs, hidden = self.gru(input_seq, hidden)
        outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
        attn_weights = self.attn(outputs, encoder_outputs)
        context = attn_weights.bmm(encoder_outputs.transpose(0, 1))
        context = context.squeeze(1)
        new_outputs = outputs
        new_outputs = new_outputs.squeeze(0)
        concat_input =, context), 1)
        concat_output = torch.tanh(self.concat(concat_input))
        outputs = self.out(concat_output)
        return outputs, hidden

The output:

torch.Size([5, 1, 50])
torch.Size([5, 5, 50])
torch.Size([5, 5, 50])
torch.Size([5, 50])

the context tensor is squeezed but the output tensor is not squeezed (size of new_output should be 5,50 and not 5,5,50). Why is this happening?


Squeeze only works if the size of a given dimension is 1.
If you want to remove a dimension of size > 1 then you need to use a function to do that reduction like sum or max or mean.