Create a model (model A) with a few linear layers that produces X, a list of tensors

Create a model (model B) with a few linear layers during the forward loop of model A (or create model B one time outside of the training loop and just reference it here each time)

Overwrite the parameters of model B with X

idx = 0
for n, p in model_B.named_parameters():
w = X[idx].reshape(128, 128)
idx += 1
p.data = w.float()

Run model B on some input Y given the new parameters

out = model_B(Y)
loss = F.l1_loss(out, ground_truth)

Finally, call loss.backward() in the forward loop of model A and update the parameters of model A

Currently, the gradient of the parameters of model A is None. Maybe the graph is interrupted when I try to replace the parameters of model B with X as p.data = w.float() is implicitly performed with no_grad()? Iâ€™m not sure if this is the issue though.

Some additional information:

X[0]â€¦X[t] all have requires_grad == True

all of model_B parameters have is_leaf == True and requires_grad == True

if I comment out the code for model B and run something simple such as loss = X[0].mean(), then the gradients of the parameters of model A are valid, so the problem probably isnâ€™t with model A itself

Models typically have Parameters and Parameters are typically things
that are optimized.

If I understand your use case, you donâ€™t intend to optimize the â€śparametersâ€ť
(weights) of model B, but rather derive (â€śpredictâ€ť) them using another model,
model A (whose Parameters you do intend to optimize).

There are a lot of ways of packaging this, but I think that writing model B
as a function, rather than structuring is as an â€śofficial modelâ€ť (i.e., Module),
would be the cleanest approach and would communicate more clearly what
you are doing.

The following abbreviated pseudo-code illustrates this approach:

def functional_model_b (input, weights): # takes an input tensor and a list of four weight tensors
t = torch.nn.functional.linear (input, weight = weights[0], bias = weights[1]
t = torch.nn.functiona.relu (t)
t = torch.nn.functional.linear (t, weight = weights[2], bias = weights[3])
return t
model_b_weights = model_A (some_input) # returns list of four tensors
pred = functional_model_b (some_other_input, model_b_weights)
loss = criterion (pred, targ)
loss.backward() # backpropagates through functional_model_b and computes gradients for model_A's parameters