Maxpool from scratch

Hi, I am trying to implement maxpool fonction from scatch (for fun) and use backward() on it.
In my implementation below, the output Y of maxpool is correct (I verified it).
But the gradient of the input at a zero tensor, which is wrong.

code :

import torch
def maxpool_fp(X):
    pool_dim = 2
    pool_stripe = 2
    bs, cx, dx, _ = list(X.size())  # batch size ; nb of channel of X ;  dimension of X
    dy = int(X.shape[2] / pool_stripe)  # dimension of Y

    Y = X[:, :, :dy, :dy] * 1  # *1 to avoid: RuntimeError: leaf variable has been moved into the graph interior
    for yn in range(bs):
        for yc, x in enumerate(X[yn]):
            for yh, h in enumerate(range(0, dx, pool_stripe)):
                for yw, w in enumerate(range(0, dx, pool_stripe)):
                    Y[yn, yc, yh, yw] = torch.max(x[h:h + pool_dim, w:w + pool_dim]).item()
    return Y


X = torch.randn(2, 2, 8, 8, requires_grad=True)

Y = maxpool_fp(X)
S = torch.sum(Y)
S.backward()

print("S =", S) # ==> Correct
print("X.grad\n", X.grad) # ==>  zero tensor !!!!!!!!!!

Hi,

The .item() that you use converts the result of your max into a python number. And we can’t track gradient for it. So it is ignored.
You will need to remove this .item() to get proper gradients.

1 Like

Great. Working now. Thanks a lot

But in order to keep track of the gradient I did this :

dy = int(X.shape[2] / pool_stripe)  # dimension of Y
Y = X[:, :, :dy, :dy] * 1

Can you see a more elegant way to do so? Or does it look fine for you ?

Why do you need the *1 here?
Can you share the full code sample maybe that will be clearer :slight_smile:

Before removing item() the *1 was there avoid the Error:

RuntimeError: leaf variable has been moved into the graph interior

When executing the first :


Y[yn, yc, yh, yw] = torch.max(x[h:h + pool_dim, w:w + pool_dim]).item()

X is_leaf attribute changed from True to False. I put the *1 to set the is_leaf attribute to False in the first place in order to avoid the change.

After removing item(), the *1 is still needed to avoid another Error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [2, 2]], which is output 0 of SliceBackward, is at version 64; expected version 63 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

which I can’t really explain

You can copy/past the code below and remove the *1 if you want to reproduce the RuntimeError. But please note that my problem was fixed by removing the .item(). I put this maxpool() function along with a hardcoded conv2d in a CNN model, and it works on MNIST.


import torch

def maxpool_fp(X):
    pool_dim = 2
    pool_stripe = 2
    bs, cx, dx, _ = list(X.size())  # batch size ; nb of channel of X ;  dimension of X
    dy = int(X.shape[2] / pool_stripe)  # dimension of Y

    Y = X[:, :, :dy, :dy] * 1  # <=============
    for yn in range(bs):
        for yc, x in enumerate(X[yn]):
            for yh, h in enumerate(range(0, dx, pool_stripe)):
                for yw, w in enumerate(range(0, dx, pool_stripe)):
                    Y[yn, yc, yh, yw] = torch.max(x[h:h + pool_dim, w:w + pool_dim])
    return Y


X = torch.randn(2, 2, 8, 8, requires_grad=True)

Y = maxpool_fp(X)
S = torch.sum(Y)
S.backward()

print("S =", S)
print("X.grad\n", X.grad)

Ho right!
Yes, you can’t modify X inplace because it’s a leaf. And when you do Y = X[:, :, :dy, :dy], Y is actually a view into X. And so when you modify Y inplace by doing Y[yn, yc, yh, yw] = foo, you also modify X leading to the error.
The right way to fix that is to use .clone() to ensure you get new memory that is not a view of X: Y = X[:, :, :dy, :dy].clone() (should be slightly faster than the *1 you use).