For the first batch [0, 1, 2, 1], only 1s are duplicate, so the index with larger score remains and the other 1 is replaced by -1, similarly in the second batch [3, 2, 1, 3], 3s are duplicate, so only 3s with the smaller scores are replaced
Ho I missed the part about duplicates in your explanation.
I’m afraid you won’t have a specialized function to do this, so you will have to do it by handm checking each index.
Thanks for the reply.
I do have an idea of how I want to implement this, but it is limited to just one batch. If I flatten the all batches and then use that function, the search space for comparison will be too large; since each duplicate value will be compared with all the other values in flattened tensor. Is there away I can apply a function to all batches individually?
You can do a outer for loop and just look at input = full_input.select(0, batch_idx) every time.
Note that select() does not copy memory and so any inplace change of input will be reflected into full_input !