Pow(2) inplace operation problem with gradient

Hi!
I’m trying to learn a differential equation from data, for that I defined a net that approximates the derivative at a point. I then use the output of the net to approximate the next grid-point of the solution to the differential equation and use that as an input for the next step. The function calls of the net are therefore recursive. Now I also know already pretty exactly the formula of the differential equation i want to approximate and do some feature extraction in every step, precisely building some products, powers, squares, etc with the output of the net and additional parameters.

class ODEMultistep(nn.Module):

    def __init__(self, args): 
        super(ODEMultistep, self).__init__()
        self.fnn1 = Derivative_FNN().to(args.device)
        self.fnn2 = Superpos_FNN(args.steps).to(args.device) 
        self.args = args

    def forward(self, nr, iv, k, stop):
        ni_pred = torch.zeros(nr.size(0), stop).to(self.args.device)
        ni_pred[:, :self.args.steps] = iv.clone().detach()
        ni_pred.requires_grad = True

        # this loops over all available nodes and calculates the next node
        for j in range(self.args.steps, stop - self.args.steps):

            solutions = []
            # this calculates all previous derivatives
            for l in range(j - self.args.steps, j):
                t1 = nr[:, l] / ni_pred[:, l] * nr[:, l] # here calculate features
                td = nr[:, l].pow(2)
                td -= ni_pred[:, l].pow(2) # in this line the gradient gets lost
                t2 = k * nr[:, l]
                t2 *= torch.sqrt(td)
                t1, t2 = t1.view(-1, 1), t2.view(-1, 1)
                derivative = self.fnn1(t1, t2)
                solutions.append(derivative)
            this_derivative = self.fnn2(solutions)
            ni_pred[:, j] = ni_pred[:, j - 1] + self.args.stepsize * this_derivative

        return ni_pred

Now I get this error message when running the code:

sys:1: RuntimeWarning: Traceback of forward call that caused the error:
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/train_ODEMultistep.py", line 164, in <module>
    train(args, model, criterion, train_loader, optimizer, epoch)
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/lib/utils_ODEMultistep.py", line 27, in train
    stop=i)
  File "/home/felixwagner/miniconda3/envs/felix_ml/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/lib/utils_ODEMultistep.py", line 101, in forward
    td -= ni_pred[:, l].pow(2)

Traceback (most recent call last):
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/train_ODEMultistep.py", line 164, in <module>
    train(args, model, criterion, train_loader, optimizer, epoch)
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/lib/utils_ODEMultistep.py", line 36, in train
    loss.backward()
  File "/home/felixwagner/miniconda3/envs/felix_ml/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/felixwagner/miniconda3/envs/felix_ml/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1]], which is output 0 of SelectBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Somehow the gradient gets destroyed when I pow(2) the outputs of the net and then re-insert it for computing the next step in the differential equation. Can somebody help me finding a solution to this problem? Can I do anything about it?

Hi,

The errors means that you changed inplace a Tensor whose value was needed to compute some gradients.
Given the line that cause the error, this is most likely the -= operation that is the issue.
Can you try replacing it with an out of place version like td = td - xx ?

I changed all the -=, += to the not-inplace version, but I get the same error at the same line:

sys:1: RuntimeWarning: Traceback of forward call that caused the error:
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/train_ODEMultistep.py", line 164, in <module>
    train(args, model, criterion, train_loader, optimizer, epoch)
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/lib/utils_ODEMultistep.py", line 27, in train
    stop=i)
  File "/home/felixwagner/miniconda3/envs/felix_ml/lib/python3.7/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/lib/utils_ODEMultistep.py", line 101, in forward
    td = td - ni_pred[:, l].pow(2)

Traceback (most recent call last):
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/train_ODEMultistep.py", line 164, in <module>
    train(args, model, criterion, train_loader, optimizer, epoch)
  File "/home/felixwagner/MEGAsync/Projektarbeit/Code/lib/utils_ODEMultistep.py", line 36, in train
    loss.backward()
  File "/home/felixwagner/miniconda3/envs/felix_ml/lib/python3.7/site-packages/torch/tensor.py", line 107, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/felixwagner/miniconda3/envs/felix_ml/lib/python3.7/site-packages/torch/autograd/__init__.py", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1]], which is output 0 of SelectBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

I also tried to do multiple of the calculations in one line, instead of only one per line, which gives the same error (not surprisingly).

Ho actually reading the error message in details:

  • “output 0 of SelectBackward” This means that the problematic one is actually ni_pred
  • You do change this inplace when you do ni_pred[:, j] = xxx.

Can you try changing ni_pred to be a list and only before the return do return torch.stack(ni_pred, 1)?
something like:

ni_pred = torch.split(iv.clone().detach(), 1, dim=1)

# replace
ni_pred[:, j]
# by
ni_pred[j]

# replace
ni_pred[:, j] = xxx
# by
ni_pred.append(xxx)

Turning it into a list did solve it. Thank you, i probably wouldn’t have figured that out :slight_smile:

2 Likes