Replace values in tensor

Given a before and after tensor, I want to replace all instance of before in another tensor A with after without using loops.


before = torch.Tensor([2,4,5])
after  = torch.Tensor([20,40,50])
A = torch.Tensor([1,2,3,4,5,6])

result = replace(A, before, after)

The result should be torch.Tensor([1,20,3,40,50,6]).


You can use torch.where to replace it:

for b, a in zip(before, after):
    A = torch.where(A == b, a, A)

Hello @vardan, thank you for the reply. I want to avoid loops. I have to replace a lot of values, so this may be inefficient.

If you are guaranteed to have all elements in before that are present in A then you can do:

idxs = torch.nonzero(A.unsqueeze(1) == before, as_tuple=True)[0]
A[idxs] = after
1 Like

Yes, all elements in before are present in A. This is a good solution, thank you! However, I am dealing with large 2D tensors, and the unsqueeze and compare step may be heavy on the memory if before is large.

There is a minor bug in case the tensors are not sorted, this solves it:

idx_0, idx_1 = torch.nonzero(A.unsqueeze(1) == before, as_tuple=True)
A[idx_0] = after[idx_1]