I have a tensor/vector of 1s and 0s which show which examples are ‘alive’, like:
a = torch.FloatTensor([1, 0, 1])
then I have a tensor I want to somehow retrieve these indexes rows from, in as efficient a way as possible, whilst ideally having not too convoluted code. eg I have tensor:
In [7]: b = torch.rand(3, 2)
In [8]: b
Out[8]:
0.5253 0.4571
0.9760 0.0465
0.3184 0.7277
[torch.FloatTensor of size 3x2]
I came up with two candidate ways to index into this:
In [46]: b.index_select(0, a.nonzero().long().view(-1))
Out[46]:
0.5253 0.4571
0.3184 0.7277
[torch.FloatTensor of size 2x2]
This one needs two ops: the nonzero
and then the index_select
. There’s also a long
cast, which will presumably involve a data copy.
Or:
In [48]: b.masked_select(a.view(-1, 1).byte())
Out[48]:
0.5253
0.4571
0.3184
0.7277
[torch.FloatTensor of size 4]
This is only a single op (I’m happy to store the mask vector as byte
, so that cast could be removed). On the downside the output needs to be re-reshaped back again. Which is free, but does mean an extra variable somewhere, to store the original shape.
It looks like the second is probably more efficient?
- what are standard ways of doing this?
- most concise ways of writing?
- most data efficient?