This will at the very end give me the max element in 3D tensor.
Is there a better way?
values = torch.randn(32, 32, 32)
values, indices = values.max(0)
print(values, indices)
values, indices = values.max(0)
print(values, indices)
values, indices = values.max(0)
print(values, indices) #tensor(4.0173) tensor(31)