For the case you tried out with pen
not undergoing a backward pass, we still add it to the total loss function and thus to the computation graph. Generation of y_fake
and y_real
always take two forward passes, and thus two stack pushes for each layer. During computation of the gradient penalty, we have a backward pass (because of calling autograd.grad
) and a forward pass (note the self(interpolates)
which calls the network). When backward()
is called on loss_d
then backward()
is expected to be called three times, as three corresponding forward passes were observed. It did not take into account the fact that an additional backward()
got called during the loss computation. We subsequently have more backward passes than forwards, and so the activation stack empties out for each layer. This might explains why if you add pen
or just do pen.backward()
the error arises - it is due to an additional grad computation which was not taken into account.
1 Like