Average pool an image without decreasing the size?

For a given n x n tensor and window size k, is there a quick way to set every k x k window to the average value of the window in the original tensor? This would basically be average pooling, but without resizing, this would effectively just blur the image.

Thanks

Hi,

If you use average pooling with a kernel of k (must be odd) and a padding of floor(k/2). Then you will maintain the image size.

I wasn’t clear enough in my first post. What I mean is the following:

Given a 4 x 4 image [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] and a window size of 2, what I’d like to get is the following:

[[2.5, 2.5, 4.5, 4.5], [2.5, 2.5, 4.5, 4.5], [10.5, 10.5, 12.5, 12.5], [10.5, 10.5, 12.5, 12.5]]

Thanks

Ho,

I’m not sure there is a very clean way to do this but the following will work:

import torch

kern_size = 2

# Add 0th dimension because avg_pool expect batch dimension
inp = torch.arange(16).float().view(1, 4, 4)
pooled = torch.nn.functional.avg_pool2d(inp, kern_size)
# remove 0th dimension
pooled.squeeze_(0)
rep_col = []
for c in range(pooled.size(1)):
  col = pooled.select(1, c)
  rep_col += [col] * kern_size
repeated = torch.stack(rep_col, 1)
rep_row = []
for r in range(repeated.size(0)):
  row = repeated.select(0, r)
  rep_row += [row] * kern_size
output = torch.stack(rep_row, 0)

print(output)

Thanks a lot for this

Another solution would be to average pool, and then use nn.UpsamplingNearest2d()

This upsampling is a Kronecker product with a matrix of ones. Unfortunately, torch does not have an implementation of it. If you are not dependent on the GPU it might be faster to do it with numpy:

import numpy as np
import torch
from torch.nn.functional import avg_pool2d

image = torch.arange(16, dtype=torch.float).view(4, 4)

kernel_size = 2
stride = 2

image_pooled = image.view(1, 1, *image.size())
image_pooled = avg_pool2d(image_pooled, kernel_size=kernel_size, stride=stride)
image_pooled = image_pooled.squeeze()

image_blured = image_pooled.numpy()
image_blured = np.kron(image_blured, np.ones(image_blured.shape))
image_blured = torch.from_numpy(image_blured)

print(image_blured)

Also you cannot use this if you need gradients !

You are right, I totally forgot about that.