What is the best way to pick up some rows in a tensor?

In some case, we have to pick up from a tensor some rows that consist a specific value in corresponding another tensor or matrix. Here, I have implemented a function that does this process, but I wonder is there any other faster way?

def pick_up_corresponding_vectors(embedding_tensor,hp,wanted):
    batch_size, sequence_length, hidden_size = embedding_tensor.shape
    batch_is,token_is = (hp==wanted).nonzero(as_tuple=True)
    indices = torch.zeros((batch_size,torch.bincount(batch_is).max().item()))
    for batch_i, token_i in zip(batch_is,token_is):
        m =(batch_is==batch_i).nonzero().flatten().tolist()
        k = len(m)
        indices[batch_i,:k]= token_is[m]
    batch_size, max_seq = indices.shape 
    indices_ = indices.repeat(1,hidden_size).view(batch_size,-1,max_seq).transpose(2,1).type(torch.int64)
    return torch.gather(embedding_tensor,1,indices_)

Here a notebook to check it easily. Also you can find more explanation about required the process.

Thanks in advance.