How to find the index of values in a given tensor

a = torch.IntTensor([1,3,2,1,4,2])
b=[2,1,6]
I want to find index of values in list b, with the result index sorted
like output as tensor([0, 2, 3, 5])

I know how to do it separately:
torch.nonzero(a == 1).squeeze_(1)
–>tensor([0, 3])
torch.nonzero(a == 2).squeeze_(1)
–>tensor([2, 5])
torch.nonzero(a == 6).squeeze_(1)
–>tensor([], dtype=torch.int64)

but how can I do it at once? or in a better way.
Thanks.

According to https://stackoverflow.com/questions/64300830/pytorch-tensor-get-the-index-of-the-element-with-specific-values, this is a way to do it:
a = torch.IntTensor([1,3,2,1,4,2])
b=[2,1,6]
mask = torch.zeros(a.shape).type(torch.bool)
for e in b:
…mask = mask + (a == e)

torch.nonzero(mask).squeeze_(1) # tensor([0, 2, 3, 5])