Hi, I am trying to implement maxpool fonction from scatch (for fun) and use backward() on it.
In my implementation below, the output Y of maxpool is correct (I verified it).
But the gradient of the input at a zero tensor, which is wrong.
import torch def maxpool_fp(X): pool_dim = 2 pool_stripe = 2 bs, cx, dx, _ = list(X.size()) # batch size ; nb of channel of X ; dimension of X dy = int(X.shape / pool_stripe) # dimension of Y Y = X[:, :, :dy, :dy] * 1 # *1 to avoid: RuntimeError: leaf variable has been moved into the graph interior for yn in range(bs): for yc, x in enumerate(X[yn]): for yh, h in enumerate(range(0, dx, pool_stripe)): for yw, w in enumerate(range(0, dx, pool_stripe)): Y[yn, yc, yh, yw] = torch.max(x[h:h + pool_dim, w:w + pool_dim]).item() return Y X = torch.randn(2, 2, 8, 8, requires_grad=True) Y = maxpool_fp(X) S = torch.sum(Y) S.backward() print("S =", S) # ==> Correct print("X.grad\n", X.grad) # ==> zero tensor !!!!!!!!!!