Understand Refcount

Greetings; I have recently started reading the PyTorch source code and I’m a bit confused about the way reference counting works (or memory management) and I’m trying to make sense of PyTorch’s memory management behavior by running the mnist.py example. The relevant part of the network is short:

x = F.relu(F.max_pool2d(self.conv1(x), 2))

I’m trying to understand why the input tensor of max_pool2d is not freed immediately after the forward pass finishes. Specifically, this line here https://github.com/pytorch/pytorch/blob/0257f5d19f0585f9a82bc06e0c4987e2136332c9/aten/src/THCUNN/generic/SpatialDilatedMaxPooling.cu#L157 does not actually free the underlying tensor storage space because it merely decrement the reference count from 2 to 1. My question is why is there still a reference? The mnist network is a linear model I cannot think of any other places where the input of the pooling could be referenced. Certainly the backward pass of convolution does not need the output of itself thus I wonder where the remaining reference comes from and why is the memory of the input tensor to pooling not immediately released.

Many thanks in advance!