Dear All,
Could anybody suggest a solution to the following problem, which happens in the code below:
....
recon_img = None
input_img = None
all_ones = None
model = ....
....
for ....
weight = ....
img_data = ....
if recon_img is None: recon_img = img_data.clone().to(glo_device)
if input_img is None: input_img = img_data.clone().to(glo_device)
if all_ones is None: all_ones = img_data.new_ones(img_data.size()).to(glo_device)
input_img[:,:,:] = img_data * weight + recon_img * (all_ones - weight) ### <-- HERE
recon_img = model(input_img)
loss = F.l1_loss(weight * recon_img, weight * img_data, reduction="sum")
optimizer.zero_grad()
loss.backward() ### <-- CRASHES
optimizer.step()
....
If I feed in model the modified image tensor (input_img[:,:,:] = img_data * weight + recon_img * (all_ones - weight)), the error shows up after the second iteration:
learning rate: 0.001
Backend TkAgg is interactive backend. Turning interactive mode on.
loss:524955.6250, frame:10
learning rate: 0.000999
Traceback (most recent call last):
ERROR: ('Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.',)
File "/Users/albert/work/cnn/_bgs_5_history_tanh_feedback.py", line 410, in <module>
Main(args)
File "/Users/albert/work/cnn/_bgs_5_history_tanh_feedback.py", line 346, in Main
loss.backward()
File "/usr/local/lib/python3.7/site-packages/torch/tensor.py", line 102, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "/usr/local/lib/python3.7/site-packages/torch/autograd/__init__.py", line 90, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
However, if I use just “input_img[:,:,:] = img_data”, everything is fine. I understand that this is a bit bizzare to blend output with the input, but it should be (?) techincally alright. Where do I miss “retain_graph=True”? I would appreciate a suggestive feedback.