Hey guys,
In my code i am trying to train a model so that it moves a given sample to a given target distribution. The next step is to introduce intermediate distributions and to use a loop so that the particles (the samples) are moved from one distribution to another iteratively. Unfortunately, at the second iteration I get the following Error-Message when running my code: “Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward”
I don’t think that retain_graph = True would fit my problem, since I would rather kind of clear the model after every iteraion than retain it. However, i gave it a shot, the result is the following error:
one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1, 2]] is at version 2251; expected version 2250 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
Here are the relevant parts of my code:
for k in range(1, K_intermediate+1):
flow = BasicFlow(dim=d, n_flows=n_flows, flow_layer=flow_layer)
ldj = train_flow(
flow, x, W, f_intermediate(x,k-1), lambda x:
f_intermediate(x,k), epochs=2500
)
x, xtransp = flow(x)
x = xtransp.data
def train_flow(flow, sample, weights, f0, f1, epochs=1000):
optim = torch.optim.Adam(flow.parameters(), lr=1e-2)
for i in range(epochs):
x0, xtransp = flow(sample)
ldj = accumulate_kl_div(flow).reshape(sample.size(0))
loss = det_loss(
x_0 = x0,
x_transp = xtransp,
weights = weights,
ldj = ldj,
f0 = f0,
f1 = f1
)
loss.backward(retain_graph = True)
optim.step()
optim.zero_grad()
reset_kl_div(flow)
if i % 250 == 0:
if i > 0 and previous_loss - loss.item() < 1e-06:
break
print(loss.item())
previous_loss = loss.item()
if torch.isnan(loss) == True:
break
return ldj
Note that the problem only arises since I start capturing the ldj-value (log of the determinant jacobian, for those who wonder). Since this value is crucial for further computations i can not just delete this.
Happy for any kind of advices,
Christian