In some case, we have to pick up from a tensor some rows that consist a specific value in corresponding another tensor or matrix. Here, I have implemented a function that does this process, but I wonder is there any other faster way?
def pick_up_corresponding_vectors(embedding_tensor,hp,wanted): batch_size, sequence_length, hidden_size = embedding_tensor.shape batch_is,token_is = (hp==wanted).nonzero(as_tuple=True) indices = torch.zeros((batch_size,torch.bincount(batch_is).max().item())) for batch_i, token_i in zip(batch_is,token_is): m =(batch_is==batch_i).nonzero().flatten().tolist() k = len(m) indices[batch_i,:k]= token_is[m] batch_size, max_seq = indices.shape indices_ = indices.repeat(1,hidden_size).view(batch_size,-1,max_seq).transpose(2,1).type(torch.int64) return torch.gather(embedding_tensor,1,indices_)
Here a notebook to check it easily. Also you can find more explanation about required the process.
Thanks in advance.