Compute gradient of bitwise >

I need to compare two tensors element-wise and get a binary mask. e.g.

c = a > b
c = c.float()

However, this operation is not differentiable. Does someone know how I can calculate the element-wise> and still be able to backprop?

torch.sign() works for me

Hi Xianshun!

The greater-than operator (>) is basically a step function.
A common differentiable approximation to a step function
is a sigmoid. (There are others.)

You could try:

c = nn.sigmoid (a - b)

You can sharpen (or smooth out) the sigmoid with a parameter:

c = nn.sigmoid (s * (a - b))

(By the way, are you sure you mean bitwise >?)

Good luck.

K. Frank

1 Like

sorry for the mistake. It should be element-wise. Thanks for your help.