Torch.max with tuple axes

I think this has come up a few times but never been directly addressed in a topic, but is there a way to get the maximum values of a tensor along one axis using tuple indexing for the axes?

(If this has come up as a full topic before, please let me know and I’ll edit the topic accordingly)

e.g. using numpy one can do the following to get the maximum values in each 3x3 slice of a datacube called “a”

import numpy as np

a = np.random.uniform(0,1,(5,1,3,3))
max_vals = a.max(axis=(1,2,3))

However, in PyTorch this will break with an exit error line:

import torch 

at = torch.tensor(a)
max_vals = a.max(axis=(1,2,3))

Have I missed a very simple workaround for this problem or does PyTorch not support one-line outputting of maximum values in this way?

Many thanks in advanced!

Not sure if this is the best way to solve the problem but one can do it recursively:

max_vals = a.max(1)[0].max(1)[0].max(1)[0]

I dont think pytorch.max function supports the tuple functionality. It only works along a single dimension as per the documentation. Might be a good idea to request for this feature.

Really silly question but is there a channel through which I could make a request for a feature like this? :slight_smile:

Sure! Feel free to open a feature request on the Github issues and explain the use case with your sample code a bit. :wink:

1 Like