Poisson loss function

I want to predict count data using a simple fully connected network. I could in principle frame it as a classification problem where each class corresponds to the event count, but I would like to do it properly using a Poisson loss function. Two quick questions:

  1. I can’t seem to find the implementation of this loss function, am I missing anything?
  2. I also cannot seem to find any examples of using pytorch to predict count data, any pointers?

Thank you.

This article provides the definition of the Poisson loss. You can calculate it using a few pytorch functions.

1 Like

Super simple question: since lambda > 0 [because of the log(lambda_hat)] how could I go about defining the loss function to guarantee proper lambdas? In principle implementing it with pytorch functions is straightforward:

def poissonLoss(predicted, observed):
    """Custom loss function for Poisson model."""
    loss=torch.mean(predicted-observed*torch.log(predicted))
    return loss

But I obviously need to force the output to be strictly positive otherwise I’ll get -inf and nans. What sort of correction for the last layer could I use? Like ReLU with a floor>0 of sorts?

1 Like

Lambda should be exp(theta * x).

1 Like

I think that Poisson loss function is not supported by default in PyTorch. You may try this:

BTW, I think that one should focus on loss functions that can manage learning rates self adaption instead of probabilities.

I’ve implemented the following poisson loss function using the first suggested link:

def poissonLoss(xbeta, y):
    """Custom loss function for Poisson model."""
    loss=torch.mean(torch.exp(xbeta)-y*xbeta)
    return loss

Training:

optimizer=torch.optim.SGD(net.parameters())
lossFn=poissonLoss

output_ = net(input_)
loss = lossFn(output_, target_)
net.zero_grad()
loss.backward()
optimizer.step()

Since I’m using pytorch internals, is this the right way to train/backprop/optimize? Not sure about this because my model is not learning.

Thanks,

You should use optimizer.zero_grad() instead of it.

https://pytorch.org/tutorials/beginner/pytorch_with_examples.html#pytorch-optim

1 Like