Gradients being returned as zeros in custom likelihood function

I have created a likelihood function (given below) where I manually input vectors of potential parameters into a NN (referred to as self.model in the code) and return the likelihood.

def log_likelihood(self, x):
    ll = torch.zeros(len(x))

    for i in range(len(x)):
        new_params = x[i, :]
        index = 0
        for p in self.model.parameters():
            size = p.numel()
            p.data = new_params[index:index+size].view(p.shape)
            index += size

        logits = self.model(self.X)
        labels = self.Y 
        crit = nn.CrossEntropyLoss(reduction='sum')
        ll[i] = - crit(logits, labels)

    return ll 

I am trying to also calculate the gradient of the likelihood function w.r.t the parameters I input using the functional jacobian method.

def log_prob_grad(x)
    grad = torch.autograd.functional.jacobian(log_likelihood, x, vectorize=True)
    return grad 

This gradient function returns a tensor full of zeros. Any help on this problem would be much appreciated.

Hi Armal!

Two related comments:

You shouldn’t be using .data. It is deprecated (in the public api) and can
lead to errors.

Pytorch does not support Parameters being modified while gradients are
being tracked.

You should probably do something like:

for p in self.model.parameters():
    p.grad = None
    with torch.no_grad():
        p.copy_ (new_params ...)   # whatever the details are

logits = self.model (self.X)
labels = self.Y 
crit = nn.CrossEntropyLoss (reduction = 'sum')
loss = - crit (logits, labels)
loss.backward()   # this will populate the grads of model.parameters()

The .grad properties of the model.parameters() will be the gradient of
loss with respect to new_params (basically x) – they just won’t be directly
attached to new_params (nor x).

If you really want to do this using functional.jacobian(), you will likely be
better off implementing model with purely functional calls (e.g., use things like
torch.nn.functional.linear() instead of torch.nn.Linear). This way
you can pass the parameter values in your x into the functional calls, rather
than work around the fact that pytorch doesn’t want you to load new values
into Parameters while gradients are being tracked.

Best.

K. Frank

2 Likes

Hi Armal (and @KFrank)

Another way (expanding upon @KFrank’s original answer) would be to use the torch.func namespace to compute derivatives w.r.t parameters in a functionalized manner.

For example,

net = Model(*args, **kwargs)

crit = nn.CrossEntropyLoss(reduction='sum')

from torch.func import functional_call, jacrev, vmap

def fcall_model(params, inputs): #functionalized call (params now an input)
  return functional_call(net, params, inputs)
  
def log_likelihood(params, inputs, labels):
  logits = fcall_model(params, inputs)
  ll_loss = -1.0 * crit( logits, labels)
  return ll_loss

grads_ll_wrt_params = vmap(jacrev(log_likelihood, argnums=(0)), in_dims=(None, 0, 0))
#You could also used forward-mode AD with jacfwd instead.

Also, I believe using torch.autograd.functional.jacobian computes the derivatives between samples in the batch, which should be zero by definition. (Unless you have any norm layers in your network).

1 Like

Hi Alpha!

Thanks for that. It looks like functional_call() is the right tool for
painlessly turning a “Module” model that contains its parameters into
a functional model into which you can pass the parameters.

Best.

K. Frank

1 Like

Hi @KFrank,

Also, I forgot to mention the parameters of the network should be defined as a dictionary with its corresponding names rather than just a list.

For example,

net = Model(*args, **kwargs)

params = dict(net.named_parameters()) #changed from net.parameters()
1 Like

Thank you so much for this. All the suggestions by @AlphaBetaGamma96 and @KFrank were very helpful in helping me achieve what I needed. The vmap function was also very helpful in being able to vectorize the approach.

This forum is worth its weight in gold due to the community.