For example,
tensors A & B are of the same shape, m is a 0-1 tensor mask of the same shape,
I want C = A * m + B * (1-m)
, but implemented efficiently in low-level C++ code.
Thank you!
For example,
tensors A & B are of the same shape, m is a 0-1 tensor mask of the same shape,
I want C = A * m + B * (1-m)
, but implemented efficiently in low-level C++ code.
Thank you!
it should work the way you wrote it given sizes are the same
Thanks! I’m looking for a faster low-level operation.
I think numpy is as fast as any low level language (pytorch as well)
Hello,
You’re indeed right that 2 multiplies + 1 add (+1 sub for 1-m
) + the intermediary memory allocations are completely unnecessary (and also look horrible) when one could just do a CUDA kernel where the thread just looks at the boolean in m[thread_idx] and chooses either A[thread_idx] or B[thread_idx], which should be much faster. Fortunately, there’s already a native torch function which does just that: torch.where
(torch.where — PyTorch 1.10.0 documentation)
Another similar function is torch.lerp
if your mask is a float tensor
Thanks, that’s precise! This function seems to have been introduce since v0.4, how can i miss it ?
Thanks! This is also a good solution, but I can only mark one as solution.
Hi Zfzhang!
Please note – not necessarily relevant to your use case – that
neither your original mask expression nor the more-or-less
equivalent torch.lerp()
protects against singularities in the
values (nor gradients).
And torch.where()
– perhaps surprisingly – doesn’t protect
against singularities in the gradients.
Consider:
>>> import torch
>>> torch.__version__
'1.9.0'
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> m = (x > 0).float()
>>> y = m * torch.sqrt (x) + (1 - m) * torch.sqrt (-x)
>>> y
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
grad_fn=<AddBackward0>)
>>> y.sum().backward()
>>> x.grad
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.where (x > 0, torch.sqrt (x), torch.sqrt (-x))
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
grad_fn=<SWhereBackward>)
>>> y.sum().backward()
>>> x.grad
tensor([nan, nan, nan, nan, nan, nan, nan, nan, nan, nan])
>>> x = torch.arange (-0.9, 1.0, 0.2, requires_grad = True)
>>> y = torch.sqrt (x.abs())
>>> y
tensor([0.9486832619, 0.8366600275, 0.7071067691, 0.5477225780, 0.3162277639,
0.3162277639, 0.5477225780, 0.7071067691, 0.8366600275, 0.9486832619],
grad_fn=<SqrtBackward>)
>>> y.sum().backward()
>>> x.grad
tensor([-0.5270463228, -0.5976142883, -0.7071067691, -0.9128708839,
-1.5811388493, 1.5811388493, 0.9128708839, 0.7071067691,
0.5976142883, 0.5270463228])
Best.
K. Frank
Thanks! This is what i expected.