# Is there a Pytorch low-level function to efficiently combine two tensors using a 0-1 mask?

For example,
tensors A & B are of the same shape, m is a 0-1 tensor mask of the same shape,
I want `C = A * m + B * (1-m)`, but implemented efficiently in low-level C++ code.

Thank you!

it should work the way you wrote it given sizes are the same

Thanks! I’m looking for a faster low-level operation.

I think numpy is as fast as any low level language (pytorch as well)

Hello,
You’re indeed right that 2 multiplies + 1 add (+1 sub for `1-m`) + the intermediary memory allocations are completely unnecessary (and also look horrible) when one could just do a CUDA kernel where the thread just looks at the boolean in m[thread_idx] and chooses either A[thread_idx] or B[thread_idx], which should be much faster. Fortunately, there’s already a native torch function which does just that: `torch.where` (torch.where — PyTorch 1.10.0 documentation)

1 Like

Another similar function is `torch.lerp` if your mask is a float tensor

1 Like

Thanks, that’s precise! This function seems to have been introduce since v0.4, how can i miss it ?

Thanks! This is also a good solution, but I can only mark one as solution.

Hi Zfzhang!

Please note – not necessarily relevant to your use case – that
equivalent `torch.lerp()` protects against singularities in the

And `torch.where()` – perhaps surprisingly – doesn’t protect

Consider:

``````>>> import torch
>>> torch.__version__
'1.9.0'
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> m = (x > 0).float()
>>> y = m * torch.sqrt (x) + (1 - m) * torch.sqrt (-x)
>>> y
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
>>> y.sum().backward()
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.where (x > 0, torch.sqrt (x), torch.sqrt (-x))
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
>>> y.sum().backward()
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.sqrt (x.abs())
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],