Custom Multidimensional RNN Module - Several Questions

Greetings, (this got kind of lengthy, but I’m just looking for some hints to get me in the right direction with this project :slightly_smiling_face:)

I’m currently struggling a lot with my try to implement a multidimensional (first of all 2D) RNN (similar to this paper by Alex Graves et al.). I want to use this for an image semantic segmentation task, so whats basically at every pixel I get the information of that pixel as well as the hidden states of the neighboring pixels above and to the right (in the current unidirectional approach). From those I compute a new hidden state and an output that is passed through a softmax function to predict the class probabilities for this pixel. Current code is at the end of the post, passing batches works and produces correctly shaped output, but I’m stuck at trying to train.

I have several issues and am kind of unhappy about my general approach, I hope some of you can point me in the right direction. I’ll just list my questions:

  1. First of all, I’m currently getting the error: one of the variables needed for gradient computation has been modified by an inplace operation. I understand what the problem is, but can’t find the actual operation that is meant.

  2. While iterating over the two dimensions of my image I need to store the computed hidden states as well as aggregate the class predictions for all pixels. At the moment I just create tensors of zeros and assign my values to their columns/rows/whatever-its-called in the process. This seems subotimal, is there a better more efficient way to do this? Maybe this is also the problem for the error?

  3. I tried to implement a MDRNN layer myself because I dislike the way I have to use linear layers for all of my weight matrices and this is going to be a fuss if I want to scale up to more dimensions. I failed mainly because I did not find a lot of information about this. The pytorch tutorials only show custom layers with the example of a linear one, which together with the RNN source code just wasn’t enough for me to get this done. Are there more resources on this? Would it be the correct way of approaching my goal?

Again, sorry for the long post, I don’t want it to seems like I want anyone to solve my project, but I’m stuck and need some help to get back on track.

class MDRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MDRNN, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.w_ih = nn.Linear(input_size, hidden_size) # Weights from input to hidden layer
        self.w_hh_1 = nn.Linear(hidden_size, hidden_size) # Recurrent weights for dimension 1 
        self.w_hh_2 = nn.Linear(hidden_size, hidden_size) # Recurrent weights for dimension 2
        self.w_ho = nn.Linear(hidden_size, output_size) # Weights from hidden to output layer
        self.softmax = nn.LogSoftmax(dim=1) # Softmax to predict classes

    def step(self, x, h1_last, h2_last):
        x_t = self.w_ih(x) # Compute incoming values from input
        h1 = self.w_hh_1(h1_last) # Compute recurrent values from dimension 1
        h2 = self.w_hh_2(h2_last) # Compute recurrent values from dimension 2
        net_t = x_t + h1 + h2 # Sum up inputs
        h_t = F.relu(net_t) # Activation function to get new hidden state
        out = self.w_ho(h_t) # Compute output from new hidden state
        pred = self.softmax(out) # Predict classes with softmax
        return pred, h_t
    def forward(self, batch):
        batch_size, T1, T2, _ = batch.shape
        batch = batch.permute(1,2,0,3) # Reshape to dim1 x dim2 x batch x input
        context_buffer = torch.zeros(T1, T2, batch_size, self.hidden_size) # Storage for hidden states
        segmentation = torch.zeros(T1, T2, batch_size, self.output_size) # Storage for prediction tensor
        for t1 in range(T1):
            for t2 in range(T2):
                h1_last = torch.zeros(batch_size, self.hidden_size) if t1 == 0 else context_buffer[t1 - 1, t2] # Previous hidden state in dim 1 or 0 at first step
                h2_last = torch.zeros(batch_size, self.hidden_size) if t2 == 0 else context_buffer[t1, t2 - 1] # Previous hidden state in dim 2 or 0 at first step
                out, h_t = self.step(batch[t1,t2], h1_last, h2_last) # Compute prediction and new hidden state
                context_buffer[t1, t2] = h_t # Store hidden state
                segmentation[t1, t2] = out # Store prediction
        return segmentation.permute(2,3,0,1) # Reshape to batch x output x dim1 x dim2 (NLLLoss wants this)