Creating an output tensor while preserving autograd graph

Hello!
I’m trying to implement a MaxPooling2D layer from scratch without using any modules from toch.nn.

def forward(self, x, debug=False):
        # In shape
        batch, in_features, h_in, w_in = x.shape
        # Expected out dims
        # If the img has been padded the term + 2 * padding should not be considered
        h_out = int(((h_in - self.kernel) / self.stride)) + 1
        w_out = int(((w_in - self.kernel) / self.stride)) + 1
        # Preparing output
        out = torch.zeros(size=(batch, in_features, h_out * w_out))
        # Getting all the windows (basically what fold does)
        offset = 0
        for i in range(0, h_in, self.stride):
            for j in range(0, w_in, self.stride):
                if j + self.kernel > w_in or i + self.kernel > h_in:
                    continue
                window_elems = x[:, :, i:i + self.kernel, j:j + self.kernel].long()
                max_elems = torch.amax(window_elems, in_features)
                elems = torch.amax(max_elems, 2)
                out[:, :, offset] = elems
                offset += 1
        return out.reshape(batch, in_features, h_out, w_out)

The problem that I’m currently having is with the out tensor that not only is considered a leaf node but also it has no grad_fn (and that means I can not call the backward from the model). Is there any way that I can create the out vector while “linking it” with the input vector in the autograd graph?
Thanks in advance!

Hi @Joseph_M,

You could try adding 0*input, and see if that works? The output does have a different shape than the input but you might be able to connect them via some trick like multiplying by 0 and adding it to the output!

It seems you are assigning the amax indices to the output tensors, which will be detached from the computation graph since amax is not differentiable, instead of the max values, which differs from the nn.MaxPool2d implementation:

x = torch.randn(1, 1, 2, 2, requires_grad=True)
pool = nn.MaxPool2d(2, return_indices=True)
out, idx = pool(x)
print(out)
# tensor([[[[-0.0955]]]], grad_fn=<MaxPool2DWithIndicesBackward0>)
print(idx)
# tensor([[[[2]]]])

As you can see, the actual out tensor has a valid grad_fn while the corresponding idx does not.

Thanks for the reply and for the insight! I have just a quick doubt, you said amax is not differentiable but if I try the following code:

x = torch.randint(0,255,size=(1,3,20,20),dtype=torch.float32,requires_grad=True)
y = torch.amax(x)
print(y.requires_grad,y.grad_fn)
# True <AmaxBackward0 object at 0x00000184988F3190>

Infact, the original code seems to be executing fine if I remove the cast .long() on window_elems:

 window_elems = x[:, :, i:i + self.kernel, j:j + self.kernel].long()
print(window_elems.requires_grad,window_elems.grad_fn)
#False None
window_elems = x[:, :, i:i + self.kernel, j:j + self.kernel]
print(window_elems.requires_grad,window_elems.grad_fn)
#True <SliceBackward0 object at 0x00000193CB7D80A0>

Sorry, I meant argmax while you are using amax, which returns the values indeed.

1 Like