For example,

tensors **A** & **B** are of the same shape, **m** is a 0-1 tensor mask of the same shape,

I want `C = A * m + B * (1-m)`

, but implemented efficiently in low-level C++ code.

Thank you!

For example,

tensors **A** & **B** are of the same shape, **m** is a 0-1 tensor mask of the same shape,

I want `C = A * m + B * (1-m)`

, but implemented efficiently in low-level C++ code.

Thank you!

it should work the way you wrote it given sizes are the same

Thanks! I’m looking for a faster low-level operation.

I think numpy is as fast as any low level language (pytorch as well)

Hello,

You’re indeed right that 2 multiplies + 1 add (+1 sub for `1-m`

) + the intermediary memory allocations are completely unnecessary (and also look horrible) when one could just do a CUDA kernel where the thread just looks at the boolean in m[thread_idx] and chooses either A[thread_idx] or B[thread_idx], which should be much faster. Fortunately, there’s already a native torch function which does just that: `torch.where`

(torch.where — PyTorch 1.10.0 documentation)

1 Like

Another similar function is `torch.lerp`

if your mask is a float tensor

1 Like

Thanks, that’s precise! This function seems to have been introduce since v0.4, how can i miss it ?

Thanks! This is also a good solution, but I can only mark one as solution.

Hi Zfzhang!

Please note – not necessarily relevant to your use case – that

neither your original mask expression nor the more-or-less

equivalent `torch.lerp()`

protects against singularities in the

values (nor gradients).

And `torch.where()`

– perhaps surprisingly – doesn’t protect

against singularities in the gradients.

Consider:

```
>>> import torch
>>> torch.__version__
'1.9.0'
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> m = (x > 0).float()
>>> y = m * torch.sqrt (x) + (1 - m) * torch.sqrt (-x)
>>> y
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
grad_fn=<AddBackward0>)
>>> y.sum().backward()
>>> x.grad
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.where (x > 0, torch.sqrt (x), torch.sqrt (-x))
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
grad_fn=<SWhereBackward>)
>>> y.sum().backward()
>>> x.grad
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.sqrt (x.abs())
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
grad_fn=<SqrtBackward>)
>>> y.sum().backward()
>>> x.grad
tensor([-0.5270463228, -0.5976142883, -0.7071067691, -0.9128708839,
-1.5811388493, 1.5811388493, 0.9128708839, 0.7071067691,
0.5976142883, 0.5270463228])
```

Best.

K. Frank

Thanks! This is what i expected.