How to make torch.nonzero faster

I found an equivalent operation:

for i, x in enumerate(X):
    x = torch.masked_select(x, Y[i].repeat(x.shape[1], 1).T).reshape(-1, x.shape[1])

The operation is still slow though. It seems that GPU is not good at such kind of task…