Thank you for bearing with me, this post will be a bit long. The returned values by max don’t need to be variable in shape since they keep the max value and there is only one (maybe repeating) max within the window, but the returned indices will be variable in shape. An example is at the MultiMax part.
I’ve implemented the custom max (named MultiMax) as a custom autograd function and the max pooling (named MultiMaxPool2d) layer using it. As I asked in this thread, I don’t know how to implement the dim parameter of torch.max for my MultiMax. So, MultiMax takes a 1d tensor and is called within the for loop of MultiMaxPool2d layer. I’ve just learned this approach might be problematic since in-place operations should be avoided in PyTorch.
My question is, what should be the proper approach instead of using the in-place operations for creating the pooled output ? I obviously need to select max values from the input and assign them to their new positions in the pooled output. I’ve added the source code and examples.
Here is my custom max operation MultiMax,
import torch
class MultiMax(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mmax = torch.max(input)
inds = torch.nonzero(mmax == input).squeeze() # Get max inds
ctx.save_for_backward(input, mmax, inds)
ctx.mark_non_differentiable(inds)
return mmax, inds
@staticmethod
def backward(ctx, grad_output, for_inds):
input, mmax, inds = ctx.saved_tensors
inds_shape = 1 if inds.dim() == 0 else inds.shape[0]
grad_input = torch.ones(input.shape[0]) * grad_output
grad_input /= inds_shape # Grad is shared among max values.
grad_input[input < mmax] = 0
return grad_input, None
An example using MultiMax,
inp = torch.tensor([6., 2, 5, 6], requires_grad=True)
mmax = MultiMax.apply
pooled, inds = mmax(inp)
pooled.sum().backward()
print('pooled:\n', pooled)
print('inds:\n', inds)
print('inp grads:\n', inp.grad)
pooled:
tensor(6., grad_fn=<MultiMaxBackward>)
inds:
tensor([0, 3])
inp grads:
tensor([0.5000, 0.0000, 0.0000, 0.5000])
Here is the implementation of custom max pooling MultiMaxPool2d using MultiMax (adapted from here),
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
from MultiMax import MultiMax
class MultiMaxPool2d(nn.Module):
def __init__(self, kernel_size, stride, padding=0, same=False):
super(MultiMaxPool2d, self).__init__()
self.k = kernel_size if type(kernel_size) is tuple else _pair(kernel_size)
self.stride = stride if type(stride) is tuple else _pair(stride)
self.padding = padding if type(padding) is tuple else _quadruple(padding)
self.same = same
def _init_pool_inds(self, s):
pool = torch.zeros(s[0], s[1], s[2], s[3])
inds = []
for x in range(0, s[0]):
inds.append([])
for y in range(0, s[1]):
inds[x].append([])
for z in range(0, s[2]):
inds[x][y].append([])
for t in range(0, s[3]):
inds[x][y][z].append(0)
return pool, inds
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,))
mmax = MultiMax.apply
s = x.shape[:4]
pool, inds = self._init_pool_inds(s)
for i in range(0, s[0]):
for j in range(0, s[1]):
for k in range(0, s[2]):
for l in range(0, s[3]):
_max, _is = mmax(x[i][j][k][l])
pool[i][j][k][l] = _max
inds[i][j][k][l] = _is
return pool, inds
An example using MultiMaxPool2d,
inp = torch.tensor([[[[ 6., 2, 3, 9],
[ 5, 6, 7, 8],
[ 9, 10, 11, 16],
[14, 14, 16, 16]]]], requires_grad=True)
mp = MultiMaxPool2d(kernel_size=2, stride=2)
pooled, inds = mp(inp)
pooled.sum().backward()
print('pooled:\n', pooled)
print('inds:\n', inds)
print('inp grads:\n', inp.grad)
pooled:
tensor([[[[ 6., 9.],
[14., 16.]]]], grad_fn=<CopySlices>)
inds:
[[[[tensor([0, 3]), tensor(1)], [tensor([2, 3]), tensor([1, 2, 3])]]]]
inp grads:
tensor([[[[0.5000, 0.0000, 0.0000, 1.0000],
[0.0000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3333],
[0.5000, 0.5000, 0.3333, 0.3333]]]])
The gradients are calculated as I expected in both of the operations but I also implemented an unpooling layer which I believe also suffers from in-place operations and the gradients are not calculated as I expected.