In some case, we have to pick up from a tensor some rows that consist a specific value in corresponding another tensor or matrix. Here, I have implemented a function that does this process, but I wonder is there any other faster way?
def pick_up_corresponding_vectors(embedding_tensor,hp,wanted):
batch_size, sequence_length, hidden_size = embedding_tensor.shape
batch_is,token_is = (hp==wanted).nonzero(as_tuple=True)
indices = torch.zeros((batch_size,torch.bincount(batch_is).max().item()))
for batch_i, token_i in zip(batch_is,token_is):
m =(batch_is==batch_i).nonzero().flatten().tolist()
k = len(m)
indices[batch_i,:k]= token_is[m]
batch_size, max_seq = indices.shape
indices_ = indices.repeat(1,hidden_size).view(batch_size,-1,max_seq).transpose(2,1).type(torch.int64)
return torch.gather(embedding_tensor,1,indices_)
Here a notebook to check it easily. Also you can find more explanation about required the process.
Thanks in advance.