Differentiating where operator

Hi,

I’m trying to make the result of the output of the piece of code below differentiable with respect to epsilon. I know that where operator is not differentiable wrt to its condition.

Is there any way to make the output differentiable wrt. to epsilon ?

eps = nn.Parameter(torch.ones(1,))
nn.init.constant_(eps,0.5)

x = torch.rand(5,5)
y = torch.zeros(5,5)
z = torch.ones(5,5)

out = torch.where(x > eps, z, y)

Thank you

The non-differentiability is from the discontinuity of the function. So you would have to relax (smoothen) the expression.
For non-NaN/inf, where(x > eps, z, y) is (x - eps > 0) * (z - y) + y which you could change to torch.sigmoid(x - eps) * T) * (z - y) + y for some tuneable float t.

Best regards

Thomas

7 Likes

I see, thanks a lot it really helps !