Hi there, is there a way for PyTorch to calculate the intersection of two 2D tensors? The post here Intersection between to vectors/tensors only provides the method for the 1D tensor. Many thanks!
I think the link answers 2D as well. Does this answer?
a = torch.FloatTensor([[1,2,3],[4,5,6]])
b = torch.FloatTensor([[1,2,9],[8,7,6]])
(a==b).nonzero()
→ tensor([[0, 0], [0, 1], [1, 2]])
Thanks a lot. But it seemed only to return the element position. Can I get the tensor that indicates the intersection between 2D tensors?
(1) (a==b)
tensor([[ True, True, False],
[False, False, True]])
(2) c = torch.FloatTensor([[1,2,0],[0,0,6]])
Can I get the value of FloatTensor intersection(c in case (2)) rather than the bool value in case (1)?
Thanks!
Check (a==b) *a . This should give elements where intersecting and zeros if not. Use result.nonzeor() if you don’t want zeros
1 Like