# Efficiently get first N numbers that satisfy a condition in each row in a pytorch tensor

Given a tensor `b`, and I would like to extract `N` elements in each row that satisfy a specific condition. For example, suppose `a` is a matrix that indicates whether an element in `b` satisfy the condition or not. Now, I would like to extract `N` elements in each row whose corresponding value in `a` is `1`.

And there can be two scenarios. (1) I just extract the first `N` elements in each row in order. (2) among all the elements that satisfy the condition, I randomly sample `N` elements in each row.

Is there an efficient way to achieve these two cases in pytorch? Using a for loop will be very slow when `b` has many rows. Thanks!

Below I give an example that shows the first case.

``````import torch
# given
a = torch.tensor([[1, 0, 0, 1, 1, 1], [0, 1, 0, 1, 1, 1], [1,1,1,1,1,0]])
b = torch.arange(18).view(3,6)

# suppose N=3
# output:
c = torch.tensor([[0, 3,4],[7,9,10], [12,13,14]])
``````

This can be a way to do what you want.

However, you need to take into account that you will NOT notice if a row from `a` has less than `N`.

In this example I have put a fourth row with only two `1`'s, but you still get `3` values.

``````a = torch.tensor([[1, 0, 0, 1, 1, 1], [0, 1, 0, 1, 1, 1], [1,1,1,1,1,0], [1,0,0,1,0,0]])
b = torch.arange(24).view(4,6)

N = 3

mode = 1

if mode == 1:
# Choose in order from left to right
idx = a * torch.arange(start=a.shape+1, end=1, step=-1)
if mode == 2:
# Choose randomly
idx = a * torch.randint(low=1, high=a.shape+1, size=(a.size()))

_, ind = torch.topk(input=idx, k=N, dim=1)
c = b[torch.arange(b.shape), ind.T].T

print(c)
``````
``````# Output:
tensor([[ 0,  3,  4],
[ 7,  9, 10],
[12, 13, 14],
[18, 21, 22]]) # ← Here are actually only two True values in the mask.
``````

If you are sure that you will always have at least `N` values, then it should be no problem.
Else, you can `a.sum(dim=1)` to see how many values you can trust from this new matrix.

1 Like

Thanks! this works well!

1 Like