What is AdaptiveAvgPool2d?

The AdaptiveAvgPool2d layers confuse me a lot.
Is there any math formula explaning it?

3 Likes

Well, the specified output size is the output size, as in the documentation.

In more detail:
What happens is that the pooling stencil size (aka kernel size) is determined to be (input_size+target_size-1) // target_size, i.e. rounded up. With this Then the positions of where to apply the stencil are computed as rounded equidistant points between 0 and input_size - stencil_size.
Let’s have a 1d example:
Say you have an input size of 14 and a target size of 4. Then the stencil size is 4.
The four equidistant points would be 0, 3.3333, 6.6666, 10 and get rounded to 0, 3, 7, 10. And so the four items would be the mean of the slices 0:4, 3:7, 7:11, 10:14 (in Python manner, so including lower bound, excluding upper bound). You see that the first two and last two slices overlap by one. Something like - occasional overlaps of 1 - this will generally be the case when the input size is not divisible by the target size.
For experimentation, you could use arange and backward to see what happens. In the above toy example:

a = torch.arange(0,14., requires_grad=True)
b = torch.nn.functional.adaptive_avg_pool1d(a[None, None], 4)
b.backward(torch.arange(1., 1+b.size(-1))[None,None])
print (b, a.grad)

Then b is 1.5, 4.5, 8.5, 11.5 just as you would get from slicing as above and taking the mean.
The gradient a.grad shows the “receptive field of each output”:
0.2500, 0.2500, 0.2500, 0.7500, 0.5000, 0.5000, 0.5000, 0.7500, 0.7500, 0.7500, 1.7500, 1.0000, 1.0000, 1.0000
again, you see the overlap at item 3 and 10.

Best regards

Thomas

39 Likes

Thank you very much.It’s very detailed explanation that give me a lot of help.

By the way,if I meet similar situation in which I want to know the math detail about some layers,where should I find it?Before questioning this AdaptiveAvgPool2d problem here,I search a lot on line and official documents too.But none of them could help me.

1 Like

You can always check the code or do experiments, but I’d venture that asking on the forums is also a fairly reliable way to get good answers. That is because of two things:

  • You really get the best experts to answer here (well, unless it’s me writing the answer). My main reason to skim the forums these days is that I’ll learn a ton from the answers.
  • The forum also is part of the eyes and ears of the developers (there are quite a few core people here in general the connection is very close).
    So if it turns out there is a frequent question, people will see to amending the docs. If you happen to hit a bug, it’ll get picked up, filed, and fixed. If your use case would benefit from additional features and it’s common enough to warrant addition, it’ll also be noticed as a feature request.

Best regards

Thomas

14 Likes

Think you very much!
It’s my honor to have the chance to get your help.

1 Like

Hi Thomas, thanks a lot for your help.
My confusion has been cleared.

Kind regards,
Haimin

1 Like

Rounded down, right?

1 Like

No this expression will round up input_size / target_size due to the “+target_size-1” in the numerator.

Best regards

Thomas

2 Likes

I am still confused about how to compute 0, 3.3333, 6.6666, 10? hope your reply,thanks!

I am also confused about it,have you understood it?

To get 4 (= target_size) sections with same length (4 = kernel_size) from the array of length 14 (= input_size) allowing some overlaps, we first separate the right most section: index of 10, 11, 12, and 13. Then we divide index 0 to 10 into 3 (4 - 1).

Here’s the numpy snippet for Thomas’s example.

In [1]: import numpy as np

In [2]: input_size = 14

In [3]: output_size = 4

In [4]: kernel_size = (input_size + output_size - 1) // output_size

In [5]: kernel_size
Out[5]: 4

In [6]: np.linspace(0, input_size - kernel_size, output_size)
Out[6]: array([ 0.        ,  3.33333333,  6.66666667, 10.        ])
5 Likes

To understand this better, I tried implementing my own version of it:

def torch_pool(inputs, target_size):
    kernel_size = (inputs.shape[-1] + target_size - 1) // target_size
    points_float = torch.linspace(0, inputs.shape[-1] - kernel_size, target_size)
    points = torch.cat([torch.squeeze(torch.round(points_float)).int(), torch.tensor([inputs.shape[-1]], dtype=torch.int32)], 0)
    pooled = []
    for idx in range(points.shape[0] - 1):
        pooled.append(torch.mean(inputs[:, :, points[idx]:points[idx + 1]], dim=-1, keepdim=False))
    pooled = torch.cat(pooled, -1)
    return pooled


inps = np.array([0, 1, 2, 3, 4, 5, 6], dtype=np.float32)[None, :, None]
inps_torch = np.transpose(inps, (0, 2, 1))
x = torch_pool(torch.tensor(inps_torch), 4)
print(x)

This gives the following output:

tensor([[0.5000, 2.0000, 3.5000, 5.5000]])

However, the actual function gives:

tensor([[[0.5000, 2.0000, 4.0000, 5.5000]]])

Can anyone help me out? Where did I go wrong?

1 Like

First off, kudos for doing your own experiment and reporting the discrepancy here!

The explanation above is imprecise relative to (todays?) implementation in PyTorch. PyTorch actually uses floor and ceiling to compute start and end points instead of mathematical rounding, so the formulas should be

def torch_pool(inputs, target_size):
    start_points = (torch.arange(target_size, dtype=torch.float32) * (inputs.size(-1) / target_size)).long()
    end_points = ((torch.arange(target_size, dtype=torch.float32)+1) * (inputs.size(-1) / target_size)).ceil().long()
    pooled = []
    for idx in range(target_size):
        pooled.append(torch.mean(inputs[:, :, start_points[idx]:end_points[idx]], dim=-1, keepdim=False))
    pooled = torch.cat(pooled, -1)
    return pooled


inps = np.array([0, 1, 2, 3, 4, 5, 6], dtype=np.float32)[None, :, None]
inps_torch = np.transpose(inps, (0, 2, 1))
x = torch_pool(torch.tensor(inps_torch), 4)
print(x)

Best regards

Thomas

3 Likes

Thank you for the clarification!

Hello sir. Good time
excuse me.
Does using nn.AdaptiveAvgPool2d((1,1)) make sense?
In this case, is the filter size equal to the input size?
Why is this used?
What happens in this case?

Double post from here with follow-up.

Sorry to disturb you, if the results of the np.linspace is [0, 3, 4.5, 6], the 4.5 should be the 4 or the 5?

Anything that works for you. :slight_smile: I’m always happy when you take bits I post here to do your own experiments and improvements.

1 Like