Filtering one tensor if another tensor meets some condition

I have a k x 2 tensor named points and I have another k x 1 tensor named mask. mask contains 1 or 0 for each index. I want to filter points and remove the entire row if mask does not contain a 1 for that specific k. How can I do this?

You could use torch.masked_select:

k = 10
x = torch.randn(k, 2)
mask = torch.empty(k, 1, dtype=torch.uint8).random_(2)
x.masked_select(mask).view(-1, 2)
1 Like