Runtime Error: variable modified by an inplace operation

The forward function of my torch.nn module initially looked like this:

def forward(self, a, b, c):
   #blah blah blah
   inds = (c<0)
   c_aug = c
   c_aug[inds] = a.shape[0]
   #more blah blah blah

This was giving me an error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [82, 4]] is at version 2; expected version 1 instead.
After some debugging using anomaly detection and looking at this question, I changed the forward function as follows:

def forward(self, a, b, c):
   #blah blah blah
   inds = (c<0)
   c_aug = torch.clone(c)
   c_aug[inds] = a.shape[0]
   #more blah blah blah

This does not give me that error. Is it the correct workaround for the error or there are better methods to handle it?

Hi,

Yes this is the correct workaround. You just want to make sure that you don’t modify the c you get as input inplace. So you clone it before changing it inplace.

Note that we usually do c_aug = c.clone() instead of using torch.clone.

1 Like