Hello, I have a problem with torch.max operation.
When I used torch.max with 1x1x1xn tensor, like
import torch
a = torch.randn(1, 1, 1, 5)
a.max(0)
It produces a tuple of two 1x1x1xn tensor. (One for maximum value, and another for indices)
However, if i try the same operation with n = 1, like
import torch
b = torch.randn(1, 1, 1, 1)
b.max(0)
torch.max produces a tuple of two 1x1x1 tensor.
On the other hand
import torch
c = torch.randn(2, 1, 1, 1)
c.max(0)
it produces a tuple of two 1x1x1 tensor.
I think b and c work correctly, and
a.max(0)
should produce a tuple of 1x1xn tensor. What’s wrong with it?