Generate a mask for all duplicate grey-scale pixels in a minibatch

Hello guys, I want to find all duplicate elements in a minibatch tensor, specifically:
given a tensor X=[B, 1, H, W], which can be regarded as B grey-scale images with size H*W. I want to find duplicate pixel in each image, thus, the desired output is a binary mask Y=[B, 1, H, W], where every unique pixel is assigned 1 while all duplicate pixel assigned 0.

Currently I am able to achieve part of the above function via torch.unique+torch.split (for a single grey-image rather than a minibatch), and I show my implementation here:

import torch

W, H = 3, 3
image_flatten = torch.tensor([2, 3, 4, 1, 6, 2, 2, 5, 3])
print(image_flatten.view(W, H))
sorted_image_flatten, idx_sort = torch.sort(image_flatten)
unique_values, _, count = torch.unique(sorted_image_flatten, return_counts=True, return_inverse=True)
repetitive_idx_groups = torch.split(idx_sort, list(count))
rep_idx_flatten = [rep_idx \
        for rep_idx_group in repetitive_idx_groups if len(rep_idx_group) > 1 \
        for rep_idx in rep_idx_group]
rep_idx_flatten = torch.tensor(rep_idx_flatten)
mask = torch.ones((W*H,1))
mask[rep_idx_flatten] = 0
mask = mask.view(H, W)
print(mask)

the outputs:

tensor([[2, 3, 4],
        [1, 6, 2],
        [2, 5, 3]])
tensor([[0., 0., 1.],
        [1., 1., 0.],
        [0., 1., 0.]])

The problem is that this implementation costs too much time even when running on CUDA, the time budget is too high to allow training my network. The reasons are 1) torch.split is very time-consuming, 2) torch.unique does not support batch operation.

So I wonder if there is something in PyTorch Library that can achieve the above function for a minibatch? I really appreaciate your time on helping me solve this problem.

upon trying your code, it seems that sorting before unique is eating some time. unique takes a sorted argument of its own that takes less time (i think sorting before unique runs 2 passes through the data, while sorting during unique does it in 1 pass). also you can use the inverse indexes returned by unique to simplify the code:

import torch
import time
W, H = 3, 3
image_flatten = torch.tensor([2, 3, 4, 1, 6, 2, 2, 5, 3])
print(image_flatten.view(W, H))
# sorted_image_flatten, idx_sort = torch.sort(image_flatten)
unique_values, inverse_indexes, count = torch.unique(image_flatten, sorted = True,return_counts=True, return_inverse=True)
non_repeated_vals = (count == 1).float()
mask_flatten = non_repeated_vals[inverse_indexes]
mask = mask_flatten.view(H,W)

Brilliant :+1: I tried your code and found the time budget is almost affordable now.
Actually when tracing the time cost, I found the most time is spent on torch.split, and creating and assigning the mask.
Thank you~

1 Like