Channel Max Pooling

I am trying to replicate a technique from a paper which adds a channel max pooling layer in-between the last max-pooling layer and the first FC layer of the VGG16 model. The paper can be found at

You may not be able to read it but the important lines are:

‘VGG16 with CMP (VGG16-CMP): Similar as DenseNet161-CMP, we applied the CMP operation to the VGG16 by implementing the CMP layer between the last max-pooling layer and the first FC layer. The dimension of the pooled features was changed from 512 × 7 × 7 to c × 7 × 7. All the other components remained unchanged’


CMP is channel max pooling layer
c is the channel number of the feature maps

1)From the statement above, is the pooling layer the newly layer addd or the max pooling layer already present?
2) Is the code below even right?

I have created a custom maxpooling layer and added this to the model but I have no clue if it is what the paper is talking about

import torch.nn as nn
import torch
import torchvision
class ChannelPoolingNetwork(nn.Module):
    def __init__(self, num_classes):

        self.base = torchvision.models.vgg16(pretrained=True, progress=True)   
        self.base.features = nn.Sequential(
            ChannelPool(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.base.classifier = nn.Sequential(
            nn.Linear(25088, 4096),
            nn.Linear(4096, 4096),
            nn.Linear(4096, num_classes)


    def forward(self, x):
        fc = self.base(x)
        return fc
# Code from
class ChannelPool(nn.MaxPool1d):
    def forward(self, input):
        n, c, w, h = input.size()
        input = input.view(n,c,w*h).permute(0,2,1)
        pooled =  nn.MaxPool1d(input)
        _, _, c = input.size()
        input = input.permute(0,2,1)
        return input.view(n,c,w,h)

I’m not sure which paper you are linking to, but could you explain, how channel-max-pooling is supposed to work?
Would it be a topk-operation on the channel-dimension?

CMP does max pooling across the dimensions of the feature map. The image below is a visualisation representation given in the paper

Turns out I made some mistakes in the code above but I think I have it correct now. My main problem is that it takes a long time to process. Is there anyway to speed it up?
The input array has 4 dimensions which are batch_index, channel dimension, kernel weight and height. I have to go through each image(input[x]) and do max pooling across the channels with a kernel size of 7 and stride 2. The input is [32,512,7,7] and have hard-coded these hyper parameters to work on the data.

I don’t know much about tensors but I am sure there is a better way than one monolithic for loop.

class ChannelPool(nn.Module):

    def __init__(self, kernel_size=7, stride=2, padding=3, dilation=1,
                 return_indices=False, ceil_mode=False):
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.dilation = dilation
        self.return_indices = return_indices
        self.ceil_mode = ceil_mode
        self.compression = 2
        self.output = None

    def forward(self, input):

        n, c, w, h = input.size()
        #Add padding to input so work with kernal size
        input = torch.nn.functional.pad(input, (0, 0, 0, 0, self.padding, self.padding), "constant", 0)
        #Get output
        output = torch.stack([ 
                            [torch.max(input[x][index:index+self.kernel_size-1],axis=0)[0] #Get max at each position in  kernal size
                            for index in range(0,input.size()[1]-self.kernel_size,self.stride)]) #Move stride
                            for x in range(n)]) #Do work for each image in batch

        return output.cuda()

I haven’t looked at your code, but it sounds like you could just use a 3D maxpool to do this. Just size the kernel so it covers all channels.


using 3D MaxPool for 2D channel wise Polling

ff = torch.rand(32,256,256,256)
b = torch.nn.MaxPool3d((8,1,1),stride=(8,1,1))
gg = b(ff)
torch.Size([32, 32, 256, 256])