I think this has come up a few times but never been directly addressed in a topic, but is there a way to get the maximum values of a tensor along one axis using tuple indexing for the axes?
(If this has come up as a full topic before, please let me know and I’ll edit the topic accordingly)
e.g. using numpy one can do the following to get the maximum values in each 3x3 slice of a datacube called “a”
import numpy as np
a = np.random.uniform(0,1,(5,1,3,3))
max_vals = a.max(axis=(1,2,3))
However, in PyTorch this will break with an exit error line:
import torch
at = torch.tensor(a)
max_vals = a.max(axis=(1,2,3))
Have I missed a very simple workaround for this problem or does PyTorch not support one-line outputting of maximum values in this way?
Many thanks in advanced!