I have some parameters that have to be in-place pruned in each iteration so that the shape is changed, but that the model and the optimizer keep references to them. Though the forward pass succeeds after the pruning, the backward pass fails because the gradient function uses the old shape of the gradient.
This code below demonstrates the problem, resulting in the error message RuntimeError: Function SumBackward0 returned an invalid gradient at index 0 - got [16, 64, 64] but expected shape compatible with [32, 64, 64].
I have tried to work around this by trying to delete the previous graph by removing the output from the scope (by setting REQUEST_GRAPH_DELETION = True in the code), but it is not guaranteed to always work.
I have tried many things to make PyTorch correctly update the metadata, but nothing worked reliably. Should I report it on GitHub?
import torch, time
REQUEST_GRAPH_DELETION = False # if False, the error always occurs, if True, only sometimes
TEST_REPEAT_COUNT = 0 # 200
def run_loop():
x = torch.ones((32, 64, 64), requires_grad=True)
out = None
while x.shape[0] > 0:
yield x.shape # for printing
if REQUEST_GRAPH_DELETION:
del out # time.sleep(0.0001) helps, but shoudn't be relied on
out = x.sum()
assert out.grad_fn(out).shape == x.shape # correct gradient shape
out.backward() # incorrect gradient shape because some metadata is not updated
with torch.no_grad(): # change the shape
N = x.shape[0] // 2
grad = x.grad
x.set_(x.data[:N])
x.grad = grad[:N]
for shape in run_loop():
print(shape)
if TEST_REPEAT_COUNT > 0: # repeats
error_count, succ_count = 0, 0
for i in range(TEST_REPEAT_COUNT):
try:
shapes = []
for shape in run_loop():
shapes.append(shape[0])
print("OK ", shapes)
succ_count += 1
except:
print("ERROR", shapes)
error_count += 1
print(f"{error_count} errors, {succ_count} successes")
I am using an optimizer without running stats. I haven’t considered this problem yet. Thank you!
In my case, the parameters have a batch dimension. Each input example of the model (we can call it “local manifold model” or “perturbation model”) has it’s own independent parameter slice for every parameter. The batch dimension is pruned for the input and all parameters with the same mask. I could also be in-place pruning the running stats with the same mask.
The problem is that autograd informations about the Tensor are not reset when you .set_ it.
You can try adding a gc.collect() after the del out to make it more reliable.
Thank you! gc.collect() is still unreliable. It seems that Python frees an object immediately when del is used anyway.
Is there perhaps a way to synchronize with the PyTorch thread that checks whether something is deleted and resets autograd information?
I have also tried x.detach_().requires_grad_() and it doesn’t work either.
On my laptop, the error rate is ~0.15. On the computer I run experiments on, the error rate is ~0 if the CPU is not busy and ~0.03 when run in parallel with e.g. stress --cpu 4.
My wild guess at the moment is that because the object we use to accumulate gradients is only a weak pointer into the Tensor, it is not always cleared before you run the next backward. In that case, it is re-used and fails because the size changed.
@ezyang do we expect the .set_() (and .resize_()) operation to reset the grad_accumulator_ ?
After trying again, I cannot reproduce the flakyness when REQUEST_GRAPH_DELETION=True on any of the machines I have available.
Could you give more information about you OS, compiler version, library version, torch and how you installed it please?
I set REQUEST_GRAPH_DELETION = False and TEST_REPEAT_COUNT = 10000. To increase the likelihood of errors, I run other processes while testing, e.g. a CPU stress test or a training experiment.
It seems that the error rate is the lowest on the 4th machine which has most CPU cores (12 logical cores) among the i7s (less than the machine with 2 Xeons). With stress --cpu 6 running in parallel with the test, I get an error rate of about 0.001.