I am interested in implementing max pooling using PyTorch without the nn.MaxPool
functions in an efficient way (i.e. can run on GPU) for the sake of learning. My input is a standard batched tensor of size (N, C, X, X)
, for simplicity I will assume that the size of my stride is equal to to the size of the kernel, which can divide X
.
I am quite stuck; I imagine the unfold
function is a good place to start, but I’m not really sure how to process the output of this function efficiently:
N, C, X, _ = x.shape
u = nn.Unfold(kernel_size, stride=kernel_size)
windows = u(x)
# How do I process windows?
If anyone could shed some light on this it would be great!