Suppose I have a tensor, a,
that I wish to index.
a
has:
B
batches;
C
channels;
and dimensions X
, Y
, Z
that contain the values I want to index.
a.shape
>> torch.Size([B, C, X, Y, Z])
I have a list of tensors, b
, of length B
, each containing N
indices I want to obtain. Each element of the list may have different values for N
ie. N0
, N1
…
len(b)
>> B
b[0].shape
>> torch.Size([N0, 3])
b[1].shape
>> torch.Size([N1, 3])
To be clear, having indexed the tensor, the final result I want is the following:
len(out)
>> B
out[0].shape
>> torch.Size([C, N0])
out[1].shape
>> torch.Size([C, N1])
What is the best way of going about this?