I am using torch.max on a torch tensor along a specific direction. When I try to use the result, I get myriad errors. I am able to display the contents of the result (see code snippet), but not able to use it as a tensor: as a mathematical operation or to display its size (errors below). However, when I do the same operation on the entire tensor to find the max scalar value, I am able to use it. The same holds for torch.min().
Would appreciate any insight into this behavior - how do I get a tensor which has the max/min value of another tensor in a given dimension? I don’t want to convert the tensor to numpy and do np.max, since I plan to use this operation inside a loss function and therefore need it to be tracked for gradients.
>>> import torch
>>> a = torch.rand((20,3,128,128,128))
>>> maxval= torch.max(a, 1)
>>> maxval[:5] # [0.8795, 0.9569, 0.5381, ..., 0.5498, 0.8159, 0.8428],
# [0.8837, 0.8009, 0.6006, ..., 0.3414, 0.9229, 0.6836],
# [0.9956, 0.5561, 0.8130, ..., 0.7098, 0.7955, 0.6614]]],
>>> maxval.size()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'torch.return_types.max' object has no attribute 'size'
>>> a_scaled = a - maxval
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: sub(): argument 'other' (position 1) must be Tensor, not torch.return_types.max
>>> maxscalar= torch.max(a)
>>> maxscalar
tensor(1.0000)
>>> maxscalar.size()
torch.Size([])