Full Log Poisson Loss for PyTorch

Hello, I need a function that computes the full log poisson loss, but can’t find anything useful for PyTorch.
The TF equivalent would be tf.nn.log_poisson_loss(targets, log_input, compute_full_loss=True). Is there anything already implemented for PyTorch?

I implemented it like in TF as such:

def log_poisson_loss(targets, log_input, compute_full_loss):
    if targets.size() != log_input.size():
        raise ValueError(
            "log_input and targets must have the same shape (%s vs %s)" %
            (log_input.size(), targets.size()))

    result = torch.exp(log_input) - log_input * targets
    if compute_full_loss:
        point_five = 0.5
        two_pi = 2 * math.pi

        stirling_approx = (targets * torch.log(targets)) - targets 
                       + (point_five * torch.log(two_pi * targets))
        zeros = torch.zeros_like(targets, dtype=targets.dtype)
        ones = torch.ones_like(targets, dtype=targets.dtype)
        cond = (targets >= zeros) & (targets <= ones)
        result += torch.where(cond, zeros, stirling_approx)
    return result