I’m extracting patches out of a Tensor
with the unfold
method. Consider the following snippet:
import torch
torch.manual_seed(0)
numel = 10
length = 3
stride = 2
input = torch.ones(numel,).requires_grad_(True)
target = torch.zeros(numel,)
input_patches = input.unfold(0, length, stride)
target_patches = target.unfold(0, length, stride)
In the subsequent processing I calculate some loss between the input_patches
and target_patches
:
loss = torch.sum((input_patches - target_patches) ** 2.0)
loss.backward()
print(input.grad)
With this some elements of input
receive a higher grad
ient than others, since they are included in multiple patches:
tensor([2., 2., 4., 2., 4., 2., 4., 2., 2., 0.])
Within my application this is unwanted. I can calculate the number of times each element was used as follows:
counts = torch.zeros(numel)
count_patch = torch.ones(length)
for idx in range(0, numel - length + 1, stride):
counts[idx:idx+length].add_(count_patch)
print(counts)
print(input.grad / torch.clamp(counts, min=1.0)) # clamp to avoid zero division
This results in ‘stable’ grad
ients
tensor([1., 1., 2., 1., 2., 1., 2., 1., 1., 0.])
tensor([2., 2., 2., 2., 2., 2., 2., 2., 2., 0.])
which is what I need.
- While this example is in 1D I actually need 2D data. I have not tested it yet, but I imagine the needed double
for
loop is costly. Is there a better / faster / more efficient way to achieve the same result with builtins fromtorch
? - My idea to implement this properly is to define an
autograd.Function
, as can be seen below, and apply this to the input before I extract the patches. Do you think this is a proper way to do this?
class _PatchNormalizer(torch.autograd.Function):
@staticmethod
def forward(ctx, input, length, stride):
counts = _PatchNormalizer.get_patch_counts(input, length, stride)
ctx.save_for_backward(counts)
return input
@staticmethod
def backward(ctx, grad_output):
counts, = ctx.saved_tensors
grad_input = grad_output / counts
return grad_input, None, None
@staticmethod
def get_patch_counts(input, length, stride):
counts = torch.zeros_like(input)
count_patch = torch.ones(length)
for idx in range(0, input.size()[0] - length + 1, stride):
counts[idx:idx + length].add_(count_patch)
return torch.clamp(counts, min=1.0)