Hello, I need a function that computes the full log poisson loss, but can’t find anything useful for PyTorch.
The TF equivalent would be tf.nn.log_poisson_loss(targets, log_input, compute_full_loss=True)
. Is there anything already implemented for PyTorch?
Thanks!
I implemented it like in TF as such:
def log_poisson_loss(targets, log_input, compute_full_loss):
if targets.size() != log_input.size():
raise ValueError(
"log_input and targets must have the same shape (%s vs %s)" %
(log_input.size(), targets.size()))
result = torch.exp(log_input) - log_input * targets
if compute_full_loss:
point_five = 0.5
two_pi = 2 * math.pi
stirling_approx = (targets * torch.log(targets)) - targets
+ (point_five * torch.log(two_pi * targets))
zeros = torch.zeros_like(targets, dtype=targets.dtype)
ones = torch.ones_like(targets, dtype=targets.dtype)
cond = (targets >= zeros) & (targets <= ones)
result += torch.where(cond, zeros, stirling_approx)
return result