I need to compare two tensors element-wise and get a binary mask. e.g.
c = a > b
c = c.float()
However, this operation is not differentiable. Does someone know how I can calculate the element-wise>
and still be able to backprop?
I need to compare two tensors element-wise and get a binary mask. e.g.
c = a > b
c = c.float()
However, this operation is not differentiable. Does someone know how I can calculate the element-wise>
and still be able to backprop?
torch.sign()
works for me
Hi Xianshun!
The greater-than operator (>
) is basically a step function.
A common differentiable approximation to a step function
is a sigmoid. (There are others.)
You could try:
c = nn.sigmoid (a - b)
You can sharpen (or smooth out) the sigmoid with a parameter:
c = nn.sigmoid (s * (a - b))
(By the way, are you sure you mean bitwise >
?)
Good luck.
K. Frank
sorry for the mistake. It should be element-wise. Thanks for your help.