I have a 3-dimensional tensor. 10_000 examples x 10 predicted labels x outputs from 3 models. (10_000, 10, 3)
I also have a tensor with ids of the model outputs I would like to use for each label. It has 10 values between 0 and 2.
Is there any way I could index into the 3d tensor picking an output for a label from a specified model? At the end I would like to have a 10_000 examples x 10 labels tensor where for each of the labels I picked predictions from a model of my choosing.
I am currently doing this via permuting the original tensor to be of shape (10, 3, 10_000)
and looping over the labels. I store the outputs in a list and concatenate them into a tensor at the end.
I tried using tensor.gather
and tensor.index_select
but couldn’t get either to work. Intuitively I feel there must a better way of doing this.
This is the code I have:
labels = []
for label, idx in zip(preds.permute(2, 0, 1), best_model_idx_per_label):
labels.append(label[idx])
torch.stack(labels)
Would be grateful for any help. Thank you!