I am wondering about how am I supposed to use GAP (global average pooling) in some CNN/ViT model?
As far as I understand the working principle of GAP - it is the per-channel feature map aggregator (by averaging) and thus avoids the use of FC layer at the end. Where each channel aggregation corresponds to the class category. For example, I have (16, 256, 64, 64):(Batch, Channel, Width, Height) tensor of information before transition to the GAP. The GAP layer will average each of 256 64-by-64 maps into (16, 256, 1, 1) tensor in the end.
So my question is:
For pre-trained ImageNet models: ResNet, ViT, etc. what is the correct way to put GAP layer into these models? I mean, all of them have final FC (1000, 1) layer in the end, there is no other channels and feature maps to be aggregated.
I will answer my own question. I hope someone will find it useful.
In fact, I was right about last FC layers of the ImageNet pre-trained models. They do only have 1000 nodes in the output, thus making it impossible to use GAP, BN and other stuff involving multiple channels.
Assuming for this example we use PyTorch based timm library, I should have done the following:
model = timm.create_model('eca_nfnet_l1', pretrained=True)
Code above initializes pre-trained “NFNet” model, this model has different main sections within:
First three (stem, stages, final_conv) are considered as the “feature extractors” pipeline and last (head) is the final FC layer in the end. These are used to abstract the information flow for any particular purpose, in my case it was: add GAP and BN layers to some custom model.
Thus, in order to customize your model with additional layers, you should interact with “feature extractors” which output the deep layered feature maps. These maps can be used further for in your custom layers (not-trained).
a = model.stem(imgTensor)
a = model.stages(a)
a = model.final_conv(a)
### prints: "torch.Size([1, 3072, 13, 13])"
This last tensor of shape
[1, 3072, 13, 13] is one we can use in further processing in your custom layers.