Elegant way to average values from overlapping areas of transpose convolution?

Hi everyone,

I need to reconstruct an image tensor from its patches and a one-hot tensor by transpose convolution.

Let’s say the original image tensor, I, is 1x1x4x4 large:
(0 ,0 ,.,.) =
0 1 2 3
4 5 6 7
8 9 10 11
12 13 14 15

Its patches P (size 3x3, stride 1):
(0 ,0 ,.,.) =
0 1 2
4 5 6
8 9 10

(1 ,0 ,.,.) =
1 2 3
5 6 7
9 10 11

(2 ,0 ,.,.) =
4 5 6
8 9 10
12 13 14

(3 ,0 ,.,.) =
5 6 7
9 10 11
13 14 15

And I have a one-hot tensor K corresponding to P:
(0 ,0 ,.,.) =
1 0
0 0

(0 ,1 ,.,.) =
0 1
0 0

(0 ,2 ,.,.) =
0 0
1 0

(0 ,3 ,.,.) =
0 0
0 1

Transpose convolving K and P (as filter) reconstructs I’, but it also sums up values from overlapping patches:
(0 ,0 ,.,.) =
0 2 4 3
8 20 24 14
16 36 40 22
12 26 28 15

To work around this, I record the overlapping times in a tensor O with the same size of I when making patches using the helper function below, and at the end perform elementwise division I’ = I’/O.

def compute_overlaps(tensor, patch_size=(3, 3), patch_stride=(1, 1)):
    n, c, h, w = tensor.size()
    px, py = patch_size
    sx, sy = patch_stride
    nx = ((w-px)//sx)+1
    ny = ((h-py)//sy)+1

    overlaps = torch.zeros(tensor.size()).type_as(tensor.data)
    for i in range(ny):
        for j in range(nx):
            overlaps[:, :, i*sy:i*sy+py, j*sx:j*sx+px] += 1
    overlaps = Variable(overlaps)
    return overlaps

But I am wondering if averaging could be done in an efficient way?

look at torch.unfold (it is an autograd supported operation).
You can extract patches from an image using that, and then convolve on the extracted patches.
In the backward, it automatically computes the correct gradients.

For example see https://github.com/pytorch/pytorch/pull/1523#issue-227526015