Is there a functional version of torch.nn.lstm?
I don’t think there is a functional API in torch.nn.functional
, but torch._VF.lstm
is used inside the nn.LSTM
module as seen here so you might want to use it.
Thanks! @ptrblck
So If I wanted to use a Bidirectional LSTM, that receives as input a torch.tensor
of shape (batch, timesteps, features):
hx
should be a tensor of zeros as they do in the docs:
num_directions = 2
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size
h_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, real_hidden_size,
dtype=input.dtype, device=input.device)
c_zeros = torch.zeros(self.num_layers * num_directions,
max_batch_size, self.hidden_size,
dtype=input.dtype, device=input.device)
hx = (h_zeros, c_zeros)
And but I’m having trouble understanding how they flatten the weights before passing it to the layer. Is it a list of the different LSTM variables flatten in a certain order?