Distribution backward grad inspection not matching numerical reproduction

(João Francisco Santos) #1


I have implemented a loss based on the Negative Binomial Distribution probability mass function.
I tried to validate this implementation by reproducing the gradients wrt. the distribution’s parameters.
Unfortunately the obtained values do not match. I am unsure if it is my loss function implementation that is at fault since I believe that the numerical reproduction is correct according to sources:
[http://vixra.org/pdf/1211.0113v1.pdf](source 1 - (Negative Binomial distribution PMF and likelihood gradient))
[https://en.wikipedia.org/wiki/Negative_binomial_distribution](source 2 - (Negative Binomial distribution PMF and likelihood gradient))

Here is the code snippet with an example and prints:

import torch
from scipy.special import digamma
import numpy as np

output_soft = torch.tensor([[5],[0.1]],requires_grad=True) # model output 

def loss(input,target):
    total_count = input[0]
    probability = input[1]
    target_p_tc_gamma = torch.tensor([target + total_count],dtype = torch.float,requires_grad=True).lgamma()
    r_gamma = total_count.lgamma()
    target_factorial = torch.tensor([target + 1], dtype = torch.float).lgamma()
    combinatorial_term = torch.tensor([target_p_tc_gamma-r_gamma-target_factorial],dtype = torch.float,requires_grad = True).exp()
    prob_term = probability.pow(target)
    comp_prob_term = torch.tensor([1-probability],dtype = torch.float,requires_grad = True).pow(total_count)
    likelihood_target = combinatorial_term*prob_term*comp_prob_term
    return - likelihood_target.log()

target = torch.tensor([5.])
loss = loss(output_soft,target)
print("Backward gradient inspection",output_soft.grad.detach().numpy())

def neg_PMF_gradient_check(input,target):
    target = target.detach().numpy()
    total_count = input[0].detach().numpy()
    probability = input[1].detach().numpy()
    dg = digamma
    comp_prob = 1-probability
    grad_tc = dg(target + total_count) - dg(total_count) + np.log((total_count)/(total_count + target))
    grad_prob = target*(1/probability) - total_count/(comp_prob)
    return [-grad_tc[0],-grad_prob[0]]

print("Numerical reproduction",neg_PMF_gradient_check(output_soft,target))

Any clues are very much appreciated!