Can pytorch find the intersection of two vectors?
Similar to the Matlab “intersection” function:
https://www.mathworks.com/help/matlab/ref/double.intersect.html;jsessionid=0f4b1f0e0ff110c50bb8a0076587
What about other set operations? e.g., union, setdiff?
Thanks
1 Like
Do you have any answers? Thanks.
WangXin93
(Wang X)
November 18, 2019, 2:58am
#3
I used np.intersect1d
like:
>>> import torch
>>> import numpy as np
>>> a = torch.tensor([1, 2, 3, 6])
>>> b = torch.tensor([0, 2, 3, 7])
>>> np.intersect1d(a, b)
array([2, 3])
It seems that tensorflow 2.0 has a function named tf.sets.intersection
I’ve also been trying to figure it out lately, and I think I’ve found a good solution. And it works without numpy, so no need to transfer your tensors to cpu (which is really slower if while training models). Check this out
>>> import torch
>>> a = torch.tensor([1, 2, 3, 6]).cuda()
>>> b = torch.tensor([0, 2, 3, 7]).cuda()
>>> intersection = (a * (a == b).float()).nonzero().flatten()
>>> intersection
tensor([1, 2], device='cuda:0')
2 Likes
alokpathy
(Alok Tripathy)
April 9, 2020, 6:20pm
#5
I think this only works if the matching elements are at the same indices. If I rearrange a
, I’m getting
>>> import torch
>>> a = torch.tensor([6, 1, 2, 3]).cuda()
>>> b = torch.tensor([0, 2, 3, 7]).cuda()
>>> intersection = (a * (a == b).float()).nonzero().flatten()
>>> intersection
tensor([], device='cuda:0')
I’ve had some success with the second answer on this SO post: https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors . The repeat
call isn’t ideal memory-wise but it works at least.
2 Likes
alokpathy
(Alok Tripathy)
April 9, 2020, 6:30pm
#6
You can also modify the first answer like so to get intersection
>>> t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
>>> t2 = torch.tensor([1, 24], device = 'cuda')
>>> indices = torch.zeros_like(t1, dtype = torch.uint8, device = 'cuda')
>>> for elem in t2:
>>> indices = indices | (t1 == elem)
>>> intersection = t1[indices]
1 Like
Deeply
(Deeply)
February 17, 2021, 2:37pm
#8
alokpathy:
h the second ans
Not sure if this would help, as the code hoovers over all t2
elements in a for-loop. Hence, would not benefit from the GPU. In fact, numpy
intersect is much faster.
def tensor_intersect(t1, t2):
t1=t1.cuda()
t2=t2.cuda()
indices = torch.zeros_like(t1, dtype = torch.bool, device = 'cuda')
for elem in t2:
indices = indices | (t1 == elem)
intersection = t1[indices]
return intersection
t1= np.random.randint( 1,1e9, 10000)
t2= np.random.randint( 1,1e9, 10000)
tic = time.time()
np.intersect1d(t1, t2)
print(time.time()-tic)
0.0009970664978027344
tic = time.time()
tensor_intersect( torch.tensor(t1), torch.tensor(t2))
print(time.time()-tic)
1.426218032836914
NB. indices should be changed to:
indices = torch.zeros_like(t1, dtype = torch.bool, device = 'cuda')
Deeply
(Deeply)
February 17, 2021, 6:03pm
#9
Here’s a tweak that’s 15 to 20 times faster than numpy intersect
(for large sets). For small sets, it is better to work with numpy
, until someone writes a better torch
algorithm based on search:
t1= torch.randint(0, 99999, (1000000,))
t2= torch.randint(0, 99999, (100000000,))
def torch_intersect(t1, t2, use_unique=False):
t1 = t1.cuda()
t2 = t2.cuda()
t1 = t1.unique()
t2 = t2.unique()
return torch.tensor(np.intersect1d(t1.cpu().numpy(), t2.cpu().numpy()))
def torch_intersect2(t1, t2, use_unique=False):
t1 = t1.cuda()
t2 = t2.cuda()
t1 = t1.unique()
t2 = t2.unique()
t1=set(t1.cpu().numpy())
t2=set(t2.cpu().numpy())
return t1.intersection(t2)
tic= time.time()
res = np.intersect1d(t1.cpu().numpy(), t2.cpu().numpy())
print('Numpy Intersect Time:', time.time()-tic)
tic= time.time()
res = torch_intersect(t1, t2)
print('Tensor Intersect Time:', time.time()-tic)
Numpy Intersect Time: 6.87971830368042
Tensor Intersect Time: 0.3944394588470459
1 Like
Nifty solution I worked out today:
import torch
a = torch.tensor([1, 2, 3, 6]).cuda()
b = torch.tensor([0, 2, 3, 4]).cuda()
a_cat_b, counts = torch.cat([a, b]).unique(return_counts=True)
intersection = a_cat_b[torch.where(counts.gt(1))]
a and b can be of different lengths and common values can be at different positions.
1 Like