I am training a network to predict trajectories of particles. The main training loop is:
for t,time_slice in enumerate(TrajectoryBatch(traj_batch,time_steps = time_steps)):
time_slice.to(device)
if t == 0:
enc_pred = 0
else:
enc_pred = rollout
# forward pass
enc,recon,rollout,pred = mod(time_slice)
# compute loss
dx,dg = torch.cdist(time_slice.x,time_slice.x),torch.cdist(enc,enc) # distance matrices for metric loss
l1 = (1/time_steps)*ae_loss(time_slice.x,recon)
l2 = (1/time_steps)*pred_loss(time_slice.y,pred)
l3 = (1/time_steps)*met_loss(dx,dg)
if t == 0:
l4 = (1/time_steps)*lin_pred_loss(enc,enc)
else:
l4 = (1/time_steps)*lin_pred_loss(enc,enc_pred)
time_slice_loss = ae_reg*l1+pred_reg*l2+met_reg*l3+lin_pred_reg*l4
# backprop
time_slice_loss.backward(retain_graph = True)
optimizer.step()
optimizer.zero_grad()
ep_loss += float(time_slice_loss)
pred_acc_train += float(l2)
The code runs for t == 0, but then gives the error message:
RuntimeError Traceback (most recent call last)
in
43 # backprop
44
—> 45 time_slice_loss.backward(retain_graph = True)
46 optimizer.step()
47 optimizer.zero_grad()~/anaconda3/envs/tf_gpu/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
219 retain_graph=retain_graph,
220 create_graph=create_graph)
→ 221 torch.autograd.backward(self, gradient, retain_graph, create_graph)
222
223 def register_hook(self, hook):~/anaconda3/envs/tf_gpu/lib/python3.8/site-packages/torch/autograd/init.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
128 retain_graph = create_graph
129
→ 130 Variable.execution_engine.run_backward(
131 tensors, grad_tensors, retain_graph, create_graph,
132 allow_unreachable=True) # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [128, 128]], which is output 0 of TBackward, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
The issue is related to the loss component, l4
l4 = (1/time_steps)*lin_pred_loss(enc,enc_pred)
because the code runs when l4 is not included in the total loss.
The variable enc_pred is an output of the previous forward pass. I’ve read through other posts pertaining to this error message, but I couldn’t find anything that resolved the issue. I’ve tried substituting:
enc_pred = rollout.clone()
and,
time_slice_loss = (ae_reg*l1+pred_reg*l2+met_reg*l3+lin_pred_reg*l4).clone()
with no success. Any help is greatly appreciated.
UPDATE:
The code runs with
enc_pred = rollout.detach()
although I’ll have to think about what exactly is going on.