Tensor.set_ seems not to correctly update metadata if the shape changes


I have some parameters that have to be in-place pruned in each iteration so that the shape is changed, but that the model and the optimizer keep references to them. Though the forward pass succeeds after the pruning, the backward pass fails because the gradient function uses the old shape of the gradient.

This code below demonstrates the problem, resulting in the error message
RuntimeError: Function SumBackward0 returned an invalid gradient at index 0 - got [16, 64, 64] but expected shape compatible with [32, 64, 64].

I have tried to work around this by trying to delete the previous graph by removing the output from the scope (by setting REQUEST_GRAPH_DELETION = True in the code), but it is not guaranteed to always work.

I have tried many things to make PyTorch correctly update the metadata, but nothing worked reliably. Should I report it on GitHub?

import torch, time

REQUEST_GRAPH_DELETION = False  # if False, the error always occurs, if True, only sometimes

def run_loop():
    x = torch.ones((32, 64, 64), requires_grad=True)
    out = None

    while x.shape[0] > 0:
        yield x.shape  # for printing

            del out  # time.sleep(0.0001) helps, but shoudn't be relied on

        out = x.sum()
        assert out.grad_fn(out).shape == x.shape  # correct gradient shape
        out.backward()  # incorrect gradient shape because some metadata is not updated

        with torch.no_grad():  # change the shape
            N = x.shape[0] // 2

            grad = x.grad
            x.grad = grad[:N]

for shape in run_loop():

if TEST_REPEAT_COUNT > 0:  # repeats
    error_count, succ_count = 0, 0
    for i in range(TEST_REPEAT_COUNT):
            shapes = []
            for shape in run_loop():
            print("OK   ", shapes)
            succ_count += 1
            print("ERROR", shapes)
            error_count += 1
    print(f"{error_count} errors, {succ_count} successes")


Which optimizer do you use? Will you also update all the running stats in there accordingly?


I am using an optimizer without running stats. I haven’t considered this problem yet. Thank you!

In my case, the parameters have a batch dimension. Each input example of the model (we can call it “local manifold model” or “perturbation model”) has it’s own independent parameter slice for every parameter. The batch dimension is pruned for the input and all parameters with the same mask. I could also be in-place pruning the running stats with the same mask.


The problem is that autograd informations about the Tensor are not reset when you .set_ it.
You can try adding a gc.collect() after the del out to make it more reliable.

Thank you! gc.collect() is still unreliable. It seems that Python frees an object immediately when del is used anyway.
Is there perhaps a way to synchronize with the PyTorch thread that checks whether something is deleted and resets autograd information?

I have also tried x.detach_().requires_grad_() and it doesn’t work either.

I am not sure why this is happening in a flaky way. On my machine, it always works when you del out.

On my laptop, the error rate is ~0.15. On the computer I run experiments on, the error rate is ~0 if the CPU is not busy and ~0.03 when run in parallel with e.g. stress --cpu 4.

My wild guess at the moment is that because the object we use to accumulate gradients is only a weak pointer into the Tensor, it is not always cleared before you run the next backward. In that case, it is re-used and fails because the size changed.

@ezyang do we expect the .set_() (and .resize_()) operation to reset the grad_accumulator_ ?


After trying again, I cannot reproduce the flakyness when REQUEST_GRAPH_DELETION=True on any of the machines I have available.
Could you give more information about you OS, compiler version, library version, torch and how you installed it please?


Thank you for your effort!

I can reproduce it on all of these machines:

  1. My laptop: Intel Pentium 2020M, Windows 10, Anaconda Python 3.7.6/3.8, pre-compiled PyTorch 1.3 from https://anaconda.org/pytorch/pytorch
  2. Intel® Core™ i7-4790K, Arch Linux, Python 3.8, PyTorch 1.3.1 compiled with GCC 9.2.0
  3. Intel® Core™ i7-7700K, Arch Linux, Python 3.8, PyTorch 1.3.1 compiled with GCC 9.2.0
  4. Intel® Core™ i7-5820K, Arch Linux, Python 3.8, PyTorch 1.3.1 compiled with GCC 9.2.0
  5. 2x Intel® Xeon® E5-2620, Arch Linux, Python 3.8, PyTorch 1.3.1 compiled with GCC 9.2.0

I set REQUEST_GRAPH_DELETION = False and TEST_REPEAT_COUNT = 10000. To increase the likelihood of errors, I run other processes while testing, e.g. a CPU stress test or a training experiment.

It seems that the error rate is the lowest on the 4th machine which has most CPU cores (12 logical cores) among the i7s (less than the machine with 2 Xeons). With stress --cpu 6 running in parallel with the test, I get an error rate of about 0.001.