This code show training process for one batch (I missed all stuff before, cause its unnecessary)
for j in range(critic_policy(epoch)):
output = netC(train_full)
generator_loss = torch.mean(cramer_critic(train_full, generated_full_2) * w_full * w_x_2 -
cramer_critic(generated_full_1, generated_full_2) * w_x_1 * w_x_2)
alpha = torch.empty(train_full.shape[0], 1, device=device).normal_(0.0,1.0)
interpolates = alpha * train_full + (1.0 - alpha) * generated_full_1
disc_interpolates = cramer_critic(interpolates, generated_full_2)
gradients = grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones_like(disc_interpolates))[0]
slopes = torch.norm(torch.reshape(gradients, (list(gradients[0].shape)[0], -1)), dim=1)
gradient_penalty = torch.mean(torch.pow(torch.max(torch.abs(slopes) - 1,
torch.zeros(8, device=device)), 2))
critic_loss = lambda_pt(epoch) * gradient_penalty - generator_loss
critic_loss.backward()
optC.step()
optC.zero_grad()
And after one success iteration raise that error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-20-2654469844e0> in <module>
32
33 critic_loss = lambda_pt(epoch) * gradient_penalty - generator_loss
---> 34 critic_loss.backward()
35 optC.step()
36 optC.zero_grad()
/usr/local/lib/python3.6/dist-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
164 products. Defaults to ``False``.
165 """
--> 166 torch.autograd.backward(self, gradient, retain_graph, create_graph)
167
168 def register_hook(self, hook):
/usr/local/lib/python3.6/dist-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
97 Variable._execution_engine.run_backward(
98 tensors, grad_tensors, retain_graph, create_graph,
---> 99 allow_unreachable=True) # allow_unreachable flag
100
101
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
All ideas, that i found is all about hidden states (netC is just 5 linear layers), so i think it’s doesn’t match (I’m new in pytorch, so maybe it’s wrong state). And idea to use retain_graph=True is not good, because somebody write, that it make training slower, I don’t want it, but if that necessary, please help to implement it (for which iteration do it)