I found an equivalent operation:
for i, x in enumerate(X):
x = torch.masked_select(x, Y[i].repeat(x.shape[1], 1).T).reshape(-1, x.shape[1])
The operation is still slow though. It seems that GPU is not good at such kind of task…
I found an equivalent operation:
for i, x in enumerate(X):
x = torch.masked_select(x, Y[i].repeat(x.shape[1], 1).T).reshape(-1, x.shape[1])
The operation is still slow though. It seems that GPU is not good at such kind of task…