Compute maxima and minima of a 4D tensor

Suppose that we have a 4-dimensional tensor, for instance

import torch
X = torch.rand(2, 3, 4, 4)                                                                                                                                                                
tensor([[[[-0.9951,  1.6668,  1.3140,  1.4274],
          [ 0.2614,  2.6442, -0.3041,  0.7337],
          [-1.2690,  0.0125, -0.3885,  0.0535],
          [ 1.5270, -0.1186, -0.4458,  0.1389]],

         [[ 0.9125, -1.2998, -0.4277, -0.2688],
          [-1.6917, -0.8855, -0.2784, -0.6717],
          [ 1.1417,  0.4574,  0.4803, -1.6637],
          [ 0.7322,  0.2654, -0.1525,  1.7285]],

         [[ 1.8310, -1.5765,  0.1392,  1.3431],
          [-0.6641, -1.5090, -0.4893, -1.4110],
          [ 0.5875,  0.7528, -0.6482, -0.2547],
          [-2.3133,  0.3888,  2.1428,  0.2331]]]])

I want to compute the maximum and the minimum values of X over the dimensions 2 and 3, that is, to compute two tensors of size (2,3,1,1), one for the maximum and one for the minimum values of the 4x4 blocks.

I started by trying to do that with torch.max() and torch.min(), but I had no luck. I would expect the dim argument of the above functions to be able to take tuple values, but it can take only an integer. So I don’t know how to proceed.

However, specifically for the maximum values, I decided to use torch.nn.MaxPool2d() with kernel_size=4 and stride=4. This indeed did the job:

max_pool = nn.MaxPool2d(kernel_size=4, stride=4)
X_max = max_pool(X)                                                                                                                                                                  
tensor([[[[2.6442]],
         [[1.7285]],
         [[2.1428]]]])

But, afaik, there’s no similar layer for “min”-pooling. Could you please help me on how to compute the minima similarly to the maxima?

Thank you.

Based on this answer on stackoverflow:

X_max = X.clone()
X_min = X.clone()
for dim in (2, 3):
    X_max = torch.max(input=X_max, dim=dim, keepdim=True)[0]
    X_min = torch.min(input=X_min, dim=dim, keepdim=True)[0]