Couldn’t find a better title for this, so I’ll explain it here:
I have a tensor with model outputs, e.g. in the shape [5, 3].
I’m trying to retain only a few of these, e.g. indices 2 and 4.
I’ve tried using torch.masked_select() to retain these entries, but it requires a 1D vector to operate over.
What alternatives do I have to extracting only the entries in these indices?
I have an implementation that uses torch.cat to copy the relevant entries over into a dummy tensor iteratively, but this feels hacky and I feel that there is a more efficient solution.
What are my options to implement this?
EDIT: Simplified the question