Calculate gradients using chain rule

I am new to pytorch and am implementing this pipeline.

  1. training of first model (model1) on a dataset (d1)
  2. generation of new dataset (d2) using model1 on unlabeled examples
  3. training of model2 on the new dataset d2
  4. estimating A, by reducing the validation loss of model2 on the validation set of d1.
    I have produced a distilled version of my use case.
import torch
import torch.nn as nn

ip1= torch.rand(1,1,3)
ip2 =  torch.rand(1,1,3)
ip3 = torch.rand(1,1,3)

cnn1d_1 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=3, stride=1)
cnn1d_2 = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1, stride=2)
A=torch.rand(1, requires_grad=True)

optimizer1 = torch.optim.Adam(cnn1d_1.parameters(), lr=0.001)
optimizer2 = torch.optim.Adam(cnn1d_2.parameters(), lr=0.001)
optimizer3 = torch.optim.SGD([A], lr=0.001)

# step 1
output=cnn1d_1(ip1) #forward pass through first model of sample ip1
loss1=output.sum()*A # calculating some loss * A
for param in cnn1d_1.parameters():

loss1.backward(retain_graph=True, create_graph=True) #should calculate gradients for model1's parameters and A
optimizer1.step() #update model1's params

#step 2
inp=cnn1d_1(ip2) #forward pass of unlabeled example ip2 through model1
out=cnn1d_2(inp) #output of previous step fed into model 2
loss2=out.sum() #calculating some loss.

loss2.backward(retain_graph=True, create_graph=True)  # should calculate gradients for model 1 and model 2's parameters.
optimizer2.step() #update model2's params

#step 3
out_new=cnn1d_2(ip3) # fake sample from validation set

loss3.backward(retain_graph=True, create_graph=True) # want to calculate the gradients of loss 3 wrt A
optimizer3.step() # update A based on calculated gradients.

So my aim is to calculate the gradients of loss 3 wrt A and update A based on them. Now if am not wrong,
→ model1’s params are updated by calculating loss1 which inturn depends on A.
→ model2’s parameters are dependent on model1’s params as ‘inp’ (generated using model1) is used to train model2
→ Hence I believe that on doing loss3.backward() ,using chain rule, A’s gradients should be automatically updated using the above mentioned dependencies.
But on printing A’s gradients after loss3.backward() the gradient values don’t change. Why is it so? Please correct me if am wrong.

I think the issue becomes more clear if we write down what each function depends on:
output1 = model1(model1_params, input_1)
loss1 = loss1_fn(output1, A)
output2 = model2(model2_params, model1(input2))
loss2 = loss2_fn(output2)
output3 = model2(input3)
loss3 = loss1_fn(output3)
Note that while A is part of the optimization algorithm of model1, it is not actually explicitly in the function of loss3. So I don’t think its gradient should be updated on loss3.backward(). For example, we would also not expect the parameters of optimizer1 such as the learning rate to have a gradient computed, even though model1’s params also “depend” on them.

Thanks for reverting @eqy and putting the problem so lucidly. I was also thinking the same. Do you have any idea how I can find the gradients of loss3 wrt A?

Finally, I used torch.autograd.grad() and passed the desired parameters to calculate the double derivatives in order to calculate the gradients and hence loss. The solution is working for now.