Why does PyTorch's max pooling layer store input tensors?

I made a simple model like below. It seems weird but it has one convolutional layer and two maxpooling layer.

class simple_model(nn.Module):
     def __init__(self):
         super(simple_model, self).__init__()
         self.maxpool2D = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
         self.conv1 = nn.Conv2d(3, 20, (5, 5))

     def forward(self, x):
         x = self.maxpool2D(self.maxpool2D(self.conv1(x)))

         return x

And I check the tensors that saved in forward propagation using gradient hook.

pack_saved_tensors = []
def pack_hook(x):
    saved_tensors.append(x)
    return x

unpack_used_tensors = []
def unpack_hook(x):
    unpack_used_tensors.append(x)
    return x

with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
     model_output = model(input_tensors)

label = torch.randn(model_output.size()).to(device)
loss = criterion(model_output, label)
loss.backward()

This is the result. And I even checked that all stored tensors are used for backward propagation.(with unpack hook funciton) And I think,

0th, 1st tensors saved for convolutional layer,
2nd, 3rd tensors saved for first maxpool2D layer,
4th, 5th tensors saved for second maxpool2D layer.

pack hook saved tensors:
0 tensor size: torch.Size([64, 3, 224, 224]), tensor type: torch.float32
1 tensor size: torch.Size([20, 3, 5, 5]), tensor type: torch.float32
2 tensor size: torch.Size([64, 20, 220, 220]), tensor type: torch.float32
3 tensor size: torch.Size([64, 20, 110, 110]), tensor type: torch.int64
4 tensor size: torch.Size([64, 20, 110, 110]), tensor type: torch.float32
5 tensor size: torch.Size([64, 20, 55, 55]), tensor type: torch.int64

unpack hook used tensors:
6 tensor size: torch.Size([64, 20, 110, 110]), tensor type: torch.float32
7 tensor size: torch.Size([64, 20, 55, 55]), tensor type: torch.int64
8 tensor size: torch.Size([64, 20, 220, 220]), tensor type: torch.float32
9 tensor size: torch.Size([64, 20, 110, 110]), tensor type: torch.int64
10 tensor size: torch.Size([64, 3, 224, 224]), tensor type: torch.float32
11 tensor size: torch.Size([20, 3, 5, 5]), tensor type: torch.float32

My question is:

Why Pytorch store input tensors for maxpooling layer? I think, in backward propagation, max pooling layer only need to store int64 tensors(3rd, 5th layer which is store indices of max value).

Any help is appreciated.

Hi Core!

My speculation (assuming that built-in pytorch functions work the same
way as do custom autograd functions):

Because MaxPool2D downsamples, it loses some information about the
shape of its input. When its static backward() method is called, the only
thing it knows is grad_output (and hence the shape of its output) and
anything it saved in ctx. MaxPool2d, as you noted, saves the argmin
index tensor in ctx to avoid recomputation in the backward pass. But
this index tensor is not enough to reconstruct the shape of the input to
MaxPool2d (which is the same as the shape of the gradient that backward()
needs to return).

MaxPool2d chooses to save its input, giving it input.shape, which,
together with the argmin index tensor, is enough to compute the gradient.
Saving the input might seem to be overkill, but if its input – often the output
of a previous layer – is already being saved by the previous layer for its
backward pass, then saving it again is essentially free.

This is surely not the only way of implementing MaxPool2d, but it seems
economical, so it’s probably a reasonable choice.

Best.

K. Frank

2 Likes

Thank you very much for your kind explanation!