What exactly would you like to index?
In your current approach, you will select from dimension 0, which has only one entry.
So basically you are repeating your Tensor, resulting in a size of [131072, 1, 256, 512].
I also meet similar issue when using torch.index_select. What I have are two tensors say lut (64 x 7281) and idx (64 x 943), where values in the idx tensor are 0 to 7280 and I need to use some way like res = torch.stack([torch.index_select(l_, 0, i_) for l_, i_ in zip(lut, idx)]). I wonder if there is any more memory-efficient way to do this.