Update model weights with another model and add them both to a common computation graph


Let’s say I have model A and model B. I store model A’s parameters using ParameterDict, compute model B’s parameters using model A parameters and load those new weights into model B’s state_dict. Is there a way to ensure model A’s parameters (parameters in the ParameterDict) are present in the computation graph such that when I do a backward pass, model B and model A are both updated?

Essentially, it would be great if I can do something like

for name, param_A, param_B in named_parameters(): 
    param_B = param_B - param_A*constant

such that param_B’s value is updated in place and param_A is added to the computational graph of param_B.

(For anyone familiar with MetaSGD, I want to implement a version of it from scratch).


@ptrblck if you get the time, would love to hear your thoughts on this - thank you so much!