Hi all, I’m trying to implement a custom max unpooling layer. The difference from torch.nn.MaxUnpool2d
is that all the max indices within the pooling window are used for unpooling in case of a repeated max value. Here is an example,
pooled_inp = torch.tensor([[[[2., 4],
[6, 8]]]], requires_grad=True)
inds = [[[[torch.tensor(3), torch.tensor([0, 1])],
[torch.tensor([1, 2, 3]), torch.tensor(0)]]]]
up = MultiMaxUnpool2d(kernel_size=2, stride=2)
unpooled = up(pooled_inp, inds)
print('unpooled:\n', unpooled)
tensor([[[[0., 0., 4., 4.],
[0., 2., 0., 0.],
[0., 6., 8., 0.],
[6., 6., 0., 0.]]]], grad_fn=<CopySlices>)
Unpooling works as I expected but the gradients are not calculated correctly. Here are the gradients computed for the above graph,
print('unpooled.grad:\n', unpooled.grad)
print('pooled_inp.grad:\n', pooled_inp.grad)
tensor([[[[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]]]])
tensor([[[[1., 2.],
[3., 1.]]]])
Here is what I expected,
tensor([[[[1., 1.],
[1., 1.]]]])
I don’t understand why the gradients are calculated like that but I’ve learned that the in-place operations should be avoided in Pytorch, so that might be the reason for it. What would be the proper way of implementation without performing in-place operations ? Here is my custom unpooling implementation,
import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
class MultiMaxUnpool2d(nn.Module):
def __init__(self, kernel_size, stride):
super(MultiMaxUnpool2d, self).__init__()
self.k = kernel_size if type(kernel_size) is tuple else _pair(kernel_size)
self.stride = stride if type(stride) is tuple else _pair(stride)
""" Returns unpooled output size """
def _calc_size(self, xSize):
h_in, w_in = xSize[2:]
h_out = (h_in - 1) * self.stride[0] + self.k[0]
w_out = (w_in - 1) * self.stride[1] + self.k[1]
return h_out, w_out
""" Takes an index in 1d, converts it into 2d. Returns its tuple. """
def _to_2d(self, ind, tl):
_h, _w = tl
return (int(ind / self.k[1]) + _h, int(ind % self.k[1]) + _w)
""" Takes a tensor of local indices in a pooling window and that window's
indices in the pooled input. Returns the list of global indices for
unpooled output."""
def _local_to_global(self, local_inds, i, j):
tl = (i * self.stride[0], j * self.stride[1]) # Calculate tl of pooling window in unpooled output.
if local_inds.dim() == 0: # torch.tensor() format
return [self._to_2d(local_inds.item(), tl)]
else: # torch.tensor([...]) format
return [self._to_2d(ind, tl) for ind in local_inds]
def forward(self, x, inds):
batch_size, num_ch, h_in, w_in = x.size()[:4]
h_out, w_out = self._calc_size(x.size())
unpooled = torch.zeros(batch_size, num_ch, h_out, w_out)
for b in range(0, batch_size):
for c in range(0, num_ch):
for i in range(0, h_in):
for j in range(0, w_in):
max_val = x[b][c][i][j]
local_inds = inds[b][c][i][j]
glob_inds = self._local_to_global(local_inds, i, j) # List of local indices converted to global indices
for g in glob_inds:
h_g, w_g = g
unpooled[b][c][h_g][w_g] = max_val
return unpooled