Why use avgpool2d and avgpool3d in local_response_norm?

The following is the formula for local response normalization (across channels):


When I looked at the implementation for the same, I saw an avg pool being done after the input tensor is squared elementwise. But I don’t understand why the average pool of the squared activations is being done when the equation involves only a summation of the squared activations.

def local_response_norm(input, size, alpha=1e-4, beta=0.75, k=1.):
    # type: (Tensor, int, float, float, float) -> Tensor
    r"""Applies local response normalization over an input signal composed of
    several input planes, where channels occupy the second dimension.
    Applies normalization across channels.
    See :class:`~torch.nn.LocalResponseNorm` for details.
    if not torch.jit.is_scripting():
        if type(input) is not Tensor and has_torch_function((input,)):
            return handle_torch_function(
                local_response_norm, (input,), input, size, alpha=alpha, beta=beta, k=k)
    dim = input.dim()
    if dim < 3:
        raise ValueError('Expected 3D or higher dimensionality \
                         input (got {} dimensions)'.format(dim))
    div = input.mul(input).unsqueeze(1)
    if dim == 3:
        div = pad(div, (0, 0, size // 2, (size - 1) // 2))
        div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
        sizes = input.size()
        div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
        div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
        div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
        div = div.view(sizes)
    div = div.mul(alpha).add(k).pow(beta)
    return input / div

As per pytorch docs here, alpha is divided by n (kernel size). So, use of avg_pool makes sense I guess.

1 Like

I see. Yeah, that makes sense. I guess I was confused because I looked at the code in the context of the equation I gave in the original post. Maybe alpha / n is can be treated as some new single constant alpha_prime (equal to alpha in the equation in my question) and that’s why torch uses that formula? Either way, thanks for resolving my doubt!