I have a k x 2
tensor named points
and I have another k x 1
tensor named mask
. mask
contains 1 or 0 for each index. I want to filter points
and remove the entire row if mask
does not contain a 1 for that specific k
. How can I do this?
You could use torch.masked_select
:
k = 10
x = torch.randn(k, 2)
mask = torch.empty(k, 1, dtype=torch.uint8).random_(2)
x.masked_select(mask).view(-1, 2)
1 Like