Modified by an inplace operation only if JIT compiled

Hi:)

I have written an nn.Module that works fine if used as a plain Python object but as soon as I jit.script / compile it, it throws at runtime:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.DoubleTensor [10, 8]], which is output 0 of SelectBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

Why does this only occur after compilation? Naively, I would understand if such an error occur if not compiled and then gets optimized away by compiling. But here the compiler seemingly introduces this error. Could this be a bug in the compiler?

Thanks for your help!

EDIT
I managed to bisect my code and reduce it to this snippet:

debug = False

model = MyFancyModel()
if not debug:
    model = torch.jit.script(model)
loss = torch.sum((model(x) - y)**2)
print('#1', loss)
loss.backward()  # works
print('#1 done.')

model = MyFancyModel()
if not debug:
    model = torch.jit.script(model)
loss = torch.sum((model(x) - y)**2)
print('#2', loss)
loss.backward()  # fails
print('#2 done.')

which works fine if debug is set to True but fails on the second call of loss.backward(). I am very confused. MyFancyModel() shouldn’t carry any static state and I don’t understand why (a) the second invocation of .backward() works w/o compiling but fails if the module is compiled, and (b) why does the first invocation works even with compiling!? Do you have any suggestions how to attack this problem?

I faced the same error recently in different scenario(not in jit) and I found the solution like below

      outer_loss = torch.tensor(0., device=args.device)
     avg_loss=F.cross_entropy(test_logit, test_target)
     outer_loss += avg_loss

     with torch.no_grad(): #if i don't change the loss using this it's throwing error as the model states for the claculated losses are different
              outer_loss.set_(torch.Tensor([InitialLoss-F.cross_entropy(test_logit, test_target)]).to(device)) 

This link also helped in understanding the problem

Thanks a lot for your help. But I don’t see the connection to my problem. Do you suggest to add with torch.no_grad() to my code? But where?

I referred about the error can you paste you MyFancyModel() for debugging

why (a) the second invocation of .backward() works w/o compiling but fails if the module is compiled, and
(b) why does the first invocation works even with compiling!? Do you have any suggestions how to attack this problem?

Very likely, one of the inputs is modified in place but the JIT fuser combines operations in a way that it needs the input for the backward. This typically kicks in with the second run on (the first being an information gathering (profiling the shapes) run).

You could try if anomaly detection helps you hunt down the inplace modification and remove that.

The model is quite involved. I will try to reduce it s.t. it becomes more readable. Thanks again:)

The error message
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.DoubleTensor [10, 8]], which is output 0 of SelectBackward, is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
is already the output of the anomaly detection.
Are you referring to the x and y in the example? If my module works correctly these variables should be read-only, but I will double check. All other variables should be renewed by the second compilation step.

EDIT
I am now quite sure that x and y do not change. I copy both via tensor.clone() and compare them after calling my module. All other variables are renewed by calling model = MyFancyModel() the second time…

Do you use select and indexing somewhere and then change the tensor you indexed?
As those will share memory (they are 'view’s into the same memory in autograd lingo) they share the same version counter and modifying the “parent” tensor will invalidate using the view in the backward (as it doesn’t track which part was viewed and which part was modified, that is the case regardless of the actual values have changed).

Yes, I do use indexing. I managed to strip down my module code, unrolled a few loops and ended up with this snippet:

import torch
import torch.nn as nn


class Core(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.f = nn.Linear(4, 2)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x = torch.cat((x1, x2), 1)
        return self.f(x)


class MyFancyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.core = Core()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape == (1, 2, 2)

        y = torch.empty(1, 4, 2)
        y[:, :2] = x

        x1, x2 = x[:, 0], x[:, 1]
        y1, y2 = y[:, 0], y[:, 1]
        y[:, 2] = x2 + (x2 - x1) * self.core(y1, y2)

        x1, x2 = x[:, 1], y[:, 2]
        y1, y2 = y[:, 1], y[:, 2]
        y[:, 3] = x2 + (x2 - x1) * self.core(y1, y2)

        return y

In case my indexing is broken, can you give me a hint how to fix it?

You could try sticking your y bits in a list and then use torch.cat.
So something like this (but I might have gotten something wrong).
x[:, None] is the same as x.unsqueeze(1), that is my taste preference here, you might have different ones.

class MyFancyModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.core = Core()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert x.shape == (1, 2, 2)

        y = torch.empty(1, 4, 2)
        y_list = [x]

        x1, x2 = x[:, 0], x[:, 1]
        y1, y2 = x[:, 0], x[:, 1]
        y_new = x2 + (x2 - x1) * self.core(y1, y2)
        print(y_new.shape)
        y_list.append(y_new[:, None])

        x1, x2 = x[:, 1], y_new # ??
        y1, y2 = y2, y_new
        y_new = x2 + (x2 - x1) * self.core(y1, y2)
        y_list.append(y_new[:,None])
        y = torch.cat(y_list, 1)

        return y

Best regards

Thomas

Thanks for your help. Much appreciated! So you are saying assigning values to a tensor as in y[:, 2] = ... does not work with autograd, even if this value is never overwritten afterwards?

Yeah. I have this superfancy “everything about autograd course” that covers this error in full detail. Writing and recording the course and all was great fun, but I am slow to set up the delivery…

1 Like

thanks for your help:)