My network is as follows, sorry for my simple drawing.
where:
G1, G2 are 2 generators and fake_B_temp is output of G1 and fake_B is output of G2.
The loss of G1 is:
and the loss of G2 is the L1 loss too.
Thus, I have 2 backward() in my net to update G1 and G2 respectively. I have make the output of G1: fake_B_temp to fake_B_temp.detach() to stop the gradient from G2.
My optimize_parameter function is:
I want to know iwhether I do is right? Is there any problem or better way to improve?Thanks!