Compute gradient of bitwise OR

I have to calculate the Intersection over Union (IoU) for two segmentations. For that, at some point I have to calculate the bitwise OR of two tensors. I am doing this with (a + b) > 0.
Unfortunately this isn’t differentiable. Does someone know how I can calculate the bitwise OR and still be able to backprop?

What if you consider them as float and use multiplications, addition to calculate the IoU? That’s how its done in most codes if I remember correctly

I can’t think of a function that only uses non-logical operators
EDIT:
Actually, I just used torch.max(a, b)

What I meant is, to calculate the IoU, we can use something like below.

Consider A, B are two 2D tensors made up of 0 and 1s.

Intersection = (A * B).sum()
Union = (A+B - A*B).sum()

IoU = intersection / Union

1 Like