I’m trying to make the result of the output of the piece of code below differentiable with respect to epsilon. I know that where operator is not differentiable wrt to its condition.
Is there any way to make the output differentiable wrt. to epsilon ?
eps = nn.Parameter(torch.ones(1,))
nn.init.constant_(eps,0.5)
x = torch.rand(5,5)
y = torch.zeros(5,5)
z = torch.ones(5,5)
out = torch.where(x > eps, z, y)
The non-differentiability is from the discontinuity of the function. So you would have to relax (smoothen) the expression.
For non-NaN/inf, where(x > eps, z, y) is (x - eps > 0) * (z - y) + y which you could change to torch.sigmoid(x - eps) * T) * (z - y) + y for some tuneable float t.