Create a model (model A) with a few linear layers that produces X, a list of tensors
Create a model (model B) with a few linear layers during the forward loop of model A (or create model B one time outside of the training loop and just reference it here each time)
Overwrite the parameters of model B with X
idx = 0
for n, p in model_B.named_parameters():
w = X[idx].reshape(128, 128)
idx += 1
p.data = w.float()
Run model B on some input Y given the new parameters
out = model_B(Y)
loss = F.l1_loss(out, ground_truth)
Finally, call loss.backward() in the forward loop of model A and update the parameters of model A
Currently, the gradient of the parameters of model A is None. Maybe the graph is interrupted when I try to replace the parameters of model B with X as p.data = w.float() is implicitly performed with no_grad()? I’m not sure if this is the issue though.
Some additional information:
X[0]…X[t] all have requires_grad == True
all of model_B parameters have is_leaf == True and requires_grad == True
if I comment out the code for model B and run something simple such as loss = X[0].mean(), then the gradients of the parameters of model A are valid, so the problem probably isn’t with model A itself
Models typically have Parameters and Parameters are typically things
that are optimized.
If I understand your use case, you don’t intend to optimize the “parameters”
(weights) of model B, but rather derive (“predict”) them using another model,
model A (whose Parameters you do intend to optimize).
There are a lot of ways of packaging this, but I think that writing model B
as a function, rather than structuring is as an “official model” (i.e., Module),
would be the cleanest approach and would communicate more clearly what
you are doing.
The following abbreviated pseudo-code illustrates this approach:
def functional_model_b (input, weights): # takes an input tensor and a list of four weight tensors
t = torch.nn.functional.linear (input, weight = weights[0], bias = weights[1]
t = torch.nn.functiona.relu (t)
t = torch.nn.functional.linear (t, weight = weights[2], bias = weights[3])
return t
model_b_weights = model_A (some_input) # returns list of four tensors
pred = functional_model_b (some_other_input, model_b_weights)
loss = criterion (pred, targ)
loss.backward() # backpropagates through functional_model_b and computes gradients for model_A's parameters