Torch.index_select out of memory

Ran out of memory while using torch.index_select

>>> print(input.size())
(1, 1, 256, 512)
>>> print(idx.size())
131072
>>> pix = torch.index_select(input, 0, idx)
RuntimeError: cuda runtime error(2): out of memory at /pytorch/torch/lib/THC/generic/THCStorage.cu:58

Any fix? Thank you.

What exactly would you like to index?
In your current approach, you will select from dimension 0, which has only one entry.
So basically you are repeating your Tensor, resulting in a size of [131072, 1, 256, 512].

1 Like

Gather all input pixel value at index idx.

input_flat = input.contiguous().view(-1)
pix = torch.gather(input_flat, 0, idx)

Does this work?

1 Like

How did you define idx? Does it store the “flat” indices, i.e. [0, 256*512]?

Could you try this:

x = torch.randn(1, 1, 256, 512)
idx = torch.arange(256*512).long()
idx = idx[torch.randperm(256*512)]

y = x.view(-1)[idx]

Let me know, if this suits your use case.

I also meet similar issue when using torch.index_select. What I have are two tensors say lut (64 x 7281) and idx (64 x 943), where values in the idx tensor are 0 to 7280 and I need to use some way like res = torch.stack([torch.index_select(l_, 0, i_) for l_, i_ in zip(lut, idx)]). I wonder if there is any more memory-efficient way to do this.

It is interesting that res = torch.stack([l_[i_] for l_, i_ in zip(lut, idx)]) works well and there is no OOM issue. I do not understand why.

Instead of the list comprehension you could also index lut as seen here:

lut = torch.randn(64, 7281)
idx = torch.randint(0, 7281, (64, 943))

res = torch.stack([l_[i_] for l_, i_ in zip(lut, idx)])

out = lut[torch.arange(lut.size(0)).unsqueeze(1), idx]
print((out == res).all())
> tensor(True)