How to calculate gradients correctly (without in-place operations) for custom unpooling layer?

Hi all, I’m trying to implement a custom max unpooling layer. The difference from torch.nn.MaxUnpool2d is that all the max indices within the pooling window are used for unpooling in case of a repeated max value. Here is an example,

pooled_inp = torch.tensor([[[[2., 4], 
                         [6, 8]]]], requires_grad=True)

inds = [[[[torch.tensor(3), torch.tensor([0, 1])], 
          [torch.tensor([1, 2, 3]), torch.tensor(0)]]]]

up = MultiMaxUnpool2d(kernel_size=2, stride=2)
unpooled = up(pooled_inp, inds)
print('unpooled:\n', unpooled)

unpooled:
 tensor([[[[0., 0., 4., 4.],
          [0., 2., 0., 0.],
          [0., 6., 8., 0.],
          [6., 6., 0., 0.]]]], grad_fn=<CopySlices>)

Unpooling works as I expected but the gradients are not calculated correctly. Here are the gradients computed for the above graph,

unpooled.retain_grad()
unpooled.sum().backward()
print('unpooled.grad:\n', unpooled.grad)
print('pooled_inp.grad:\n', pooled_inp.grad)

unpooled.grad:
 tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

pooled_inp.grad:
 tensor([[[[1., 2.],
          [3., 1.]]]])

Here is what I expected,

pooled_inp.grad:
 tensor([[[[1., 1.],
          [1., 1.]]]])

I don’t understand why the gradients are calculated like that but I’ve learned that the in-place operations should be avoided in Pytorch, so that might be the reason for it. What would be the proper way of implementation without performing in-place operations ? Here is my custom unpooling implementation,

import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair

class MultiMaxUnpool2d(nn.Module):
    
    def __init__(self, kernel_size, stride):
        super(MultiMaxUnpool2d, self).__init__()
        self.k = kernel_size if type(kernel_size) is tuple else _pair(kernel_size)
        self.stride = stride if type(stride) is tuple else _pair(stride)
        
    """ Returns unpooled output size """
    def _calc_size(self, xSize):
        h_in, w_in = xSize[2:]
        h_out = (h_in - 1) * self.stride[0] + self.k[0]
        w_out = (w_in - 1) * self.stride[1] + self.k[1]
        return h_out, w_out
    
    """ Takes an index in 1d, converts it into 2d. Returns its tuple. """
    def _to_2d(self, ind, tl):
        _h, _w = tl
        return (int(ind / self.k[1]) + _h, int(ind % self.k[1]) + _w)
        
    """ Takes a tensor of local indices in a pooling window and that window's
    indices in the pooled input. Returns the list of global indices for 
    unpooled output."""
    def _local_to_global(self, local_inds, i, j):
        tl = (i * self.stride[0], j * self.stride[1])                           # Calculate tl of pooling window in unpooled output. 
        if local_inds.dim() == 0:                                               # torch.tensor() format
            return [self._to_2d(local_inds.item(), tl)]
        else:                                                                   # torch.tensor([...]) format
            return [self._to_2d(ind, tl) for ind in local_inds]
                
    def forward(self, x, inds):
        batch_size, num_ch, h_in, w_in = x.size()[:4]
        h_out, w_out = self._calc_size(x.size())
        unpooled = torch.zeros(batch_size, num_ch, h_out, w_out)
        for b in range(0, batch_size):
            for c in range(0, num_ch):
                for i in range(0, h_in):
                    for j in range(0, w_in):
                        max_val = x[b][c][i][j]
                        local_inds = inds[b][c][i][j]
                        glob_inds = self._local_to_global(local_inds, i, j)     # List of local indices converted to global indices 
                        for g in glob_inds:
                            h_g, w_g = g
                            unpooled[b][c][h_g][w_g] = max_val
        return unpooled

Hi,

That gradient actually looks correct to me. These values are used multiple times, so their gradient is the sum of the gradients for each of their contributions.

@albanD In that case, is it proper to use the for loop for indexing the torch.tensor as in the forward function ?

I wrote it before learning that in-place operations should be avoided in PyTorch. And it is stated in this reply that the assignment of a value to an index is an in-place operation, so my implementation does not seem correct to me for that reason. But given that the forward and backward functions are working as expected, it might be correct, so I am confused. Any elaboration would be helpful.

You don’t have to avoid them. It is just that autograd does not support every combination of them and it will raise an error if you hit such case.
So if your code runs without error, it means that autograd can handle this case just fine.

The only concern I would have with such implementation is the slowdown due to the nested loops. But that’s unrelated to gradient correctness.

Thanks a lot for the clarification, I didn’t know that! The slowdown is not a concern for now but I’ll take a look at it in the future.

All the best.