Vectorized Max Pooling without nn module

I am interested in implementing max pooling using PyTorch without the nn.MaxPool functions in an efficient way (i.e. can run on GPU) for the sake of learning. My input is a standard batched tensor of size (N, C, X, X), for simplicity I will assume that the size of my stride is equal to to the size of the kernel, which can divide X.

I am quite stuck; I imagine the unfold function is a good place to start, but I’m not really sure how to process the output of this function efficiently:

N, C, X, _ = x.shape
u = nn.Unfold(kernel_size, stride=kernel_size)
windows = u(x)

# How do I process windows?

If anyone could shed some light on this it would be great!

@fmassa has created an example of max pooling using unfole here.