How get indices of elements in an array on another array?

a = torch.randint(0,5,(10,))
b =  torch.tensor([3,4,1,2, 0] )

I want to find all positions of elements in b in array a. Currently, I use this

[a.eq(b[i]).nonzero(as_tuple=True) for i in range(b.shape[0])].

Is there any better way to do this?

You can do something like this.

((a - b.unsqueeze(0).T) == 0).nonzero()

The output format is different, but the idea is the same.

Example

Where

#a = tensor([3, 3, 3, 2, 4, 2, 1, 3, 4, 0])
#b = tensor([3, 4, 1, 2, 0])
  • Your format
# Output: 
[(tensor([0, 1, 2, 7]),), (tensor([4, 8]),), (tensor([6]),), (tensor([3, 5]),), (tensor([9]),)]
  • New format
# Output:
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 7],
        [1, 4],
        [1, 8],
        [2, 6],
        [3, 3],
        [3, 5],
        [4, 9]])

Where the first column corresponds to the index for b and the second column for a.

2 Likes
# Output:
tensor([[0, 0],
        [0, 1],
        [0, 2],
        [0, 7],
        [1, 4],
        [1, 8],
        [2, 6],
        [3, 3],
        [3, 5],
        [4, 9]])

Is there way to do a mean-pooling with respect index for b. For example, I want to do a grouby mean like this,

 [t[t[:,0].float()==i][:, 1].float().mean() for i in range(4)]

I got it, you can do the following,

 index = ((a - b.unsqueeze(0).T)== 0 ).float()
r  = index @ a.unsqueeze(dim=1 )/ (index.sum(dim=1, keepdim=True)

Unfortunately this approach will use immense memory since (a - b.unsqueeze(0).T) spans a whole matrix over a and b. It was not suitable for my purpose (a and b had more than 150000 elements, this results in 22,5 billion elements in memory ~ 21GB).

My solution is:

def index_of_a_in_b(a, b):
  b_indices = torch.where(torch.isin(b, a))[0]
  b_values = b[b_indices]
  return b_indices[b_values.argsort()[a.argsort().argsort()]]