RuntimeError due to inplace operation

I’m still trying to wrap my head around PyTorch’s Autograd engine. I wanted to implement a toy network architecture but keep getting the same error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

Looking at my code, I can’t figure out, what exactly causes the error. It would be great, if someone could give me a hint. Here is my code:

import time
import torch
import random


def test_toy_network():

    torch.autograd.set_detect_anomaly(True)

    # Parameters
    n_samples = 2000
    lr = 1e-4
    n_inputs = 2
    n_outputs = 2
    sigma = 0.1
    n_steps = 1000

    # Fake data
    x_data = torch.rand(size=(n_samples, n_inputs))
    y_data = torch.rand(size=(n_samples, n_outputs))

    matrix_height = 3
    matrix_width = 1

    # Trainable parameters
    w_in = torch.normal(mean=0.0, std=sigma, size=(matrix_height, n_inputs), requires_grad=True)
    b_in = torch.normal(mean=0.0, std=sigma, size=(matrix_height,), requires_grad=True)

    w = [[torch.normal(mean=0.0, std=sigma, size=(1,), requires_grad=True)
          for _ in range(matrix_width)] for _ in range(matrix_height)]

    w_out = torch.normal(mean=0.0, std=sigma, size=(n_outputs, matrix_height), requires_grad=True)
    b_out = torch.normal(mean=0.0, std=sigma, size=(n_outputs,), requires_grad=True)

    # Placeholder
    a = [[torch.zeros(size=(1,)) for _ in range(matrix_width + 1)] for _ in range(matrix_height)]
    a_tmp = torch.zeros(size=(matrix_height,))

    h = torch.nn.Sigmoid()

    for n in range(n_steps):

        t0 = time.time()

        # Draw data points
        rand_idx = random.randint(0, n_samples - 1)
        x = x_data[rand_idx]
        y = y_data[rand_idx]

        # Feedforward
        a_in = h(w_in.matmul(x) + b_in)

        for i in range(matrix_height):
            a[i][0] = a_in[i]

        a[0][1] = h(a[0][0] * w[0][0] + a[1][0] * w[1][0])
        a[1][1] = h(a[0][0] * w[0][0] + a[1][0] * w[1][0] + a[2][0] * w[2][0])
        a[2][1] = h(a[1][0] * w[1][0] + a[2][0] * w[2][0])  # <--- Error

        for i in range(matrix_height):
            a_tmp[i] = a[i][1]

        a_out = torch.sigmoid(w_out.matmul(a_tmp) + b_out)
        loss = torch.nn.MSELoss(reduction="mean")(a_out, y)

        # Backpropagation
        loss.backward(retain_graph=True)

        # Gradient descent
        with torch.no_grad():
            w_in.sub_(lr * w_in.grad)
            b_in.sub_(lr * b_in.grad)

            w_out.sub_(lr * w_out.grad)
            b_out.sub_(lr * b_out.grad)

            w[0][0].sub_(lr * w[0][0].grad)  # <--- Error
            w[1][0].sub_(lr * w[1][0].grad)  # <--- Error
            w[2][0].sub_(lr * w[2][0].grad)  # <--- Error

        t1 = time.time()

        if n % 100 == 0:
            print(f"n {n} loss {loss.item()} time {(t1-t0)}")


if __name__ == "__main__":
    test_toy_network()

The error you had before the present one, namely

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling .backward() or autograd.grad() the first time.

is a bit incomplete. Between the two sentences written there, it should read

Think about whether you actually want to take backward of an expression involving gradients and if that is the case specify …

Using anomaly detection mode (with torch.autograd.detect_anomaly()) gives you the line that is causing this: a_tmp[i] = a[i][1].

Indeed this is bad, as the gradient-fn property is only kept per-tensor and not per element, autograd does not know that the backwards of the old iteration isn’t actually needed anymore.
If you change a_tmp to be allocated in each iteration, you won’t have the problem.

So I’m not terribly smart, and as such, I never try this type of clever optimization in the first go of writing a program as you have here with a_tmp and instead optimize for expressivity. This is particularly true here because PyTorch does cache memory allocations, so the overhead of re-allocating is very small (you might use empty to also avoid zeroing if you must). Of course, if you insist being this clever and preallocate, you could also call a_tmp.detach_() in the no-grad block.

Now, smart or not, I have spent quite a while thinking about what it means to write efficient PyTorch code. If you are not contend with the performance of your code, here is one of my rules of thumb that looks applicable: working with lists/loops containing/processing 1-element tensors and very small operands. The things you spell out element by element do look like you would be much happier if you just used a bunch of matmuls or somesuch.

Best regards

Thomas

1 Like

Thanks! Both of your suggestions worked! However, I do not understand why a_tmp.detach_() works? What exactly happens in the background if I call .detach_()?

Basically, .detach_() removes the backward edges that go into the present tensor (print .grad_fn before and after .detach() to see it). These are the ones that go into there when you call it. After your slice assignment to a_tmp, you’ll have new ones (print before and after there, too, to see it)

1 Like

In another tutorial on PyTorch’s autograd engine, I saw that after updating the weights, the gradients are set to zero with my_tensor.grad.zero_() or they are just replaced by None. Is this also necessary in my example above? I would be very happy if you could give me an answer to this.

Yes, usually, you would use opt.zero_grad() for that, but doing it on the tensors would also work.

1 Like