Torch.where() not working and returns nan value when using exp() function

You might be running into this behavior so could you check if your use case might suffer from the same logic creating invalid gradients?