[Autodiff] Custom autograd Maxpooling2d

Hi, I’m trying to custom the Maxpool2d autograd using module torch.autograd.Function.

I want to change the backward output. For my gradient and I want to keep all the maximum indexes and take the mean of this values, not only the first one index. At the end, I have to compute the product with my_grad and the grad_outputs.

This is my code :

class myMaxpool(torch.autograd.Function):
This autograd function implements a MaxPooling function.

def forward(ctx, input, kernel_size, stride):

    out = None
    pool_height, pool_width = kernel_size
    ctx.pool_height, ctx.pool_width = pool_height, pool_width

    ctx.stride = stride

    N, C, H, W = input.size()
    indexes = torch.zeros(N, C, H, W, device = input.device)
    grad_input = torch.zeros(N, C, H, W, device = input.device, dtype = input.dtype)

    out_H = 1 + (H-pool_height)//stride
    out_W = 1 + (W-pool_width)//stride
    out = torch.zeros(N, C, out_H, out_W, device = input.device, dtype = input.dtype)

    for n in range(N):
        for c in range(C):
            for i in range(out_H):
                for j in range(out_W):
                    out[n, c, i, j] = torch.max(input[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width])
                    tmp = ((input[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] == out[n, c, i, j])*1)
                    indexes[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] = tmp
                    grad_input[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] = torch.div(indexes[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width] ,indexes[n, c, i*stride:i*stride+pool_height, j*stride:j*stride+pool_width].sum())
    indexes =  indexes.view(-1)
    return out

def backward(ctx, grad_outputs):

    input,indexes,grad_input = ctx.saved_tensors

    grad_x, grad_kernel_size, grad_stride = grad_input, None, None
    N, C, H, W = grad_outputs.size()
    grad_x = torch.zeros(grad_x.size())

    for n in range(N):
        for c in range(C):
            for i in range(H):
                for j in range(W):
                    grad_x[n, c, i*ctx.stride:i*ctx.stride+ctx.pool_height, j*ctx.stride:j*ctx.stride+ctx.pool_width] = grad_input[n, c, i*ctx.stride:i*ctx.stride+ctx.pool_height, j*ctx.stride:j*ctx.stride+ctx.pool_width]*grad_outputs[n, c, i, j]

    return grad_x, grad_kernel_size, grad_stride

It’s possible to do the same without re-compute all the loops in the backward function ?