Hello, I have a problem with torch.max operation.
When I used torch.max with 1x1x1xn tensor, like
import torch a = torch.randn(1, 1, 1, 5) a.max(0)
It produces a tuple of two 1x1x1xn tensor. (One for maximum value, and another for indices)
However, if i try the same operation with n = 1, like
import torch b = torch.randn(1, 1, 1, 1) b.max(0)
torch.max produces a tuple of two 1x1x1 tensor.
On the other hand
import torch c = torch.randn(2, 1, 1, 1) c.max(0)
it produces a tuple of two 1x1x1 tensor.
I think b and c work correctly, and
should produce a tuple of 1x1xn tensor. What’s wrong with it?