Efficiently get first N numbers that satisfy a condition in each row in a pytorch tensor

Given a tensor b, and I would like to extract N elements in each row that satisfy a specific condition. For example, suppose a is a matrix that indicates whether an element in b satisfy the condition or not. Now, I would like to extract N elements in each row whose corresponding value in a is 1.

And there can be two scenarios. (1) I just extract the first N elements in each row in order. (2) among all the elements that satisfy the condition, I randomly sample N elements in each row.

Is there an efficient way to achieve these two cases in pytorch? Using a for loop will be very slow when b has many rows. Thanks!

Below I give an example that shows the first case.

import torch
# given
a = torch.tensor([[1, 0, 0, 1, 1, 1], [0, 1, 0, 1, 1, 1], [1,1,1,1,1,0]])
b = torch.arange(18).view(3,6)

# suppose N=3
# output:
c = torch.tensor([[0, 3,4],[7,9,10], [12,13,14]])

This can be a way to do what you want.

However, you need to take into account that you will NOT notice if a row from a has less than N.

In this example I have put a fourth row with only two 1's, but you still get 3 values.

a = torch.tensor([[1, 0, 0, 1, 1, 1], [0, 1, 0, 1, 1, 1], [1,1,1,1,1,0], [1,0,0,1,0,0]])
b = torch.arange(24).view(4,6)

N = 3

mode = 1

if mode == 1:
    # Choose in order from left to right
    idx = a * torch.arange(start=a.shape[1]+1, end=1, step=-1)
if mode == 2:
    # Choose randomly
    idx = a * torch.randint(low=1, high=a.shape[1]+1, size=(a.size()))

_, ind = torch.topk(input=idx, k=N, dim=1)
c = b[torch.arange(b.shape[0]), ind.T].T

print(c)
# Output:
tensor([[ 0,  3,  4],
        [ 7,  9, 10],
        [12, 13, 14],
        [18, 21, 22]]) # ← Here are actually only two True values in the mask.

If you are sure that you will always have at least N values, then it should be no problem.
Else, you can a.sum(dim=1) to see how many values you can trust from this new matrix.

1 Like

Thanks! this works well!

1 Like