Can pytorch find the intersection of two vectors?

Similar to the Matlab “intersection” function:
https://www.mathworks.com/help/matlab/ref/double.intersect.html;jsessionid=0f4b1f0e0ff110c50bb8a0076587

What about other set operations? e.g., union, setdiff?

Thanks

1 Like

Do you have any answers? Thanks.

WangXin93
(Wang X)
November 18, 2019, 2:58am
#3
I used `np.intersect1d`

like:

```
>>> import torch
>>> import numpy as np
>>> a = torch.tensor([1, 2, 3, 6])
>>> b = torch.tensor([0, 2, 3, 7])
>>> np.intersect1d(a, b)
array([2, 3])
```

It seems that tensorflow 2.0 has a function named `tf.sets.intersection`

I’ve also been trying to figure it out lately, and I think I’ve found a good solution. And it works without numpy, so no need to transfer your tensors to cpu (which is really slower if while training models). Check this out

```
>>> import torch
>>> a = torch.tensor([1, 2, 3, 6]).cuda()
>>> b = torch.tensor([0, 2, 3, 7]).cuda()
>>> intersection = (a * (a == b).float()).nonzero().flatten()
>>> intersection
tensor([1, 2], device='cuda:0')
```

2 Likes

alokpathy
(Alok Tripathy)
April 9, 2020, 6:20pm
#5
I think this only works if the matching elements are at the same indices. If I rearrange `a`

, I’m getting

```
>>> import torch
>>> a = torch.tensor([6, 1, 2, 3]).cuda()
>>> b = torch.tensor([0, 2, 3, 7]).cuda()
>>> intersection = (a * (a == b).float()).nonzero().flatten()
>>> intersection
tensor([], device='cuda:0')
```

I’ve had some success with the second answer on this SO post: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors . The `repeat`

call isn’t ideal memory-wise but it works at least.

2 Likes

alokpathy
(Alok Tripathy)
April 9, 2020, 6:30pm
#6
You can also modify the first answer like so to get intersection

```
>>> t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
>>> t2 = torch.tensor([1, 24], device = 'cuda')
>>> indices = torch.zeros_like(t1, dtype = torch.uint8, device = 'cuda')
>>> for elem in t2:
>>> indices = indices | (t1 == elem)
>>> intersection = t1[indices]
```

1 Like

Deeply
(Deeply)
February 17, 2021, 2:37pm
#8

alokpathy:

h the second ans

Not sure if this would help, as the code hoovers over all `t2`

elements in a for-loop. Hence, would not benefit from the GPU. In fact, `numpy`

intersect is much faster.

```
def tensor_intersect(t1, t2):
t1=t1.cuda()
t2=t2.cuda()
indices = torch.zeros_like(t1, dtype = torch.bool, device = 'cuda')
for elem in t2:
indices = indices | (t1 == elem)
intersection = t1[indices]
return intersection
```

```
t1= np.random.randint( 1,1e9, 10000)
t2= np.random.randint( 1,1e9, 10000)
tic = time.time()
np.intersect1d(t1, t2)
print(time.time()-tic)
0.0009970664978027344
tic = time.time()
tensor_intersect( torch.tensor(t1), torch.tensor(t2))
print(time.time()-tic)
1.426218032836914
```

NB. indices should be changed to:
`indices = torch.zeros_like(t1, dtype = torch.bool, device = 'cuda')`

Deeply
(Deeply)
February 17, 2021, 6:03pm
#9
Here’s a tweak that’s 15 to 20 times faster than `numpy intersect`

(for large sets). For small sets, it is better to work with `numpy`

, until someone writes a better `torch`

algorithm based on search:

```
t1= torch.randint(0, 99999, (1000000,))
t2= torch.randint(0, 99999, (100000000,))
def torch_intersect(t1, t2, use_unique=False):
t1 = t1.cuda()
t2 = t2.cuda()
t1 = t1.unique()
t2 = t2.unique()
return torch.tensor(np.intersect1d(t1.cpu().numpy(), t2.cpu().numpy()))
def torch_intersect2(t1, t2, use_unique=False):
t1 = t1.cuda()
t2 = t2.cuda()
t1 = t1.unique()
t2 = t2.unique()
t1=set(t1.cpu().numpy())
t2=set(t2.cpu().numpy())
return t1.intersection(t2)
tic= time.time()
res = np.intersect1d(t1.cpu().numpy(), t2.cpu().numpy())
print('Numpy Intersect Time:', time.time()-tic)
tic= time.time()
res = torch_intersect(t1, t2)
print('Tensor Intersect Time:', time.time()-tic)
```

```
Numpy Intersect Time: 6.87971830368042
Tensor Intersect Time: 0.3944394588470459
```

1 Like

Nifty solution I worked out today:

```
import torch
a = torch.tensor([1, 2, 3, 6]).cuda()
b = torch.tensor([0, 2, 3, 4]).cuda()
a_cat_b, counts = torch.cat([a, b]).unique(return_counts=True)
intersection = a_cat_b[torch.where(counts.gt(1))]
```

a and b can be of different lengths and common values can be at different positions.

1 Like