How to save memory in this case with torch.max

I have a tensor x in shape of [100, 64, 256, 256]. x = [x0, x1, ..., x99], so each x_i is [64, 256, 256].

What I want is to get maximum tensor of the total 100 x_i for each i except itself.

E.g. for x0, I would like to get torch.max((x1, x2, x3, ..., x99)) and for x1, I would like to get ‘torch.max((x0, x2, x3, …, x99))’

How to write this in memory efficiency way?

The way I do is fist get an index tensor [100, 99].

index = torch.tensor(
[[1, 2, 3, 4, 5, ..., 99],
 [0, 2, 3, 4, 5, ..., 99],
 [0, 1, 3, 4, 5, ..., 99],
 ...
 [0, 1, 2, 3, 4, ..., 98]]
)

then, out = torch.max(x[index], 1)

BUT, this may cost so much memory, I think the problem is x[index] this thing is too big. Especially I need gradient to do back-propagation.

Is there any other better way to do this?

You could run a loop over the 100 entries in the first dimension, e.g.:

import torch

#x = torch.rand(100, 64, 256, 256)
x = torch.rand(100, 64, 3, 3) # reduced for testing

def max_exclusive(x, i):
    a,b = x[:i], x[i+1:]
    if a.nelement() == 0:
        return torch.max(b, 0)[0]
    if b.nelement() == 0:
        return torch.max(a, 0)[0]
    return torch.max(torch.max(a, 0)[0], torch.max(b, 0)[0])

#y = torch.cat([max_exclusive(x, i).unsqueeze(0) for i in range(x.size(0))])

# or manually if you do not want to store intermediate results:
y = x.new_empty(x.size())
for i in range(x.size(0)):
    y[i] = max_exclusive(x, i)

print(y.size())