Thank you for bearing with me, this post will be a bit long. The returned values by max
don’t need to be variable in shape since they keep the max value and there is only one (maybe repeating) max within the window, but the returned indices will be variable in shape. An example is at the MultiMax
part.
I’ve implemented the custom max
(named MultiMax
) as a custom autograd function and the max pooling (named MultiMaxPool2d
) layer using it. As I asked in this thread, I don’t know how to implement the dim
parameter of torch.max
for my MultiMax
. So, MultiMax
takes a 1d tensor and is called within the for
loop of MultiMaxPool2d
layer. I’ve just learned this approach might be problematic since in-place operations should be avoided in PyTorch.
My question is, what should be the proper approach instead of using the in-place operations for creating the pooled output ? I obviously need to select max values from the input and assign them to their new positions in the pooled output. I’ve added the source code and examples.
Here is my custom max
operation MultiMax
,
import torch
class MultiMax(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mmax = torch.max(input)
inds = torch.nonzero(mmax == input).squeeze() # Get max inds
ctx.save_for_backward(input, mmax, inds)
ctx.mark_non_differentiable(inds)
return mmax, inds
@staticmethod
def backward(ctx, grad_output, for_inds):
input, mmax, inds = ctx.saved_tensors
inds_shape = 1 if inds.dim() == 0 else inds.shape[0]
grad_input = torch.ones(input.shape[0]) * grad_output
grad_input /= inds_shape # Grad is shared among max values.
grad_input[input < mmax] = 0
return grad_input, None
An example using MultiMax
,
inp = torch.tensor([6., 2, 5, 6], requires_grad=True)
mmax = MultiMax.apply
pooled, inds = mmax(inp)
pooled.sum().backward()
print('pooled:\n', pooled)
print('inds:\n', inds)
print('inp grads:\n', inp.grad)
pooled:
tensor(6., grad_fn=<MultiMaxBackward>)
inds:
tensor([0, 3])
inp grads:
tensor([0.5000, 0.0000, 0.0000, 0.5000])
Here is the implementation of custom max pooling MultiMaxPool2d
using MultiMax
(adapted from here),
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
from MultiMax import MultiMax
class MultiMaxPool2d(nn.Module):
def __init__(self, kernel_size, stride, padding=0, same=False):
super(MultiMaxPool2d, self).__init__()
self.k = kernel_size if type(kernel_size) is tuple else _pair(kernel_size)
self.stride = stride if type(stride) is tuple else _pair(stride)
self.padding = padding if type(padding) is tuple else _quadruple(padding)
self.same = same
def _init_pool_inds(self, s):
pool = torch.zeros(s[0], s[1], s[2], s[3])
inds = []
for x in range(0, s[0]):
inds.append([])
for y in range(0, s[1]):
inds[x].append([])
for z in range(0, s[2]):
inds[x][y].append([])
for t in range(0, s[3]):
inds[x][y][z].append(0)
return pool, inds
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,))
mmax = MultiMax.apply
s = x.shape[:4]
pool, inds = self._init_pool_inds(s)
for i in range(0, s[0]):
for j in range(0, s[1]):
for k in range(0, s[2]):
for l in range(0, s[3]):
_max, _is = mmax(x[i][j][k][l])
pool[i][j][k][l] = _max
inds[i][j][k][l] = _is
return pool, inds
An example using MultiMaxPool2d
,
inp = torch.tensor([[[[ 6., 2, 3, 9],
[ 5, 6, 7, 8],
[ 9, 10, 11, 16],
[14, 14, 16, 16]]]], requires_grad=True)
mp = MultiMaxPool2d(kernel_size=2, stride=2)
pooled, inds = mp(inp)
pooled.sum().backward()
print('pooled:\n', pooled)
print('inds:\n', inds)
print('inp grads:\n', inp.grad)
pooled:
tensor([[[[ 6., 9.],
[14., 16.]]]], grad_fn=<CopySlices>)
inds:
[[[[tensor([0, 3]), tensor(1)], [tensor([2, 3]), tensor([1, 2, 3])]]]]
inp grads:
tensor([[[[0.5000, 0.0000, 0.0000, 1.0000],
[0.0000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3333],
[0.5000, 0.5000, 0.3333, 0.3333]]]])
The gradients are calculated as I expected in both of the operations but I also implemented an unpooling layer which I believe also suffers from in-place operations and the gradients are not calculated as I expected.