Current torch.min() does not support multiple dimensions?

Let say I have a 3 dimensional tensor.

a = torch.arange(8).reshape(2, 2, 2)

>>> a
tensor(
[[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])

I want to compute the minimum over the dimension 1 and 2, to get

tensor([0, 4])

However, torch.min() does not support multiple dimensions. In contrast, torch.sum() does support? Is there any particular reason for that?

So for this specific case, currently the most efficient way to take minimum over multiple dimensions is to first reshape the tesnor

a.reshape(2, 2 * 2)

and then take the minimum over the second axis?

1 Like

Hi @Kevin96,

You should do:

a.view(-1,2*2).min(axis=1)

where view guarantees no memory copy, and almost no overhead.

5 Likes

HI kaiwen,

not very memory-efficient but “easy”?

def mul_min(x, axis, keepdim=False):
  axis = reversed(sorted(axis))
  min_x = x
  for i in axis:
    min_x, _ = min_x.min(i, keepdim)
  return min_x