The forward function of my torch.nn module initially looked like this:
def forward(self, a, b, c):
#blah blah blah
inds = (c<0)
c_aug = c
c_aug[inds] = a.shape[0]
#more blah blah blah
This was giving me an error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [82, 4]] is at version 2; expected version 1 instead.
After some debugging using anomaly detection and looking at this question, I changed the forward function as follows:
def forward(self, a, b, c):
#blah blah blah
inds = (c<0)
c_aug = torch.clone(c)
c_aug[inds] = a.shape[0]
#more blah blah blah
This does not give me that error. Is it the correct workaround for the error or there are better methods to handle it?