I have a tensor/vector of 1s and 0s which show which examples are ‘alive’, like:
a = torch.FloatTensor([1, 0, 1])
then I have a tensor I want to somehow retrieve these indexes rows from, in as efficient a way as possible, whilst ideally having not too convoluted code. eg I have tensor:
In : b = torch.rand(3, 2) In : b Out: 0.5253 0.4571 0.9760 0.0465 0.3184 0.7277 [torch.FloatTensor of size 3x2]
I came up with two candidate ways to index into this:
In : b.index_select(0, a.nonzero().long().view(-1)) Out: 0.5253 0.4571 0.3184 0.7277 [torch.FloatTensor of size 2x2]
This one needs two ops: the
nonzero and then the
index_select. There’s also a
long cast, which will presumably involve a data copy.
In : b.masked_select(a.view(-1, 1).byte()) Out: 0.5253 0.4571 0.3184 0.7277 [torch.FloatTensor of size 4]
This is only a single op (I’m happy to store the mask vector as
byte, so that cast could be removed). On the downside the output needs to be re-reshaped back again. Which is free, but does mean an extra variable somewhere, to store the original shape.
It looks like the second is probably more efficient?
- what are standard ways of doing this?
- most concise ways of writing?
- most data efficient?