I am interested in implementing max pooling using PyTorch without the `nn.MaxPool`

functions in an efficient way (i.e. can run on GPU) for the sake of learning. My input is a standard batched tensor of size `(N, C, X, X)`

, for simplicity I will assume that the size of my stride is equal to to the size of the kernel, which can divide `X`

.

I am quite stuck; I imagine the `unfold`

function is a good place to start, but I’m not really sure how to process the output of this function efficiently:

```
N, C, X, _ = x.shape
u = nn.Unfold(kernel_size, stride=kernel_size)
windows = u(x)
# How do I process windows?
```

If anyone could shed some light on this it would be great!