nn.LSTM using custom LSTMCells

I’ve found a lot of resources on writing a custom nn.LSTMCell, and have successfully done that and applied it to the time_sequence_prediction from the pytorch examples. I would now like to use my custom LSTMCell within a nn.LSTM with multiple layers. Is there a simple way to do this I am missing already built into nn.LSTM? or alternatively is there a source for a pytorch nn.LSTM implementation that would be the equivalent to this this link for nn.LSTMCell?

If anyone else comes across this, this seems to be working for me. Returns the same values, but isn’t taking in and returning the same format so you’ll have to check your input and output. e.g. my input has to be [(h,x)] instead of just (h,x) and my output instead of returning _,states = lstm(…) its just states. I also only tested with num_layers of 1. If anyone would like to check my work, or update with any fixes to this (especially to allow consistent formatting), or fixes required with problems if you are using more layers that would be greatly appreciated. Not making this as solution because I’d like the solution to allow it to just replace a nn.LSTM directly without any other changes to the code and this does not do that.

class LSTMCell(nn.Module):

    def __init__(self, input_size, hidden_size, bias=True):
        super(LSTMCell, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.bias = bias
        self.x2h = nn.Linear(input_size, 4 * hidden_size, bias=bias)
        self.h2h = nn.Linear(hidden_size, 4 * hidden_size, bias=bias)

    def reset_parameters(self):
        std = 1.0 / math.sqrt(self.hidden_size)
        for w in self.parameters():
            w.data.uniform_(-std, std)
    def forward(self, x, hidden):
        hx, cx = hidden
        gates = self.x2h(x) + self.h2h(hx)
        gates = gates.squeeze()
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        ingate = F.sigmoid(ingate)
        forgetgate = F.sigmoid(forgetgate)
        cellgate = F.tanh(cellgate)
        outgate = F.sigmoid(outgate)        

        cy = torch.mul(cx, forgetgate) +  torch.mul(ingate, cellgate)        

        hy = torch.mul(outgate, F.tanh(cy))
        return (hy, cy)

class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size,num_layers=1):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        cell_list.append(LSTMCell( self.input_size, self.hidden_size))#the first
        #one has a different number of input channels
        for idcell in range(1,self.num_layers):
            cell_list.append(LSTMCell(self.hidden_size, self.hidden_size))
    def forward(self, current_input, hidden_state):
            hidden_state:list of tuples, one for every layer, each tuple should be hidden_layer_i,c_layer_i
            input is the tensor of shape seq_len,Batch,Chans,H,W
        next_hidden=[]#hidden states(h and c)

        for idlayer in range(self.num_layers):#loop for every layer

            hidden_c=hidden_state[idlayer]#hidden and c are images with several channels
            all_output = []
            output_inner = []            
            for t in range(seq_len):#loop for every step


            current_input = hidden_c[0]
        return next_hidden