Hi, I’m trying to custom the Maxpool2d autograd using module torch.autograd.Function.
I want to change the backward output. For my gradient and I want to keep all the maximum indexes and take the mean of this values, not only the first one index. At the end, I have to compute the product with my_grad and the grad_outputs.
This is my code :
class myMaxpool(torch.autograd.Function):
“”"
This autograd function implements a MaxPooling function.
“”"
@staticmethod
def forward(ctx, input, kernel_size, stride):
out = None
pool_height, pool_width = kernel_size
ctx.pool_height, ctx.pool_width = pool_height, pool_width
ctx.stride = stride
N, C, H, W = input.size()
indexes = torch.zeros(N, C, H, W, device = input.device)
grad_input = torch.zeros(N, C, H, W, device = input.device, dtype = input.dtype)
out_H = 1 + (H-pool_height)//stride
out_W = 1 + (W-pool_width)//stride
out = torch.zeros(N, C, out_H, out_W, device = input.device, dtype = input.dtype)
for n in range(N):
for c in range(C):
for i in range(out_H):
for j in range(out_W):
out[n, c, i, j] = torch.max(input[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width])
tmp = ((input[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] == out[n, c, i, j])*1)
indexes[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] = tmp
grad_input[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] = torch.div(indexes[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] ,indexes[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width].sum())
indexes = indexes.view(-1)
print(grad_input)
ctx.save_for_backward(input,indexes,grad_input)
return out
@staticmethod
def backward(ctx, grad_outputs):
input,indexes,grad_input = ctx.saved_tensors
grad_x, grad_kernel_size, grad_stride = grad_input, None, None
N, C, H, W = grad_outputs.size()
grad_x = torch.zeros(grad_x.size())
for n in range(N):
for c in range(C):
for i in range(H):
for j in range(W):
grad_x[n, c, i*ctx.stride:i*ctx.stride+ctx.pool_height, j*ctx.stride:j*ctx.stride+ctx.pool_width] = grad_input[n, c, i*ctx.stride:i*ctx.stride+ctx.pool_height, j*ctx.stride:j*ctx.stride+ctx.pool_width]*grad_outputs[n, c, i, j]
return grad_x, grad_kernel_size, grad_stride
It’s possible to do the same without re-compute all the loops in the backward function ?
Thanks.