How can I optimize non-leaf tensors?

Hey there,
I’m trying to implement an algorithm that should work as follows:

inputA, inputB, network1, network2

  1. network1 takes inputA and generates network2 (a list of tensors used as weights and biases)
  2. network2 takes inputB and we do gradient descent (only network 2 is updated here)
  3. network2 takes inputB and we do gradient descent (only network 1 is updated here)

since PyTorch optimizers use no_grad() while updating the weights I was planning to use higher to generate a differentiable optimizer starting from a standard PyTorch one.

The only problem is that I can’t create a PyTorch optimizer initialized with network2 weights as they are not leaf-tensors.

Would implementing sgd by hand solve the problem? Is there a simpler solution?

If I wasn’t clear enough in explaining the issue, please ask for further information

Thank you in advance.

:cry: