The following code, inside the forward
fo a network module, results in a module that learns correctly:
def forward(self, x, state_cell_tuple):
...
xdot = x @ self.Wx
xdot = xdot.view(batch_size, 4, self.embedding_size)
....
i = F.tanh(xdot[:, 0] + hdot[:, 0] + self.bias[0])
... (etc) ...
However, transposing the view, and the accessor fails to learn correctly:
xdot = x @ self.Wx
xdot = xdot.view(4, batch_size, self.embedding_size)
...
i = F.tanh(xdot[0] + hdot[0] + self.bias[0])
... (etc) ...
(from an LSTM cell implementation of course).
So, the question is: why is this view transposition failing? a matrix multiplication is effectively a fully-connected layer, should not matter which way around I do the .view()
, I think? What am I missing here?