Understanding graphs and state

I’ve done some work on understanding graph and state, and how these are freed on backward calls. Notebook here.

There are two questions remaining - the second question is more important.

1) No guarantee that second backward will fail?

x = Variable(torch.ones(2,3), requires_grad=True)
y = x.mean(dim=1).squeeze() + 3 # size (2,)
z = y.pow(2).mean() # size 1
y.backward(torch.ones(2))
z.backward() # should fail! But only fails on second execution
y.backward(torch.ones(2)) # still fine, though we're calling it for the second time
z.backward() # this fails (finally!)

My guess: it’s not guaranteed that an error is raised on the second backward pass through part of the graph. But of course if we need to keep buffers on part of the graph, we have to supply retain_variables=True. Cause buffers could have been freed.

Probably the specific simple operations for y (mean, add) don’t need buffers for backward, while the z=y.pow(2).mean() does need a buffer to store the result of y.pow(2). correct?

2) Using a net twice on the same input Variable makes a new graph with new state?

out = net(inp)
out2 = net(inp) # same input
out.backward(torch.ones(1,1,2,2))
out2.backward(torch.ones(1,1,2,2)) # doesnt fail -> has a different state than the first fw pass?!

Am I right to think that fw-passing the same variable twice constructs a second graph, keeping the state of the first graph around?

The problem I see with this design is that often (during testing, or when you detach() to cut off gradients, or anytime you add an extra operation just for monitoring) there’s just a fw-pass on part of the graph - so is that state then kept around forever and just starts consuming more memory on every new fw-pass of the same variable?

I understand that the volatile flag is probably introduced for this problem and I see it’s used during testing in most example code.

But I think these are some examples where there’s just fw-pass without volatile flag:

But in general, if I understand this design correctly, this means anytime you have a part of a network which isn’t backpropped through, you need to supply volatile flag? Then when you use that intermediate volatile variable in another part of the network which is backpropped through, you need to re-wrap and turn volatile off?

PS
If there’s interest, I could update & adapt the notebook to your answers, or merge the content into the existing “for torchies” notebook, and submit a PR to the tutorials repo.

6 Likes
  1. Yes. We don’t guarantee that the error will be raised, but if you want to be sure that you can backprop multiple times you need to specify retain_variables=True. It won’t raise an error only for very simple ops like the ones you have here (e.g. grad_input of add is just grad_output, so there’s no need for any buffers, and that’s why it also doesn’t check if they were freed). Not sure if we should add these checks or not. It probably doesn’t matter, as it will raise a clear error, and otherwise will still compute correct gradients.

  2. Yes, when you use the same net with the same input twice, it will construct a new graph, that will share all the leaves, but all other nodes will be exact copies of the first one, with separate state and buffers. Every operation you do on Variables adds one more node, so if you compute a function 4 times, you’ll always have 4x more nodes around (assuming all outputs stay in scope).

Now, here’s some description on when do we keep the state around:

  1. When finetuning a net, all the nodes before the first operation with trained weights won’t even require the gradient, and because of that they won’t keep the buffers around. So no memory wasted in this case.

  2. Test on non-volatile and detaching the outputs will keep the bottom part of the graph around, and it will require grad because the params do, so it will keep the buffers. In both cases it would help if all the generator parameters would have requires_grad set to False for a moment, or a volatile input would be used, and then the flag would be switched off on the generator output. Still, I wouldn’t say that it consumes more memory on every fw-pass - it will just increase the memory usage, but it will be a constant factor, not like a leak. The graph state will get freed as soon as the outputs will go out of scope (unlike Lua, Python uses refcounting).

There’s however one change that we’ll be rolling out soon - variables that don’t require_grad won’t keep a reference to the creator. This won’t help with inference without volatile, and it will still make the generator graph allocate the buffers, but the will be freed as soon as the output is detached. This won’t have any impact on the mem usage, since that memory would be already allocated after the first pass, and it can be reused by the discriminator afterwards.

Anyway, the examples will need to be fixed. Hope this helps, if something’s unclear just let me know. Also, you can read more about the flags in this note in the docs.

2 Likes

Thanks for the elaborate answer, the graph going out of scope with the output variable is the essential part I was missing here.

If these fixes are what you had in mind, then I’ll send a PR.

Let me know if you think it’s useful to make the notebook with your answer into a full tutorial, I think these autograd graph/state mechanics are a bit underdocumented atm. Or maybe some explanation could be added to the autograd note in the docs.

Yeah, they look good, only nit is to not put spaces around the equals in volatile=True.

I agree the materials we have right now aren’t very detailed but we didn’ thave a lot of time to expand them. If you’d have a moment to write that down and submit a notebook or a PR to the notes I’ll merge them. Thanks!

Thanks, this post is a savior. I had been getting the RuntimeError: Trying to backward through the graph second time, but the buffers have already been freed. in my net, then recreated the exact graph structure in a simplified toy class to diagnose the problem and think of a solution (I know setting retain_variables=True would have done it, but I wanted to overthink a bit), and could not reproduce the problem at all (which gave me quite a headache). I finally understand that for very simple operations the backward pass is not required to fail :+1:.

Hi,

I am still a little bit confused regarding the detach() method.

We do fake = netG(noise).detach() to prevent the backpropping through netG. Now, for the netG training, we do output = netD(fake). If we had lets say

fake = netG(inputs)
Loss1 = criterion(netD1(fake), real_label)
Loss2 =  criterion(netD2(fake), real_label)
Loss3 =  criterion(netD3(fake), real_label)
Loss_other = criterion_other(fake, target)
Loss = Loss1 + Loss2 + Loss3 + Loss_other
Loss.backward()

does this create the graphs for each of the netDs? Will be wrong if I did

Loss1 = criterion(netD1(fake).detach(), real_label)
Loss2 =  criterion(netD2(fake).detach(), real_label)
Loss3 =  criterion(netD3(fake).detach(), real_label)
Loss_other = criterion_other(fake, target)
Loss = Loss1 + Loss2 + Loss3 + Loss_other
Loss.backward()

to save some memory, since I don’t need to backprop through netD? Will there be any difference in backpropping?

Regards
Nabarun

Hi,

if I understand you correctly, you want to train netD1 with loss1, netD3 with loss2 ,… and netG with loss_other ?

Right now what you do is you calculate the output of netD1 and then detach this output, and then with the detached output calculate the loss1 (so basically the loss between a detached variable and a target,) so it will not propagate back to netD1 or netG (since you detached the variable (output), after you pass it through netD1).
What you probably want to do is:

Loss1 = criterion(netD1(fake.detach()), real_label)
Loss2 =  criterion(netD2(fake.detach()), real_label)
Loss3 =  criterion(netD3(fake.detach()), real_label)
Loss_other = criterion_other(fake, target)
Loss = Loss1 + Loss2 + Loss3 + Loss_other
Loss.backward()

, where you detach the fake, thus loss1, etc can be propagated back through netD1 etc (but still not though netG, if you want to propagate through netG and not netD1 you can try to set .requires_grad=False for all paramters in netD1, but not sure if it will work, since it only works on leaves).

Hope what I just told you is mostly correct and does not confuse you more.
Cheers

Hi,

Actually no, don’t want to train netD* at all, all the losses are for training the netG, and all netD* are fixed.

The last point of your reply kind of hit the point, I want to propagate through netG without propagating through netD*.

Why I thought it might work is because of this post Freezing parameters - #2 by ebetica

.requires_grad = False looks promising, but I am not sure either, would be great if someone can clarify on that.

@apaszke any thoughts on this?

Regards
Nabarun

Nabarun,
what you want is fundamentally impossible, you need to backprop through D wrt its inputs because of the chain rule. You don’t need to compute gradients of D wrt its parameters, which is what you avoid by setting requires_grad=False as in the GAN pytorch example code.
Tom

2 Likes

Adam, you said that:

Is it only valid for the same net with the same input? I am interested in this specific detail because I am trying to implement a RNN unrolling the input sequence in a loop and accumulating the gradient. In my case the input is a new Tensor (taken from the sequence) at every step of the iteration. Would I be able to achieve BPTT in this case?