I’m extracting patches out of a `Tensor`

with the `unfold`

method. Consider the following snippet:

```
import torch
torch.manual_seed(0)
numel = 10
length = 3
stride = 2
input = torch.ones(numel,).requires_grad_(True)
target = torch.zeros(numel,)
input_patches = input.unfold(0, length, stride)
target_patches = target.unfold(0, length, stride)
```

In the subsequent processing I calculate some loss between the `input_patches`

and `target_patches`

:

```
loss = torch.sum((input_patches - target_patches) ** 2.0)
loss.backward()
print(input.grad)
```

With this some elements of `input`

receive a higher `grad`

ient than others, since they are included in multiple patches:

```
tensor([2., 2., 4., 2., 4., 2., 4., 2., 2., 0.])
```

Within my application this is unwanted. I can calculate the number of times each element was used as follows:

```
counts = torch.zeros(numel)
count_patch = torch.ones(length)
for idx in range(0, numel - length + 1, stride):
counts[idx:idx+length].add_(count_patch)
print(counts)
print(input.grad / torch.clamp(counts, min=1.0)) # clamp to avoid zero division
```

This results in ‘stable’ `grad`

ients

```
tensor([1., 1., 2., 1., 2., 1., 2., 1., 1., 0.])
tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 0.])
```

which is what I need.

- While this example is in 1D I actually need 2D data. I have not tested it yet, but I imagine the needed double
`for`

loop is costly. Is there a better / faster / more efficient way to achieve the same result with builtins from`torch`

? - My idea to implement this properly is to define an
`autograd.Function`

, as can be seen below, and apply this to the input before I extract the patches. Do you think this is a proper way to do this?

```
class _PatchNormalizer(torch.autograd.Function):
@staticmethod
def forward(ctx, input, length, stride):
counts = _PatchNormalizer.get_patch_counts(input, length, stride)
ctx.save_for_backward(counts)
return input
@staticmethod
def backward(ctx, grad_output):
counts, = ctx.saved_tensors
grad_input = grad_output / counts
return grad_input, None, None
@staticmethod
def get_patch_counts(input, length, stride):
counts = torch.zeros_like(input)
count_patch = torch.ones(length)
for idx in range(0, input.size()[0] - length + 1, stride):
counts[idx:idx + length].add_(count_patch)
return torch.clamp(counts, min=1.0)
```