Normalize gradient after unfolding

I’m extracting patches out of a Tensor with the unfold method. Consider the following snippet:

import torch


numel = 10
length = 3
stride = 2

input = torch.ones(numel,).requires_grad_(True)
target = torch.zeros(numel,)

input_patches = input.unfold(0, length, stride)
target_patches = target.unfold(0, length, stride)

In the subsequent processing I calculate some loss between the input_patches and target_patches:

loss = torch.sum((input_patches - target_patches) ** 2.0)

With this some elements of input receive a higher gradient than others, since they are included in multiple patches:

tensor([2., 2., 4., 2., 4., 2., 4., 2., 2., 0.])

Within my application this is unwanted. I can calculate the number of times each element was used as follows:

counts = torch.zeros(numel)
count_patch = torch.ones(length)

for idx in range(0, numel - length + 1, stride):

print(input.grad / torch.clamp(counts, min=1.0)) # clamp to avoid zero division

This results in ‘stable’ gradients

tensor([1., 1., 2., 1., 2., 1., 2., 1., 1., 0.])
tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 0.])

which is what I need.

  1. While this example is in 1D I actually need 2D data. I have not tested it yet, but I imagine the needed double for loop is costly. Is there a better / faster / more efficient way to achieve the same result with builtins from torch?
  2. My idea to implement this properly is to define an autograd.Function, as can be seen below, and apply this to the input before I extract the patches. Do you think this is a proper way to do this?
class _PatchNormalizer(torch.autograd.Function):
    def forward(ctx, input, length, stride):
        counts = _PatchNormalizer.get_patch_counts(input, length, stride)
        return input

    def backward(ctx, grad_output):
        counts, = ctx.saved_tensors
        grad_input = grad_output / counts
        return grad_input, None, None

    def get_patch_counts(input, length, stride):
        counts = torch.zeros_like(input)
        count_patch = torch.ones(length)

        for idx in range(0, input.size()[0] - length + 1, stride):
            counts[idx:idx + length].add_(count_patch)

        return torch.clamp(counts, min=1.0)