Is there a Pytorch low-level function to efficiently combine two tensors using a 0-1 mask?

For example,
tensors A & B are of the same shape, m is a 0-1 tensor mask of the same shape,
I want C = A * m + B * (1-m), but implemented efficiently in low-level C++ code.

Thank you!

it should work the way you wrote it given sizes are the same

Thanks! I’m looking for a faster low-level operation.

I think numpy is as fast as any low level language (pytorch as well)

Hello,
You’re indeed right that 2 multiplies + 1 add (+1 sub for 1-m) + the intermediary memory allocations are completely unnecessary (and also look horrible) when one could just do a CUDA kernel where the thread just looks at the boolean in m[thread_idx] and chooses either A[thread_idx] or B[thread_idx], which should be much faster. Fortunately, there’s already a native torch function which does just that: torch.where (torch.where — PyTorch 1.10.0 documentation)

1 Like

Another similar function is torch.lerp if your mask is a float tensor

2 Likes

Thanks, that’s precise! This function seems to have been introduce since v0.4, how can i miss it :slight_smile: ?

Thanks! This is also a good solution, but I can only mark one as solution.

Hi Zfzhang!

Please note – not necessarily relevant to your use case – that
neither your original mask expression nor the more-or-less
equivalent torch.lerp() protects against singularities in the
values (nor gradients).

And torch.where() – perhaps surprisingly – doesn’t protect
against singularities in the gradients.

Consider:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> m = (x > 0).float()
>>> y = m * torch.sqrt (x) + (1 - m) * torch.sqrt (-x)
>>> y
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
       grad_fn=<AddBackward0>)
>>> y.sum().backward()
>>> x.grad
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.where (x > 0, torch.sqrt (x), torch.sqrt (-x))
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
        0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
       grad_fn=<SWhereBackward>)
>>> y.sum().backward()
>>> x.grad
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.sqrt (x.abs())
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
        0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
       grad_fn=<SqrtBackward>)
>>> y.sum().backward()
>>> x.grad
tensor([-0.5270463228, -0.5976142883, -0.7071067691, -0.9128708839,
        -1.5811388493,  1.5811388493,  0.9128708839,  0.7071067691,
         0.5976142883,  0.5270463228])

Best.

K. Frank

Thanks! This is what i expected.