Hi everyone,
I have a network made of 2 parts, named part A and part B. And there are two loss, named loss_A
and loss_B
. The following code shows their relationship.
# input -> A -> B -> output
x = A(input)
output = B(x)
loss_A = lossA(output, label)
loss_B = lossB(output, label)
Note that the optimization over loss_A
only updates part A, and the optimization over loss_B
only updates part B. I am wondering how to do it.
Here is my resolution, but I am not sure if it is correct.
# in function forward()
x = A(input)
output = B(x)
output_detach = B(x.detach)
# in function main()
optimizer_A = SGD([{'params': A.parameters()}], lr=lr, weight_decay=weight_decay)
optimizer_B = SGD([{'params': B.parameters()}], lr=lr, weight_decay=weight_decay)
loss_B = lossB(output_detach, label)
optimizer_B.zero_grad()
loss_B.backward(retain_graph=True)
optimizer_B.step()
loss_A = lossA(output, label)
optimizer_A.zero_grad()
loss_A.backward()
optimizer_A.step()
Could anyone solve this question? Thanks in advanced!