How to mask for each batch?

Currently I has index of batch_size * candidate_size and desired_candidate : batch_size * top_5

like, if batch_size == 16 and candidate_size == 15

>> index.size()
torch.Size([64, 15])

>> desired_candidate.size()
torch.Size([64, 5])

>> index[0]
tensor([104, 171, 182,   3,  56, 178,   6,   6,   4,  21,  30, 182,  27,  39,  56], device='cuda:0', dtype=torch.int32)

>> desired_candidate[0]
tensor([171,   4, 182, 102,  61], device='cuda:0')

I’d like to filter index by its desired_candidate, so the desired one is

>> index_[0]
tensor([104, -1, -1,   3,  56, 178,   6,   6,   -1,  21,  30, -1,  27,  39,  56], device='cuda:0', dtype=torch.int32)

for each batch. But I can’t find out how to do this.
If someone would know how to implement this, I’d appreciate it.
Thanks.

Hi,

It might depend a bit on the memory limitations you have. But the following should work (the sizes are very small for printing purposes):

import torch

b = 2
c = 3

index = torch.rand(b, c).mul(5).long()
desired_candidate = torch.rand(b, 5).mul(5).long()

print(index)
print(desired_candidate)


expanded_size = (b, c, 5)
expanded_index = index.unsqueeze(2).expand(expanded_size)
expanded_desired_candidate = desired_candidate.unsqueeze(1).expand(expanded_size)

mask = expanded_index.eq(expanded_desired_candidate).any(-1)

index[mask] = -1
print(index)

Let me know if it fits your needs !

1 Like

Thanks for great advice! I’ll try it.