How to take the gradient of the updated A w.r.t B where A is updated using B

Imagine the following example:

A = nn.Linear(3, 4)
B = torch.ones(3).requires_grad_()

loss1 = A(B).sum()
grad1 = torch.autograd.grad(loss1, A.parameters(), create_graph=True)

# then update A with grad1

loss2 = F(A)

My question is: how should I update A such that I can calculate the gradient of loss2 w.r.t B?
if both A and B are tensors, I know how to do it. But in case that A is a nn.Module, how should I do it?

It’s not possible to compute gradient of loss2 if it isn’t a tensor. Do you mean to do something like out = A(b); loss1 = out.sum(); loss2 = F(out)?

Thanks for your reply! Unfortunately that is not the case. What i mean is:

  1. we first update A, the gradient of A is a function of B.

  2. After the update, I want to calculate a loss of the updated A. But then take derivative w.r.t B (since the updated A is a function of B).

I want to achieve the above but do not know how to do it if A is of type nn.Module

I have actually found the solution. One can use the higher library for this. GitHub - facebookresearch/higher: higher is a pytorch library allowing users to obtain higher order gradients over losses spanning training loops rather than individual training steps.