Loading a state dict seems to erase grad

Hi Ya’ll,

I am experiencing a somewhat spooky error, in which the simple act of loading a new state dictionary into my model seems to erase the graph tying parameters to an external matrix.

Here is the code, and I will explain below.

    for grad, (name, param) in zip(recovered, model.named_parameters()):
        updated_param = param - stepsize * grad  # This is 100% dependent on M
        new_state_dict[name] = updated_param

    model.load_state_dict(new_state_dict)
    l = list(model.parameters())
    test = ag.grad(l[0][0, 0], M, allow_unused=True)  # and somehow this is 100% NOT dependent on M

Essentially I need to do some manual SGD because of other computations I am doing with M (that is a matrix that is modifying gradients). Every gradient (which is in “recovered”) is dependent on M.

Here is the dillema
updated_param is dependent on M
for any name, the parameter contained in new_stat_dict[name] is dependent on M
but some how, that last test shows that none of the parameters in my model are dependent on M

Any thoughts on what is going on here would be super helpful!

Hi,

This is “expected” given the how all these functions are supposed to work.
In particular, load_state_dict is 100% breaking the computational graph so you won’t be able to backpropagate through what it does.

The nn.Parameters() that are recovered by model.parameters() are supposed to be special Tensors that are learnt when you use nn tools.
In your case, your try and hack around that by having some nn.Parameters() that you don’t actually learn.
Unfortunately, you cannot do it this way.

A simple way to do this is to create your own nn.Module that inherits from the orignal module you wanted and that has M as an nn.Parameter() but not the original .weight for example. Then inside the forward pass of this module, you will first compute the new weight, update it with a simple self.weight = new_value, and then call the parent’s module forward method :slight_smile:

Hi Alban,

Thanks so much for breaking that down for me; I’m glad its expected and not a bug. My overall goal is for the graph to continue through the update of the weights/biases in my network by the gradient, because I want to use future losses to update my matrix M.

What I am doing is passing data through the model, getting gradient, applying ops to it using M, updating parameters with that transformed gradient, and then repeating. I am hoping that I can differentiate the loss calculated using parameters updated through gradients modified by M in order to update M. Will what you suggest enable me to do that?

You can definitely do that, the autograd will have no issue with it. The main issue you’re going to encounter is that the nn tools are not built for that :confused:
For example, the basic nn.Conv2d() is built to have learnable parameters. So if you don’t want these to be learnable but intermediary results, you’ll have to hack around it.

So the best way to implement this is going to be very dependent on your exact application.
If you have a code sample showing what you try to do, I’ll be happy to take a look at it. My advice for a first draft would be to do a toy example with limited scope that use as few nn tools as possible. And then you can try adding them back one by one, checking if they break your workflow.

Damn, thats good to know. My goal is to still update the params in the net, but via ag.grad(), not through loss.backward()… It sounds like from what you are saying that even that will not work if I try to use nn.

Luckily, the net I am implementing this on is a super simple Feed-Forward MLP, so I will try implementing it by hand via Torch Tensors and multiplication, and post it here (for people having the same issue in the future, I will detail what parts break my code). Out of curiosity, do you think an elementwise like nn.tanh will also break the computational graph, or am I going to need to program in my own non-linearities as well?

Hi,
Maybe I’ve been a bit negative.
Here is a code sample of how you can do this for linear:

You can use such module as a dropin replacement for Linear in your model and the rest will work as expected! (avoid multithreaded forward in a single copy of this net :wink: )

import torch
from torch import nn


class MyLinear(nn.Linear):
    def __init__(self, in_size, out_size):
        super().__init__(in_size, out_size)
        self.in_size = in_size
        self.out_size = out_size
        # Remove the original weights and bias
        del self.weight
        del self.bias

        # Add our new parameter
        self.M = nn.Parameter(torch.rand(out_size, 1))

    def forward(self, inp):
        # repopulate the original weights with our custom version
        # You can do any differentiable op here
        self.weight = self.M.expand(self.out_size, self.in_size) * 3
        self.weight[1, 1] = 0 # Mask entry 1,1 in in weight to see some difference in the gradients
        self.bias = self.M.squeeze(-1)

        result = super().forward(inp)

        # Avoid side effects in nn.Module
        del self.weight
        del self.bias

        return result


in_size = 10
out_size = 20

my_layer = MyLinear(in_size, out_size)

inp = torch.rand(5, in_size)
out = my_layer(inp)
out.sum().backward()

print(my_layer.M.grad)

Hope this helps !