Currently I has index of batch_size * candidate_size
and desired_candidate : batch_size * top_5
like, if batch_size
== 16 and candidate_size == 15
>> index.size()
torch.Size([64, 15])
>> desired_candidate.size()
torch.Size([64, 5])
>> index[0]
tensor([104, 171, 182, 3, 56, 178, 6, 6, 4, 21, 30, 182, 27, 39, 56], device='cuda:0', dtype=torch.int32)
>> desired_candidate[0]
tensor([171, 4, 182, 102, 61], device='cuda:0')
I’d like to filter index by its desired_candidate, so the desired one is
>> index_[0]
tensor([104, -1, -1, 3, 56, 178, 6, 6, -1, 21, 30, -1, 27, 39, 56], device='cuda:0', dtype=torch.int32)
for each batch. But I can’t find out how to do this.
If someone would know how to implement this, I’d appreciate it.
Thanks.