Imagine the following example:
A = nn.Linear(3, 4)
B = torch.ones(3).requires_grad_()
loss1 = A(B).sum()
grad1 = torch.autograd.grad(loss1, A.parameters(), create_graph=True)
# then update A with grad1
loss2 = F(A)
My question is: how should I update A such that I can calculate the gradient of loss2 w.r.t B?
if both A and B are tensors, I know how to do it. But in case that A is a nn.Module, how should I do it?
It’s not possible to compute gradient of loss2 if it isn’t a tensor. Do you mean to do something like out = A(b); loss1 = out.sum(); loss2 = F(out)?
Thanks for your reply! Unfortunately that is not the case. What i mean is:
we first update A, the gradient of A is a function of B.
After the update, I want to calculate a loss of the updated A. But then take derivative w.r.t B (since the updated A is a function of B).
I want to achieve the above but do not know how to do it if A is of type nn.Module