Hi, I have a problem with loss.backward()
with torch.stack
.
Even as I write this question, I don’t know why this isn’t working.
I have a output
of my model as a list of tensor.
output = net(inp) # output is list of tensor
When I calculate loss by loss_fn(output[-1])
, there is no problem.
loss=loss_fn(output[-1],answer)
loss.backward() # Success!
However, when I tried with
output = torch.stack(output)
loss=loss_fn(output[-1],answer)
loss.backward() # Fail... Why?
With autograd.set_detect_anomaly(True)
option, I got an message
[W python_anomaly_mode.cpp:104] Warning: Error detected in MulBackward0. Traceback of forward call that caused the error:
File "/home/dngusdnr1/openfold/openfold/model/grad_debug.py", line 364, in <module>
main()
File "/home/dngusdnr1/openfold/openfold/model/grad_debug.py", line 335, in main
output = net(s=seqs,z=pair_feat,ulr_mask=(masks|~(str_mask).bool()),init_rigids=init_rigids,mask=str_mask,aatype=aatype)
File "/home/dngusdnr1/anaconda3/envs/pytorch3d/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/dngusdnr1/openfold/openfold/model/debug_module.py", line 650, in forward
rot_mats=rigids.get_rots().get_rot_mats(),
File "/home/dngusdnr1/openfold/openfold/model/rigid_utils.py", line 514, in get_rot_mats
rot_mats = quat_to_rot(self._quats)
File "/home/dngusdnr1/openfold/openfold/model/rigid_utils.py", line 196, in quat_to_rot
quat = quat[..., None] * quat[..., None, :]
(function _print_stack)
Traceback (most recent call last):
File "/home/dngusdnr1/openfold/openfold/model/grad_debug.py", line 364, in <module>
main()
File "/home/dngusdnr1/openfold/openfold/model/grad_debug.py", line 347, in main
(loss / GRADIENT_ACCUMULATE_EVERY).backward()
File "/home/dngusdnr1/anaconda3/envs/pytorch3d/lib/python3.9/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/dngusdnr1/anaconda3/envs/pytorch3d/lib/python3.9/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 610, 4, 1]], which is output 0 of AsStridedBackward0, is at version 1; expected version 0 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
My model is little bit complicated, so I can’t make code snippet for my model, sorry for that. Below is my best.
class MyModel(nn.Module):
def __init__ (...
....
def forward(...
....
traj_s=[]
for i in range(self.no_blocks):
s = self.block(s)
x = self.interpreter(s)
traj_s.append(x)
....
return traj_s # output is list of tensor
#run model
net=MyModel()
output=net(tmp_input)
# Case1 backward success
loss=loss_fn(output[-1],answer)
loss.backward()
# Case 2 backward fail
output=torch.stack(output)
loss=loss_fn(output[-1],answer)
loss.backward()
What confuses me even more is that there is no problem with simple models.
class Simple_test(torch.nn.Module):
def __init__(self):
super(Simple_test,self).__init__()
self.act_fn=torch.nn.ReLU()
self.l1=torch.nn.Linear(10,10)
self.l2=torch.nn.Linear(10,10)
self.l3=torch.nn.Linear(10,10)
self.l4=torch.nn.Linear(10,10)
def forward(self,x):
memo=[]
for _ in range(5):
x=self.l1(x)
x=self.act_fn(x)
x=self.l2(x)
x=self.act_fn(x)
x=self.l3(x)
x=self.act_fn(x)
x=self.l4(x)
x=self.act_fn(x)
memo.append(x)
return memo
net=Simple_test()
net.train()
tmp_inp=torch.randn(1,10)
output=net(tmp_inp)
#Case 1 success
loss=torch.sum(output[-1])
loss.backward()
#Case 2 success
output=torch.stack(output)
loss=torch.sum(output[-1])
loss.backward()
I know that torch.stack
itself allows gradient backward. I’m sure that my code make this problem.
But I can’t understand why it was fine before do torch.stack
…
If there was a problem in the my forward code, I think there must be a problem in the loss.backward()
even before applying torch.stack
…
Does anyone know the cause of this issue?
It’s probably a problem with my code, of course, so I’d appreciate some advice on where to start.
Thanks