Extracting layer outputs using a binary mask


(Alex) #1

Hi all,

Couldn’t find a better title for this, so I’ll explain it here:

I have a tensor with model outputs, e.g. in the shape [5, 3].
I’m trying to retain only a few of these, e.g. indices 2 and 4.
I’ve tried using torch.masked_select() to retain these entries, but it requires a 1D vector to operate over.
What alternatives do I have to extracting only the entries in these indices?
I have an implementation that uses torch.cat to copy the relevant entries over into a dummy tensor iteratively, but this feels hacky and I feel that there is a more efficient solution.
What are my options to implement this?

EDIT: Simplified the question


#2

I’m not really understand, what kind of indices you would like to slice.
Are the indices in dim0?
If so, you could just use:

outputs = torch.randn(5, 3)
outputs[[2, 4], :]

(Alex) #3

Thanks for your reply, this was exactly what I was looking for. Looks like I was greatly overthinking the problem!