The forward function of my torch.nn module initially looked like this:
def forward(self, a, b, c): #blah blah blah inds = (c<0) c_aug = c c_aug[inds] = a.shape #more blah blah blah
This was giving me an error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.LongTensor [82, 4]] is at version 2; expected version 1 instead.
After some debugging using anomaly detection and looking at this question, I changed the forward function as follows:
def forward(self, a, b, c): #blah blah blah inds = (c<0) c_aug = torch.clone(c) c_aug[inds] = a.shape #more blah blah blah
This does not give me that error. Is it the correct workaround for the error or there are better methods to handle it?