Implementing log1mexp in Pytorch


I’m trying to implement the log1mexp function (basically log(1 - exp(-x) ) ) which computes the value accurately. This is basically eq(7) in the following note:

For my application, exp(-x) can be very close to 1 at times, so a numerically stable implementation is necessary. I’m thinking of how to implement this in pytorch and am confused between 3 alternatives: The first is to just compose this with existing pytorch functions, and wrap that in a python def. The second is to extend autograd.Function, and define my own forward and backward procedures. The third would be to write some C code, and interface that with pytorch.

Which one is more suitable? Thanks for the help.

the second procedure seems simple and easy to do, will suit your purpose.

Did you ever implement this? Is there a plan of having torch have a log1mexp function generically?