Hi,
I would like to implement multiplicative LSTM (https://arxiv.org/pdf/1609.07959.pdf) and found an implementation that seem to work with normal inputs (i.e. not packed sequences) here: https://github.com/FlorianWilhelm/mlstm4reco/blob/master/src/mlstm4reco/layers.py
The code is below:
class mLSTM(RNNBase):
def __init__(self, input_size, hidden_size, bias=True):
super(mLSTM, self).__init__(
mode='LSTM', input_size=input_size, hidden_size=hidden_size,
num_layers=1, bias=bias, batch_first=True,
dropout=0, bidirectional=False)
w_im = torch.Tensor(hidden_size, input_size)
w_hm = torch.Tensor(hidden_size, hidden_size)
b_im = torch.Tensor(hidden_size)
b_hm = torch.Tensor(hidden_size)
self.w_im = Parameter(w_im)
self.b_im = Parameter(b_im)
self.w_hm = Parameter(w_hm)
self.b_hm = Parameter(b_hm)
self.lstm_cell = LSTMCell(input_size, hidden_size, bias)
self.reset_parameters()
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hx):
n_batch, n_seq, n_feat = input.size()
hx, cx = hx
steps = [cx.unsqueeze(1)]
for seq in range(n_seq):
mx = F.linear(input[:, seq, :], self.w_im, self.b_im) * F.linear(hx, self.w_hm, self.b_hm)
hx = (mx, cx)
hx, cx = self.lstm_cell(input[:, seq, :], hx)
steps.append(cx.unsqueeze(1))
return torch.cat(steps, dim=1)
I checked the code used in current LSTM/GRU/RNNBase in https://pytorch.org/docs/master/_modules/torch/nn/modules/rnn.html but I don’t know how I could replace easily _impl = _rnn_impls[self.mode]
or if there would be a way to handle PackedSequence directly in the code of mLSTM.
Thank you for your help