bilibili
(Bo Liu)
March 15, 2022, 12:31am
#1
Imagine the following example:

A = nn.Linear(3, 4)
B = torch.ones(3).requires_grad_()

loss1 = A(B).sum()
grad1 = torch.autograd.grad(loss1, A.parameters(), create_graph=True)

# then update A with grad1

loss2 = F(A)

My question is: how should I update A such that I can calculate the gradient of loss2 w.r.t B?
if both A and B are tensors, I know how to do it. But in case that A is a nn.Module, how should I do it?

It’s not possible to compute gradient of loss2 if it isn’t a tensor. Do you mean to do something like out = A(b); loss1 = out.sum(); loss2 = F(out)?

bilibili
(Bo Liu)
March 15, 2022, 4:21pm
#3
Thanks for your reply! Unfortunately that is not the case. What i mean is:

we first update A, the gradient of A is a function of B.

After the update, I want to calculate a loss of the updated A. But then take derivative w.r.t B (since the updated A is a function of B).

I want to achieve the above but do not know how to do it if A is of type nn.Module

bilibili
(Bo Liu)
March 15, 2022, 6:23pm
#4