Turns out I made some mistakes in the code above but I think I have it correct now. My main problem is that it takes a long time to process. Is there anyway to speed it up?
The input array has 4 dimensions which are batch_index, channel dimension, kernel weight and height. I have to go through each image(input[x]) and do max pooling across the channels with a kernel size of 7 and stride 2. The input is [32,512,7,7] and have hard-coded these hyper parameters to work on the data.
I don’t know much about tensors but I am sure there is a better way than one monolithic for loop.
class ChannelPool(nn.Module):
def __init__(self, kernel_size=7, stride=2, padding=3, dilation=1,
return_indices=False, ceil_mode=False):
super().__init__()
self.kernel_size = kernel_size
self.stride = stride or kernel_size
self.padding = padding
self.dilation = dilation
self.return_indices = return_indices
self.ceil_mode = ceil_mode
self.compression = 2
self.output = None
def forward(self, input):
n, c, w, h = input.size()
#Add padding to input so work with kernal size
input = torch.nn.functional.pad(input, (0, 0, 0, 0, self.padding, self.padding), "constant", 0)
#Get output
output = torch.stack([
torch.stack(
[torch.max(input[x][index:index+self.kernel_size-1],axis=0)[0] #Get max at each position in kernal size
for index in range(0,input.size()[1]-self.kernel_size,self.stride)]) #Move stride
for x in range(n)]) #Do work for each image in batch
return output.cuda()