I have a tensor x
in shape of [100, 64, 256, 256]. x = [x0, x1, ..., x99]
, so each x_i
is [64, 256, 256].
What I want is to get maximum tensor of the total 100 x_i for each i except itself.
E.g. for x0
, I would like to get torch.max((x1, x2, x3, ..., x99))
and for x1
, I would like to get ‘torch.max((x0, x2, x3, …, x99))’
How to write this in memory efficiency way?
The way I do is fist get an index
tensor [100, 99].
index = torch.tensor(
[[1, 2, 3, 4, 5, ..., 99],
[0, 2, 3, 4, 5, ..., 99],
[0, 1, 3, 4, 5, ..., 99],
...
[0, 1, 2, 3, 4, ..., 98]]
)
then, out = torch.max(x[index], 1)
BUT, this may cost so much memory, I think the problem is x[index]
this thing is too big. Especially I need gradient to do back-propagation.
Is there any other better way to do this?