Min / Max across multiple axes

Hi Everyone,
I’m trying to get the min and max of each image in a batch of images (NCHW) format.
From what I could find, torch.min's dim argument takes only one dimension.

Any way to make it take multiple axes?

2 Likes

I am also interested in a function like this. Functions like mean(), median(), and mode() can take in a list of dimensions to reduce; is functionality for other functions like min() and max() possible?

My use case is that I want to quantize a weight tensor; if I am interested in quantizing using constants from the whole tensor, a step along the way is to calculate the min and max of the whole tensor, which min() and max() do accomplish. However, I want to compare this against channel-wise quantization (with different constants for zero point, scale, etc), and it would be great if min() and max() produced the right outputs given a list of dimensions to reduce.

I just manage to handle the problem by using two min functions like this!

In [1]: x = torch.rand(2,3,4,4)
Out[1]:
tensor([[[[0.6575, 0.9134, 0.7617, 0.1276],
          [0.3552, 0.0012, 0.0331, 0.4121],
          [0.4304, 0.8518, 0.0515, 0.4528],
          [0.7224, 0.2117, 0.7193, 0.9164]],

         [[0.9737, 0.2632, 0.4182, 0.8434],
          [0.7079, 0.3799, 0.3052, 0.8838],
          [0.9668, 0.6371, 0.7326, 0.7869],
          [0.9386, 0.0949, 0.7356, 0.3458]],

         [[0.3471, 0.2736, 0.3536, 0.4553],
          [0.4408, 0.3221, 0.1848, 0.0352],
          [0.5171, 0.1303, 0.2172, 0.4764],
          [0.3724, 0.6341, 0.7000, 0.4547]]],


        [[[0.5591, 0.5444, 0.0419, 0.1762],
          [0.9670, 0.7052, 0.4980, 0.7963],
          [0.3407, 0.5290, 0.1264, 0.7158],
          [0.9510, 0.1153, 0.8766, 0.1305]],

         [[0.9847, 0.3020, 0.2368, 0.5934],
          [0.6701, 0.9279, 0.1959, 0.8628],
          [0.9260, 0.6380, 0.1120, 0.3395],
          [0.0412, 0.0786, 0.2994, 0.5483]],

         [[0.7809, 0.1019, 0.6274, 0.9994],
          [0.9357, 0.2192, 0.6914, 0.3103],
          [0.3666, 0.3832, 0.2777, 0.1814],
          [0.4448, 0.6855, 0.3532, 0.9655]]]])
In [2]: x.min(2)[0].min(2)[0]
Out[2]: 
tensor([[0.0012, 0.0949, 0.0352],
        [0.0419, 0.0412, 0.1019]])