Hyper-net cannot change the weights of the backbone-net using state_dict()

Hi,

I want to use a hyper-net to change the weights of a backbone-net. The following snippet is not meaningful but illustrates in a concise way what I want to do:


import sys
import torch
from torch import nn, optim

# our backbone-net
modelV = nn.Linear(1, 1, bias=False)
nn.init.zeros_(modelV.weight)
# our hyper-net
modelW = nn.Linear(1, 1, bias=False)
nn.init.constant_(modelW.weight, 0.1)

optimizer = optim.Adam(modelW.parameters(), lr=1e-3)
loss_func = nn.MSELoss()

# our data
# the hyper-net takes in z = 1 and must output 2 if trained
# the backbone-net takes in X and must output Y if trained with Y = 2 X
Z = torch.tensor(1.).reshape(-1,1)
X = torch.Tensor( [ -1. , 0 , 1., 2. ] ).reshape(-1,1)
y = torch.Tensor( [ -2., 0., 2., 4. ] ).reshape(-1,1)

for ite in range(10):
    hypernet_prediction = modelW(Z)
    # build a backbone-net
    backbone_state_dict = modelV.state_dict()
    backbone_state_dict["weight"] = hypernet_prediction[0].view(backbone_state_dict["weight"].size())
    modelV.load_state_dict(backbone_state_dict)
    # predict with backbone-net
    backbonet_prediction = modelV(X)
    # compute loss
    loss = loss_func(backbonet_prediction, y) 
    print("ite = {}, loss = {}".format(ite, loss.item()))
    # optimize
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

This way of changing weights is very elegant but unfortunately it does not work because as mentioned in other threads the loading of the state dict is not tracked as a differentiable op by autograd.

What alternatives are there ? (keeping in mind that the real models will be much much larger)
thanks for any help