Channel Max Pooling

Turns out I made some mistakes in the code above but I think I have it correct now. My main problem is that it takes a long time to process. Is there anyway to speed it up?
The input array has 4 dimensions which are batch_index, channel dimension, kernel weight and height. I have to go through each image(input[x]) and do max pooling across the channels with a kernel size of 7 and stride 2. The input is [32,512,7,7] and have hard-coded these hyper parameters to work on the data.

I don’t know much about tensors but I am sure there is a better way than one monolithic for loop.

class ChannelPool(nn.Module):

    def __init__(self, kernel_size=7, stride=2, padding=3, dilation=1,
                 return_indices=False, ceil_mode=False):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride or kernel_size
        self.padding = padding
        self.dilation = dilation
        self.return_indices = return_indices
        self.ceil_mode = ceil_mode
        self.compression = 2
        self.output = None


    def forward(self, input):

        n, c, w, h = input.size()
        #Add padding to input so work with kernal size
        input = torch.nn.functional.pad(input, (0, 0, 0, 0, self.padding, self.padding), "constant", 0)
        
        #Get output
        output = torch.stack([ 
                        torch.stack(
                            [torch.max(input[x][index:index+self.kernel_size-1],axis=0)[0] #Get max at each position in  kernal size
                            for index in range(0,input.size()[1]-self.kernel_size,self.stride)]) #Move stride
                            for x in range(n)]) #Do work for each image in batch

        return output.cuda()