Given a tensor
b, and I would like to extract
N elements in each row that satisfy a specific condition. For example, suppose
a is a matrix that indicates whether an element in
b satisfy the condition or not. Now, I would like to extract
N elements in each row whose corresponding value in
And there can be two scenarios. (1) I just extract the first
N elements in each row in order. (2) among all the elements that satisfy the condition, I randomly sample
N elements in each row.
Is there an efficient way to achieve these two cases in pytorch? Using a for loop will be very slow when
b has many rows. Thanks!
Below I give an example that shows the first case.
import torch # given a = torch.tensor([[1, 0, 0, 1, 1, 1], [0, 1, 0, 1, 1, 1], [1,1,1,1,1,0]]) b = torch.arange(18).view(3,6) # suppose N=3 # output: c = torch.tensor([[0, 3,4],[7,9,10], [12,13,14]])