Why is my parameter not changing and its gradient 0?

I am building a really simple model to learn the parameter in a Poisson model and I am not sure where I am going wrong. I am using pytorch.nn and doing the following.

I made some really simple fake data

# This is the value I am trying to estimate

x = torch.tensor(2.0)


# This is a value drawn from the Poisson(x) distribution 
# In this example it is 4

y = torch.poisson(x).reshape(1)

Then I just set up a really simple model

# I initialised the parameter that is going to estimate x with a random value (0.2) 
# and set that it requires a gradient

a = torch.tensor([0.2], requires_grad = True)


# I define the loss function with log_input set to false

loss_function = torch.nn.PoissonNLLLoss(log_input = False)


# Defined the model

def model(a):
    return torch.poisson(a)

# And the parameter to be optimised 
# I chose SGD arbitrarily, maybe this is the problem?

optimizer = torch.optim.SGD([a], lr = 0.1)

Then I do iterations to update a

for i in range(2000):
    
    # Forward pass

    y_pred = model(a)


    # Compute the loss

    loss = loss_function(y_pred, y)
    

    # Backprop

    optimizer.zero_grad()
    
    loss.backward()
    
    
    # Update parameters

    optimizer.step()

The problem is after this the a is still 0.2 and if I call a.grad it is 0. Where am I going wrong?

I have also tried instead to initiate a class for the model inheriting a nn.Module. However the same problem persists :

class learning_model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.rand(1))
        self.a.requires_grad = True
        
    def forward(self):
        return torch.poisson(self.a)
model = learning_model()

loss_function = nn.PoissonNLLLoss(log_input = False)

optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)

print(model.a)

Outputs:

Parameter containing:
tensor([0.1402], requires_grad=True)

Then:

for i in range(20):
    
    # Forward pass
    y_pred = model()
    
    # Compute the loss
    loss = loss_function(y_pred, y)
    
    # Backprop
    optimizer.zero_grad()
    
    loss.backward()    
    
    # Update parameters
    optimizer.step()
    
print(model.a, '\n gradient:', model.a.grad)

Outputs:

Parameter containing:
tensor([0.1402], requires_grad=True) 
 gradient: tensor([0.])

Hi Charlie!

The is because that although torch.poisson() returns floating-point
numbers and attaches a grad_fn=<PoissonBackward> to the
computation graph, it gives, by design, zero gradients.

First is the question of what the derivative of a random sample
with respect to some distribution parameter ought to be. Pytorch
chooses zero. See the discussion of “gradient estimators” in the
torch.distributions documentation.

Second is the fact that torch.poisson() returns values that are
discrete integers (even though they are packaged as floating-point
numbers). Since the result of torch.poisson() is stuck at an integer,
it can’t change by “just a little bit,” so it makes sense to define its
derivative as zero.

Best.

K. Frank

1 Like

Hi Kirk,

Thanks so much for the great reply! Is there any alternative you could give instead then?
I really need to use a Poisson distribution for my model.

Thanks,
Charlie

Hi Charlie!

The best I can think of is:

Find a continuous approximation to the Poisson distribution that
adequately meets you needs. (Search for “cumulative Poisson
distribution.”) This won’t be exactly the Poisson distribution, so
you won’t get exactly the same results as you would have with the
Poisson distribution.

Implement the inverse cumulative probability distribution for the
approximation you chose. Let’s call this function f (u, a), where
u is understood to be a random value sampled uniformly between
zero and one, and a is your Poisson-distribution parameter
(sometimes called the Poisson “rate”). Holding a fixed, for a uniform
deviate u, f (u, a) is a random deviate sampled from your
(approximate) Poisson distribution.

Package f() as a pytorch “custom autograd function” class. It will
have a forward() function that calculates f (u, a). u will be a
tensor with requires_grad = False that consists of one or more
uniform deviates , while a will be a tensor that wraps the rate
parameter and has requires_grad = True.

Now the important part: Your custom autograd function will also have
a backward() function that calculates the derivative of f (u, a) with
respect to a. (You have to work out the math for this and implement
it.)

You can now feed random approximate-Poisson deviates into the
calculations that lead to your loss function, and then use pytorch’s
autograd system to backpropagate and optimize your loss with respect
to your rate parameter a (as well as any other parameters you like).

Of course, any single random deviate is “noisy,” so your loss could
also be noisy, jump around, and not optimize well. But if any given
optimization step involves averaging over many deviates, that noise
will smooth out, and you will be able to optimize in a sensible and
mathematically-legitimate way with respect to a.

Best.

K. Frank