Intersection between to vectors/tensors

Can pytorch find the intersection of two vectors?

Similar to the Matlab “intersection” function:
https://www.mathworks.com/help/matlab/ref/double.intersect.html;jsessionid=0f4b1f0e0ff110c50bb8a0076587

What about other set operations? e.g., union, setdiff?

Thanks

1 Like

Do you have any answers? Thanks.

I used np.intersect1d like:

>>> import torch
>>> import numpy as np
>>> a = torch.tensor([1, 2, 3, 6])
>>> b = torch.tensor([0, 2, 3, 7])
>>> np.intersect1d(a, b)
array([2, 3])

It seems that tensorflow 2.0 has a function named tf.sets.intersection

1 Like

I’ve also been trying to figure it out lately, and I think I’ve found a good solution. And it works without numpy, so no need to transfer your tensors to cpu (which is really slower if while training models). Check this out

>>> import torch
>>> a = torch.tensor([1, 2, 3, 6]).cuda()
>>> b = torch.tensor([0, 2, 3, 7]).cuda()
>>> intersection = (a * (a == b).float()).nonzero().flatten()
>>> intersection
tensor([1, 2], device='cuda:0')
2 Likes

I think this only works if the matching elements are at the same indices. If I rearrange a, I’m getting

>>> import torch
>>> a = torch.tensor([6, 1, 2, 3]).cuda()
>>> b = torch.tensor([0, 2, 3, 7]).cuda()
>>> intersection = (a * (a == b).float()).nonzero().flatten()
>>> intersection
tensor([], device='cuda:0')

I’ve had some success with the second answer on this SO post: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors . The repeat call isn’t ideal memory-wise but it works at least.

2 Likes

You can also modify the first answer like so to get intersection

>>> t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
>>> t2 = torch.tensor([1, 24], device = 'cuda')
>>> indices = torch.zeros_like(t1, dtype = torch.uint8, device = 'cuda')
>>> for elem in t2:
>>>     indices = indices | (t1 == elem)  
>>> intersection = t1[indices]  
1 Like

That’s better. Thanks

Not sure if this would help, as the code hoovers over all t2 elements in a for-loop. Hence, would not benefit from the GPU. In fact, numpy intersect is much faster.

def tensor_intersect(t1, t2):
    t1=t1.cuda()
    t2=t2.cuda()
    indices = torch.zeros_like(t1, dtype = torch.bool, device = 'cuda')
    for elem in t2:
        indices = indices | (t1 == elem)  
        intersection = t1[indices]  
    return intersection
t1= np.random.randint( 1,1e9, 10000)
t2= np.random.randint( 1,1e9, 10000)
tic = time.time()
np.intersect1d(t1, t2) 
print(time.time()-tic)
0.0009970664978027344

tic = time.time()
tensor_intersect( torch.tensor(t1), torch.tensor(t2))
print(time.time()-tic)
1.426218032836914

NB. indices should be changed to:
indices = torch.zeros_like(t1, dtype = torch.bool, device = 'cuda')

Here’s a tweak that’s 15 to 20 times faster than numpy intersect (for large sets). For small sets, it is better to work with numpy, until someone writes a better torch algorithm based on search:


t1= torch.randint(0, 99999, (1000000,))
t2= torch.randint(0, 99999, (100000000,))

def torch_intersect(t1, t2, use_unique=False):
    t1 = t1.cuda()
    t2 = t2.cuda()
    t1 = t1.unique()
    t2 = t2.unique()
        
    return torch.tensor(np.intersect1d(t1.cpu().numpy(), t2.cpu().numpy()))


def torch_intersect2(t1, t2, use_unique=False):
    t1 = t1.cuda()
    t2 = t2.cuda()
    t1 = t1.unique()
    t2 = t2.unique()
    t1=set(t1.cpu().numpy())
    t2=set(t2.cpu().numpy())    
        
    return t1.intersection(t2)

tic= time.time()
res = np.intersect1d(t1.cpu().numpy(), t2.cpu().numpy())    
print('Numpy Intersect Time:', time.time()-tic)

    
tic= time.time()
res = torch_intersect(t1, t2)
print('Tensor Intersect Time:', time.time()-tic)
Numpy Intersect Time: 6.87971830368042
Tensor Intersect Time:  0.3944394588470459
1 Like

Nifty solution I worked out today:

import torch
a = torch.tensor([1, 2, 3, 6]).cuda()
b = torch.tensor([0, 2, 3, 4]).cuda()
a_cat_b, counts = torch.cat([a, b]).unique(return_counts=True)
intersection = a_cat_b[torch.where(counts.gt(1))]

a and b can be of different lengths and common values can be at different positions.

1 Like

For a one line answer I would do

import torch
first = torch.Tensor([1, 2, 3, 4, 5, 6])
second = torch.Tensor([7, 3, 9, 1])
intersection=first[(first.view(1, -1) == second.view(-1, 1)).any(dim=0)]
2 Likes

Thanks for sharing. Your solution could even expand to N-dim vectors, e.g., 3D coordinates. In addition, the index of the intersection could also be acquired in this way, if the return_inverse is set to True in the torch.unique function.