Alternatives to torch.isin()

Hi, I am looking for an alternative to torch.isin().I;m using the below code snippet to compare instances in 2 tensors (1D tensors but different lengths), and masking out non-common tensors to obtain 2 identical tensors. This works on my local machine, but fails on the slurm cluster at my university.

        valid_instances = np.intersect1d(instance_idx.cpu().numpy(),instance_idx_pair.cpu().numpy()) #
        valid_instances = torch.tensor(valid_instances,device=device)
        # intersect_mask, to remove instances which are present A but B and VICE VERSA
        intersect_mask = torch.isin(instance_idx,valid_instances)
        intersect_mask_pair = torch.isin(instance_idx_pair,valid_instances)
        instance_idx = instance_idx[intersect_mask]
        instance_idx_pair = instance_idx_pair[intersect_mask_pair]

I am guessing Pytorch version mismatch could be an issue. Is there an alternative operator to torch.isin()?

For the moment, I have reverted to using a for loop for this task

        intersect_mask = torch.tensor([idx in valid_instances for idx in instance_idx], device=device, dtype=torch.bool)
        intersect_mask_pair = torch.tensor([idx in valid_instances for idx in instance_idx_pair], device=device, dtype=torch.bool)

It works, but is too slow. Could someone suggest a better way / more elegant solution for the same?

Hi Advait!

You can avoid the explicit python (list-comprehension) loops at the cost
of materializing what could be a large tensor. (But it still uses an inefficient
brute-force algorithm, so figuring out how to get torch.isin() working
would really be the right way to go.)

I haven’t waded into the rest of your code, but here is a loop-free pytorch
replacement for isin():

>>> import torch
>>> print (torch.__version__)
>>> _ = torch.manual_seed (2023)
>>> e = torch.randint (128, (64,))
>>> t = torch.randint (128, (128,))
>>> isin = torch.isin (e, t)
>>> isinB = (e.unsqueeze (1) == t).sum (dim = 1).bool()
>>> torch.equal (isin, isinB)


K. Frank