Thank you for bearing with me, this post will be a bit long. The returned values by `max`

don’t need to be variable in shape since they keep the max value and there is only one (maybe repeating) max within the window, but the returned indices will be variable in shape. An example is at the `MultiMax`

part.

I’ve implemented the custom `max`

(named `MultiMax`

) as a custom autograd function and the max pooling (named `MultiMaxPool2d`

) layer using it. As I asked in this thread, I don’t know how to implement the `dim`

parameter of `torch.max`

for my `MultiMax`

. So, `MultiMax`

takes a 1d tensor and is called within the `for`

loop of `MultiMaxPool2d`

layer. I’ve just learned this approach might be problematic since in-place operations should be avoided in PyTorch.

My question is, what should be the proper approach instead of using the in-place operations for creating the pooled output ? I obviously need to select max values from the input and assign them to their new positions in the pooled output. I’ve added the source code and examples.

Here is my custom `max`

operation `MultiMax`

,

```
import torch
class MultiMax(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
mmax = torch.max(input)
inds = torch.nonzero(mmax == input).squeeze() # Get max inds
ctx.save_for_backward(input, mmax, inds)
ctx.mark_non_differentiable(inds)
return mmax, inds
@staticmethod
def backward(ctx, grad_output, for_inds):
input, mmax, inds = ctx.saved_tensors
inds_shape = 1 if inds.dim() == 0 else inds.shape[0]
grad_input = torch.ones(input.shape[0]) * grad_output
grad_input /= inds_shape # Grad is shared among max values.
grad_input[input < mmax] = 0
return grad_input, None
```

An example using `MultiMax`

,

```
inp = torch.tensor([6., 2, 5, 6], requires_grad=True)
mmax = MultiMax.apply
pooled, inds = mmax(inp)
pooled.sum().backward()
print('pooled:\n', pooled)
print('inds:\n', inds)
print('inp grads:\n', inp.grad)
pooled:
tensor(6., grad_fn=<MultiMaxBackward>)
inds:
tensor([0, 3])
inp grads:
tensor([0.5000, 0.0000, 0.0000, 0.5000])
```

Here is the implementation of custom max pooling `MultiMaxPool2d`

using `MultiMax`

(adapted from here),

```
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.utils import _pair, _quadruple
from MultiMax import MultiMax
class MultiMaxPool2d(nn.Module):
def __init__(self, kernel_size, stride, padding=0, same=False):
super(MultiMaxPool2d, self).__init__()
self.k = kernel_size if type(kernel_size) is tuple else _pair(kernel_size)
self.stride = stride if type(stride) is tuple else _pair(stride)
self.padding = padding if type(padding) is tuple else _quadruple(padding)
self.same = same
def _init_pool_inds(self, s):
pool = torch.zeros(s[0], s[1], s[2], s[3])
inds = []
for x in range(0, s[0]):
inds.append([])
for y in range(0, s[1]):
inds[x].append([])
for z in range(0, s[2]):
inds[x][y].append([])
for t in range(0, s[3]):
inds[x][y][z].append(0)
return pool, inds
def _padding(self, x):
if self.same:
ih, iw = x.size()[2:]
if ih % self.stride[0] == 0:
ph = max(self.k[0] - self.stride[0], 0)
else:
ph = max(self.k[0] - (ih % self.stride[0]), 0)
if iw % self.stride[1] == 0:
pw = max(self.k[1] - self.stride[1], 0)
else:
pw = max(self.k[1] - (iw % self.stride[1]), 0)
pl = pw // 2
pr = pw - pl
pt = ph // 2
pb = ph - pt
padding = (pl, pr, pt, pb)
else:
padding = self.padding
return padding
def forward(self, x):
x = F.pad(x, self._padding(x), mode='reflect')
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1])
x = x.contiguous().view(x.size()[:4] + (-1,))
mmax = MultiMax.apply
s = x.shape[:4]
pool, inds = self._init_pool_inds(s)
for i in range(0, s[0]):
for j in range(0, s[1]):
for k in range(0, s[2]):
for l in range(0, s[3]):
_max, _is = mmax(x[i][j][k][l])
pool[i][j][k][l] = _max
inds[i][j][k][l] = _is
return pool, inds
```

An example using `MultiMaxPool2d`

,

```
inp = torch.tensor([[[[ 6., 2, 3, 9],
[ 5, 6, 7, 8],
[ 9, 10, 11, 16],
[14, 14, 16, 16]]]], requires_grad=True)
mp = MultiMaxPool2d(kernel_size=2, stride=2)
pooled, inds = mp(inp)
pooled.sum().backward()
print('pooled:\n', pooled)
print('inds:\n', inds)
print('inp grads:\n', inp.grad)
pooled:
tensor([[[[ 6., 9.],
[14., 16.]]]], grad_fn=<CopySlices>)
inds:
[[[[tensor([0, 3]), tensor(1)], [tensor([2, 3]), tensor([1, 2, 3])]]]]
inp grads:
tensor([[[[0.5000, 0.0000, 0.0000, 1.0000],
[0.0000, 0.5000, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.3333],
[0.5000, 0.5000, 0.3333, 0.3333]]]])
```

The gradients are calculated as I expected in both of the operations but I also implemented an unpooling layer which I believe also suffers from in-place operations and the gradients are not calculated as I expected.