Hi, I am trying to implement maxpool fonction from scatch (for fun) and use backward() on it.
In my implementation below, the output Y of maxpool is correct (I verified it).
But the gradient of the input at a zero tensor, which is wrong.
code :
import torch
def maxpool_fp(X):
pool_dim = 2
pool_stripe = 2
bs, cx, dx, _ = list(X.size()) # batch size ; nb of channel of X ; dimension of X
dy = int(X.shape[2] / pool_stripe) # dimension of Y
Y = X[:, :, :dy, :dy] * 1 # *1 to avoid: RuntimeError: leaf variable has been moved into the graph interior
for yn in range(bs):
for yc, x in enumerate(X[yn]):
for yh, h in enumerate(range(0, dx, pool_stripe)):
for yw, w in enumerate(range(0, dx, pool_stripe)):
Y[yn, yc, yh, yw] = torch.max(x[h:h + pool_dim, w:w + pool_dim]).item()
return Y
X = torch.randn(2, 2, 8, 8, requires_grad=True)
Y = maxpool_fp(X)
S = torch.sum(Y)
S.backward()
print("S =", S) # ==> Correct
print("X.grad\n", X.grad) # ==> zero tensor !!!!!!!!!!