Intersection between tensors

Given two tensor:

A = torch.tensor([[0, 2],
                  [1, 2],
                  [2, 0],
                  [2, 1],
                  [2, 3],
                  [3, 2],
                  [1, 8],
                  [8, 1]])

B = torch.tensor([[0, 3],
                  [1, 8],
                  [8, 1]])

How two get the intersection of two tensors, so that I can get tensor([[1,8],[8,1]])?

Maybe

import torch

A = torch.tensor([[0, 2], [1, 2], [2, 0], [2, 1],
                  [2, 8], [3, 2], [1, 8], [8, 2]])

B = torch.tensor([[0, 3], [1, 8], [8, 2]])

C = (A[:, None] == B[None, :])
C1 = (C[..., 0] & C[..., 1]).nonzero()
interset = A[C1[:, 0], :]
print(interset)