I’m trying to make the result of the output of the piece of code below differentiable with respect to epsilon. I know that
where operator is not differentiable wrt to its condition.
Is there any way to make the output differentiable wrt. to epsilon ?
eps = nn.Parameter(torch.ones(1,)) nn.init.constant_(eps,0.5) x = torch.rand(5,5) y = torch.zeros(5,5) z = torch.ones(5,5) out = torch.where(x > eps, z, y)