I have indices of length N where each index in this list corresponds to a channel. I need to select the channels based on these indices. The size of the output should be Nx1xWxH or NxWxH. How can I do this?
I think you want to expand your indices as Nx1xWxH by doing ind.view(N, 1, 1, 1).expand(N, 1, W, H).
Then you can use gather to get the values: res = value.gather(ind, 1).
Then if you want to remove the dimension of size 1, you can do res = res.squeeze(1).
import torch
torch.__version__
torch.manual_seed (2020)
N = 5
C = 4
W = 3
t = torch.randn (N, C, W)
t
i = torch.LongTensor ([3, 0, 3, 1, 2])
i
t[torch.arange (0, N).long(), i, :]
That works as well.
Not sure which one will be fastest and use less memory…
My personal preference is usually to avoid fancy advanced indexing because I’m not used to that semantic Also allocating the arange() Tensor can be expensive depending on the sizes involved. Note that expand does not allocate memory so my solution don’t do any extra allocation.