Computational Graph continuity when replacing weights of a model

Hi all,

I have a question regarding the continuity of computational graph when replacing the weights of a model. I have a complete graph in which first I predict a set of weights corresponding to a neural net which is not explicitly optimized. I just need the model to compute a loss and therefore obtain gradients. However, after I predict weights of that model, upon replacing those using the .copy_ method Pytorch throws an error for the modification of Leaf Nodes. Also only copying the weights does not transfer information for completing the graph. Does anyone have a suggestion regarding how to go about this issue?


If the predicted weights aren’t really parameters of a model that are meant to be repeatedly updated, then you might want to consider using the torch.nn.functional API: torch.nn.functional — PyTorch 2.0 documentation for its forward function/layers and pass the predicted weights as an input to the model that isn’t explicitly optimized.

@eqy thanks for your response. The weights are in fact parameters. However rather than directly updating them using an optimizer, another model is supposed to predict those. Theoretically there is a complete computational graph, but I am not sure how I can handle in Pytorch.

Yes, they are parameters, but your use case suggests that they should not be treated as nn.Parameter type Parameter — PyTorch 2.0 documentation. Instead you can pass the weights as inputs via the functional API: torch.nn.functional — PyTorch 2.0 documentation

1 Like

Thanks @eqy. Your solution worked. You are right, since those would be weights involved in a computational graph and nothing more, the functional API is the way to go.