Currently I has index of `batch_size * candidate_size`

and desired_candidate : `batch_size * top_5`

like, if `batch_size`

== 16 and candidate_size == 15

```
>> index.size()
torch.Size([64, 15])
>> desired_candidate.size()
torch.Size([64, 5])
>> index[0]
tensor([104, 171, 182, 3, 56, 178, 6, 6, 4, 21, 30, 182, 27, 39, 56], device='cuda:0', dtype=torch.int32)
>> desired_candidate[0]
tensor([171, 4, 182, 102, 61], device='cuda:0')
```

I’d like to filter index by its desired_candidate, so the desired one is

```
>> index_[0]
tensor([104, -1, -1, 3, 56, 178, 6, 6, -1, 21, 30, -1, 27, 39, 56], device='cuda:0', dtype=torch.int32)
```

for each batch. But I can’t find out how to do this.

If someone would know how to implement this, I’d appreciate it.

Thanks.