Write custom LSTM cell and handling PackedSequence


I would like to implement multiplicative LSTM (https://arxiv.org/pdf/1609.07959.pdf) and found an implementation that seem to work with normal inputs (i.e. not packed sequences) here: https://github.com/FlorianWilhelm/mlstm4reco/blob/master/src/mlstm4reco/layers.py

The code is below:

class mLSTM(RNNBase):
    def __init__(self, input_size, hidden_size, bias=True):
        super(mLSTM, self).__init__(
            mode='LSTM', input_size=input_size, hidden_size=hidden_size,
                 num_layers=1, bias=bias, batch_first=True,
                 dropout=0, bidirectional=False)

        w_im = torch.Tensor(hidden_size, input_size)
        w_hm = torch.Tensor(hidden_size, hidden_size)
        b_im = torch.Tensor(hidden_size)
        b_hm = torch.Tensor(hidden_size)
        self.w_im = Parameter(w_im)
        self.b_im = Parameter(b_im)
        self.w_hm = Parameter(w_hm)
        self.b_hm = Parameter(b_hm)

        self.lstm_cell = LSTMCell(input_size, hidden_size, bias)

    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.hidden_size)
        for weight in self.parameters():
            weight.data.uniform_(-stdv, stdv)

    def forward(self, input, hx):
        n_batch, n_seq, n_feat = input.size()

        hx, cx = hx
        steps = [cx.unsqueeze(1)]
        for seq in range(n_seq):
            mx = F.linear(input[:, seq, :], self.w_im, self.b_im) * F.linear(hx, self.w_hm, self.b_hm)
            hx = (mx, cx)
            hx, cx = self.lstm_cell(input[:, seq, :], hx)

        return torch.cat(steps, dim=1)

I checked the code used in current LSTM/GRU/RNNBase in https://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html but I don’t know how I could replace easily _impl = _rnn_impls[self.mode] or if there would be a way to handle PackedSequence directly in the code of mLSTM.

Thank you for your help