How to efficiently find correspondences between two point sets without nested for loop in Pytorch?

Hello,

I now have two point sets (tensor) A and B that shape like

A.size() >>(50, 3) , example: [ [0, 0, 0], [0, 1, 2], …, [1, 1, 1]]

B.size() >>(10, 3)

where the first dimension stands for number of points and the second dim stands for coordinates (x,y,z)

To some extent, the question could also be simplified into " Finding common elements between two tensors ". Is there a quick way to do this without nested loop like :

def is_corr(a, b):
    corr = []
    index = []
    for idx, ele_b in enumerate(b):
        cur_b = ele_b
        for ids, ele_a in enumerate(a):
            if torch.equal(ele_a, cur_b):
                corr.append(ele_a)
                index.append(ids)

    return corr, index

Thanks

Hi,

I do not know there is any builtin function for this, but the best way I can think of to loop over elements in b and use it as index in a.
I know this solution sounds weird but may be we can optimize it later!

indices = [torch.sum(a==b[idx], dim=1) == len(b[idx]) for idx in range(len(b))]

a==b[idx] will return [x, x, x] in your shape where x is boolean. We looking for equality so sum should be equal to length of index tensor b.

Bests

Hi,

Thanks for your quick reply, but it seems a bit tricky for me to understand, and I write a solution like:

space = torch.zeros((100, 100, 100)) # total coordinates space
pointset1 = torch.randint(0,10,(10240, 3)) # random sample
pointset2 = torch.randint(0,10,(256, 3))   # random sample

space[pointset1[:, 0], pointset1[:, 1], pointset1[:, 2]] += 1 
space[pointset2[:, 0], pointset2[:, 1], pointset2[:, 2]] += 1

corrs = space > 1  ## Get intersaction
del space
print(torch.nonzero(corrs))

What do you think of it ?

1 Like

Your idea is interesting!
The coordinates you calculate are correct in space. But there is a few problems.

  1. How can you remap corrs obtained from space to values in pointset1?
  2. It does not count occurrences more than once.
    About second problem,

This line will not +1 for duplicate indices. For instance, if you define pointset1 all ones, then only index [1, 1, 1] in space will be 1. But we expect it to be len(pointset1).

Try this test:

pointset1 = torch.tensor([[1, 0, 1],
        [0, 0, 0],
        [1, 0, 0], 
        [1, 0, 0], 
        [0, 1, 1], 
        [0, 1, 1], 
        [0, 1, 0], 
        [1, 1, 1],
        [1, 0, 1],
        [1, 1, 1]])
pointset2 = torch.tensor([[1, 1, 1], [1, 0, 0]])

A way like computing histogram would help but still it is not possible to ocmpute it without for loop.

Thank you! And I found my solution inapplicable for my real problem. The thought of building a space map only works when the coordinates are all integer. However, in real problem, coordinates are float type like [0.5678, 1.2345, 1.3678] …

Actually, I was going to ask how to determine space size as it looked arbitrary. So, have you solved the issue?

Yes, I think the code below could solve the problem:

((pointset1[:, None, :] == pointset2[None, ...]).all(dim=2)).nonzero()

this would return a tensor with shape of [ [index1_in_set1, index1_in_set2], [], [], …, []]

where each [index_in_set1 and index_in_set2] stands for a correspondence. And the two elements index_in_set1 and index_in_set2 stand for point’s index in the set1 and set 2 respectively.