Torch.max with tuple axes

I think this has come up a few times but never been directly addressed in a topic, but is there a way to get the maximum values of a tensor along one axis using tuple indexing for the axes?

e.g. using numpy one can do the following to get the maximum values in each 3x3 slice of a datacube called “a”

import numpy as np

a = np.random.uniform(0,1,(5,1,3,3))
max_vals = a.max(axis=(1,2,3))

However, in PyTorch this will break with an exit error line:

import torch 

at = torch.tensor(a)
max_vals = a.max(axis=(1,2,3))

Have I missed a very simple workaround for this problem or does PyTorch not support one-line outputting of maximum values in this way?

Not sure if this is the best way to solve the problem but one can do it recursively:

max_vals = a.max(1)[0].max(1)[0].max(1)[0]

I dont think pytorch.max function supports the tuple functionality. It only works along a single dimension as per the documentation. Might be a good idea to request for this feature.

Really silly question but is there a channel through which I could make a request for a feature like this? :slight_smile:

Sure! Feel free to open a feature request on the Github issues and explain the use case with your sample code a bit. :wink:

