Global Average Pooling in Pytorch


I am trying to use global average pooling, however I have no idea on how to implement this in pytorch. So global average pooling is described briefly as:

It means that if you have a 3D 8,8,128 tensor at the end of your last convolution, in the traditional method, you flatten it into a 1D vector of size 8x8x128. And you then add one or several fully connected layers and then at the end, a softmax layer that reduces the size to 10 classification categories and applies the softmax operator.

The global average pooling means that you have a 3D 8,8,10 tensor and compute the average over the 8,8 slices, you end up with a 3D tensor of shape 1,1,10 that you reshape into a 1D vector of shape 10. And then you add a softmax operator without any operation in between. The tensor before the average pooling is supposed to have as many channels as your model has classification categories.

So, to me it sounds like in one case you do

tensor = tensor.view(8*8*10, 10)
tensor = self.Linear(tensor)   # size 10
tensor = self.Softmax(tensor)

and, in the other, you do

tensor = self.Conv2d(output_size = 10, kernel_size=1)  #to get [10x8x8] size
tensor = self.GovalAvgPooling(tensor)  #whatever this is , to get [10, 1, 1]
tensor = self.Squeeze_Dims(tensor)  # to just get a vector [10]
tensor = self.Softmax(tensor)

Here are the questions:

  1. Are the above examples correct, keeping in mind the description of global average pooling?
  2. How can I do the global average pooling? Should I use the functional module?
  3. The paper I am trying to reproduce (residual nets) says that:

The network ends with a global average pooling, a 10-way fully-connected layer, and softmax.

But this does not make sense ?? Why do they need the 10-way fc layer?

Questions about global average pooling
(Allen Ye) #2

Is nn.AvgPool2d() what you’re looking for?


That is the pooling layer, yes, but it seems to be more complicated than just using pooling. Or am I wrong?

(Anuvabh) #4

Global average pooling means that you average each feature map separately. In your case if the feature map is of dimension 8 x 8, you average each and obtain a single value. The important part here is that you do the average operation per-channel. You can think of each of the feature maps as the final feature representation per category over which you want to do classification.

To do this you can apply either nn.AvgPool2d or F.avg_pool2d with kernel_size equal to the dimensions of the feature maps (in this case, 8).

The 10-way fc is because there are 10 categories. It’s like you extract features from all the preceeding conv layers and feed them into a linear classifier.


why not use torch.mean to achieve this?

(Ywu36) #6

torch.mean works on one dimension instead of all three dimensions.

(Ywu36) #7

Try this:

F.avg_pool3d(tensor, kernel_size=input.size()[2:]).view(input.size()[0],-1)

(Indrayana Rustandi) #8

Just a note, the SqueezeNet architecture (available in PyTorch model zoo) uses global average pooling. Here’s global average pooling as implemented there:

final_conv = nn.Conv2d(512, self.num_classes, kernel_size=1)
self.classifier = nn.Sequential(

512 is the number of channels in the feature maps feeding in to this layer, and 13 is the number of rows and columns in the feature maps going in to this layer. You’ll need to change these depending on your network structure.

(Wangchust) #9

x = nn.avg_pool2d(x, x.size()[2:]) works fine when x.shape=N * C * H * W

(jdhao) #10

Another way to do global average pooling for each feature map is to use torch.mean as suggested by @Soumith_Chintala, but we need to flatten each feature map into to vector. The following snippet illustrates the idea,

# suppose x is your feature map with size N*C*H*W
x = torch.mean(x.view(x.size(0), x.size(1), -1), dim=2)
# now x is of size N*C

Also you can use adaptive_avg_pool2d to achieve global average pooling, just set the output size to (1, 1),

import torch.nn.functional as F
x = F.adaptive_avg_pool2d(x, (1, 1))

(Lin Eric) #11

use nn.AdaptiveMaxPool2d,

(Ihor Menshykov) #12

Did anyone make any benchmarks? I’m guessing mean was probably the fastest?