Intersection between 2D tensors

Hi there, is there a way for PyTorch to calculate the intersection of two 2D tensors? The post here Intersection between to vectors/tensors only provides the method for the 1D tensor. Many thanks!

I think the link answers 2D as well. Does this answer?
a = torch.FloatTensor([[1,2,3],[4,5,6]])
b = torch.FloatTensor([[1,2,9],[8,7,6]])
(a==b).nonzero()
tensor([[0, 0], [0, 1], [1, 2]])

Thanks a lot. But it seemed only to return the element position. Can I get the tensor that indicates the intersection between 2D tensors?

(1) (a==b)

tensor([[ True,  True, False],
        [False, False,  True]])

(2) c = torch.FloatTensor([[1,2,0],[0,0,6]])

Can I get the value of FloatTensor intersection(c in case (2)) rather than the bool value in case (1)?

Thanks!

Check (a==b) *a . This should give elements where intersecting and zeros if not. Use result.nonzeor() if you don’t want zeros

1 Like