I am trying to implement an Multi-Dimensional Recurrent Neural Network [1]. In particular, I am trying to implement a Multi-Dimensional LSTM layer to do multi-dimensional sequence to sequence operations.
Multi-dimensional sequences:
Consider how a LSTM takes as input a sequence and outputs a sequence, by running a LSTM cell over its 1d input. A multi-dimensional LSTM would do the same but for inputs of arbitrary dimension. For instance, an image can be seen as a 2d sequence; the LSTM cell would start at a corner, and iterate over both dimensions (to each cell has 2 successors, instead of 1 for standard sequences).
What I implemented:
I implemented 2 classes, MDLSTMCell(torch.nn.Module)
and MDLSTM(torch.nn.Module)
. Here is the forward()
function for the first:
def forward(self, x, old):
s_0, h_0 = old
# Note on input dimension:
# - x is of size (batch_size, self.size_in). It is the value of the
# input sequence at a given position.
# - s_0 and h_0 are of size (self.dim_in,batch_size,self.size_out). They
# are the values of the cell state and hidden state at the "previous"
# positions, previous from every dimension (default should be 0).
# 1/ Forget, input, output and cell activation gates.
f = [torch.sigmoid(self.biasf[l] + torch.mm(x, self.wf[l]) + sum(torch.mul(h_0[k], self.uf[l][k]) for k in range(self.dim_in))) for l in range(self.dim_in)]
i = torch.sigmoid(self.biasi + torch.mm(x, self.wi) + sum(torch.mul(h_0[k], self.ui[k]) for k in range(self.dim_in)))
o = torch.sigmoid(self.biaso + torch.mm(x, self.wo) + sum(torch.mul(h_0[k], self.uo[k]) for k in range(self.dim_in)))
c = torch.sigmoid(self.biasc + torch.mm(x, self.wc) + sum(torch.mul(h_0[k], self.uc[k]) for k in range(self.dim_in)))
# 2/ Cell state
s = torch.mul(i, c) + sum(torch.mul(f[k], s_0[k]) for k in range(self.dim_in))
# 3/ Final output
h = torch.mul(o, torch.tanh(s))
return (s, h)
And here is the the second one:
def forward(self, x):
# **Note on states:**
# States are stored as "1d"-tensors (safe for the batch size and output
# dimension) and concatenated at the end.
# Uses unravel_index() and ravel_multi_index() to go from 1d indexing
# to multi-dimensional indexing with the prev_all() function.
shape_idx = x.shape[:-2]
batch_size = x.shape[-2]
shape_t = (batch_size, self.size_out)
s = [] # Cell state
h = [] # Hidden state
for idx in self.iter_idx(shape_idx):
s_new, h_new = self.mdlstmcell(x[idx], (prev_all(s, idx, shape_idx, shape_t), prev_all(h, idx, shape_idx, shape_t)))
s.append(s_new.reshape((1, *shape_t)))
h.append(h_new.reshape((1, *shape_t)))
h = torch.cat(h)
h = torch.reshape(h, (*shape_idx, *shape_t))
return h
What I would like to do:
This code works, but I don’t like it. Ideally, the cell state s
and hidden state h
would be pre-allocated with shape (d1, ..., dn, batch_size, self.size_out)
and filled with a loop. However, I cannot do this. Filling them with a loop would be doing in-place operations, and PyTorch does not like in-place operations, even though in this specific case it should cause no issue. It should cause no issue because when filling these tensors, I would never override a value that would be used in practice.
As you can see, my solution was to store them in Python lists, and at the end concatenate/reshape everything.
My questions:
- Does my approach have an significant impact on performance? I’m assuming it does, allocating a tensor once and filling it should be much faster that performing list operations.
- Can I tell PyTorch to, locally, ignore the fact that I am doing in-place operations? This would allow me to fill
s
andh
as tensors. - Does anyone have a better idea as to how to approach this problem?