In that case you have to write code in your netG
class to ensure that gradients on z
are propagated (I am not sure if this is the correct term; what I mean is that you should tell PyTorch how gradients with respect to z
change, when you call netG.forward(z)
.).
This is not something that I have done (yet!), so I am not entirely clear on how to do this. I think you need to implement a backward
function inside your netG
class. The code in this message may help you get started.