I’ve done some work on understanding graph and state, and how these are freed on backward calls. Notebook here.
There are two questions remaining - the second question is more important.
1) No guarantee that second backward will fail?
x = Variable(torch.ones(2,3), requires_grad=True)
y = x.mean(dim=1).squeeze() + 3 # size (2,)
z = y.pow(2).mean() # size 1
y.backward(torch.ones(2))
z.backward() # should fail! But only fails on second execution
y.backward(torch.ones(2)) # still fine, though we're calling it for the second time
z.backward() # this fails (finally!)
My guess: it’s not guaranteed that an error is raised on the second backward pass through part of the graph. But of course if we need to keep buffers on part of the graph, we have to supply retain_variables=True. Cause buffers could have been freed.
Probably the specific simple operations for y (mean, add) don’t need buffers for backward, while the z=y.pow(2).mean()
does need a buffer to store the result of y.pow(2)
. correct?
2) Using a net twice on the same input Variable makes a new graph with new state?
out = net(inp)
out2 = net(inp) # same input
out.backward(torch.ones(1,1,2,2))
out2.backward(torch.ones(1,1,2,2)) # doesnt fail -> has a different state than the first fw pass?!
Am I right to think that fw-passing the same variable twice constructs a second graph, keeping the state of the first graph around?
The problem I see with this design is that often (during testing, or when you detach() to cut off gradients, or anytime you add an extra operation just for monitoring) there’s just a fw-pass on part of the graph - so is that state then kept around forever and just starts consuming more memory on every new fw-pass of the same variable?
I understand that the volatile flag is probably introduced for this problem and I see it’s used during testing in most example code.
But I think these are some examples where there’s just fw-pass without volatile
flag:
-
fake = netG(noise).detach()
to avoid bpropping through netG https://github.com/pytorch/examples/blob/master/dcgan/main.py#L216 - test on non-volatile variables: https://github.com/pytorch/examples/blob/master/super_resolution/main.py#L74
- If you finetune only top layers of a feedforward net, bottom layers see only fw-passes
But in general, if I understand this design correctly, this means anytime you have a part of a network which isn’t backpropped through, you need to supply volatile flag? Then when you use that intermediate volatile variable in another part of the network which is backpropped through, you need to re-wrap and turn volatile off?
PS
If there’s interest, I could update & adapt the notebook to your answers, or merge the content into the existing “for torchies” notebook, and submit a PR to the tutorials repo.