Inplace operation halting backward()

Hi, my model employs a very unusual loss function, where you take an intermediate value from the forward() operation (x5 in the example below), and use it to calculate the l1 regularization term, in coordination with the MSE_Loss yielded by the final output (x9) of the model. Running backward() on this loss function, however, yields the error below, saying the inplace operation (probably while calculating the intermediate L_TV_Loss) and thus the backward() cannot operate. If I pass the x5 term instead of calculating the regularization term within the forward() method, then calculate the L_TV_Loss after the forward() has run, the backward() doesn’t seem to calculate the appropriate gradients for the regularization term (most likely because the MSE_Loss and L_TV_Loss overlaps in the computational graph, starting from x5). In brief, the loss function has multiple ‘branches’ to backpropagate. Are there any good ideas as to solve this problem? Thanks

def L_TV_Loss(diff_grid, grid_size, lambda_):
  #calculate the L1_Loss
  L1_Loss_f = nn.L1Loss()
  L_TV_Loss = torch.norm(diff_grid[:,0]) + torch.norm(diff_grid[:,1])
  L_TV_Loss = L_TV_Loss * lambda_
  return L_TV_Loss

class ALIGNet(nn.Module):

  def __init__(self, grid_size):


    self.conv = get_conv(grid_size)

    self.flatten = nn.Flatten()

    self.linear1 = nn.Sequential(nn.Linear(80,20),nn.ReLU(),)

    self.linear2 = nn.Linear(20, 2*grid_size*grid_size)

    self.upsampler = nn.Upsample(size = [IMAGE_SIZE, IMAGE_SIZE], mode = 'bilinear')

    self.linear2.bias = nn.Parameter(init_grid(grid_size).view(-1))

    self.grid_offset_x = torch.tensor(float(0), requires_grad=True)

    self.grid_offset_y = torch.tensor(float(0), requires_grad=True)

    self.grid_offset_x = nn.Parameter(self.grid_offset_x)

    self.grid_offset_y = nn.Parameter(self.grid_offset_y)

    self.grid_size = grid_size

  def forward(self, x, src_batch, checker_board=False, grad_list = False):

    x0 = self.conv(x)

    x1 = self.flatten(self.x0)

    x2 = self.linear1(self.x1)

    x3 = self.linear2(self.x2)

    #enforce axial monotinicity using the abs operation

    x4 = torch.abs(self.x3)

    batch, grid = self.x4.shape

    x5 = self.x4.view(batch, 2,self.grid_size,self.grid_size)

    #apply the l_tv penalty 

    L_TV_Loss_ = L_TV_Loss(x5, 8, 1e-5)

    #perform the cumsum operation to restore the original grid from the differential grid

    x6 = cumsum_2d(x5, self.grid_offset_x, self.grid_offset_y)

    #Upsample the grid_size x grid_size warp field to image_size x image_size warp field

    x7 = self.upsampler(x6)

    x8 = x7.permute(0,2,3,1)

    if checker_board:

      source_image = apply_checkerboard(src_batch, IMAGE_SIZE)

    #calculate target estimation

    x9 = nn.functional.grid_sample(src_batch.unsqueeze(0).permute([1,0,2,3]), x8, mode='bilinear')

    return x9, L_TV_Loss_

loss = nn.MSELoss(target_image, x9) + L_TV_Loss

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [10, 8, 8]], which is output 0 of SelectBackward, is at version 2; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).```