Backpropagate a fixed gradient through a network

I wanted to do something like this:
out1 = net1(input) out2 = net2(out1) other_out = net3(out1)

I want to update net1 with the weighted sum of gradient from out2 i.e. out2.backward() and other_out. So the similar implementation in torch7 would be:
dnet2 = net2:backward(...some argument....) net1:backward(input, l1* dnet2 + l2*other_out)
where l1 and l2 are some constants. The dimensions of dnet2 and other_out are same.

How should I do something like this in pytorch? @smth

assuming out2 and other_out are scalars you can do this:

total_loss = out2 * l1 + other_out * l2
total_loss.backward()

For this to work as I want, I would have to make requries_grad of net2 and net3 False.
And I do not want the gradient of net3 to backpropagate in net1 but the output of net3, I doubt just making requires_grad of net3 False will do this?

Well you can’t get the gradient out of nowhere. You’ll still need to differentiate all of net3, and the only thing that changing requires_grad can save you is computing grads w.r.t. weights.

Ah I think I misread what you wrote. If you want to do the backward with grad_output as a linear combination of these outputs you can do this:

out1 = net1(input)
out2 = net2(out1)
other_out = net3(out1)
out1.backward(l1 * out2 + l2 * other_out)
1 Like

The thing is that out2 and other_out are not of same dimension. other_out is of the same dimension as out1 while out2 is some other dimension. So I wanted to backpropagate the just net2 and extract the gradient which would be input to net1. Then I can write a linear combination of that gradient with other_out and do a backward on net1 with the combined gradient. How do I do this in pytorch?
Thanks

Then you can do this:

out1 = net1(input)
out2 = net2(out1)
other_out = net3(out1)
# grad_out2 can be Variable(torch.ones(1)) if out2 is a scalar, or it can be None if you use the master branch
torch.autograd.backward([out1, l2 * out2], [l1 * other_out, grad_out2]) 
1 Like

Thanks for the help. I believe I can multiply the constants l1 and l2 in either the variables or the grad_variables argument, right or is it specific?

I am getting the same error as was before which is:
RuntimeError: Trying to backward through the graph second time, but the buffers have already been freed. Please specify retain_variables=True when calling backward for the first time.

I think when it does backward on out1 with other_out, it frees the buffer of net1 and not able to do backward while doing backward on out2. Would using a cloned variable in the second argument work?
I wrote something like this:
torch.autograd.backward([out1, out2], [l2 * other_out.data, l1 * grad_out2])
where grad_output is Tensor.
Let me know if its wrong.

UPDATE: If I do retain_variables=True I get an error like this:
RuntimeError: dependency not found for N5torch8autograd12ConvBackwardE
Please let me know if its a pytorch issue or is something wrong with my implementation? I am using the master branch version of pytorch.

Any updates with the issue @smth @apaszke?

I have got a deadline to follow. Can anyone guide me through this?

if you are using master branch, the “dependency not found” is a compile issue on your side.

ummm, give me a small snippet and i can try to help you (i dont have a lot of time right now either).

Thanks for the help. I understand you must be busy. I just wanted to know what the error was.
So,
out1 = net1(input) err1 = net2(out1) out2 = net3(out1) newd = out2.data - out1.data torch.autograd.backward([err1, out1], [grad_out, -someWeight * newd], retain_=True)
Here, out2 and out1 are of same dimension and grad_out is a tensor. Let me know if you need more info. Thanks!

can you give me a script with this exact code snippet in your comment (but with net1, net2, etc. defined to something), so that I can run it.

grad_out has to be same dimension as err1

Yeah it is. The codebase consists of whole lot if things. I can upload on github and let you run it. I can share it with you since it is a private repository.

i dont need the entire codebase, based on the snippet in this comment Backpropagate a fixed gradient through a network just make a fake script that reproduces the error that’s small.

I sent you a message with the gist of the code snippet. Please take a look at it.

we’re looking into this. it is a bug on the master branch. let me see if we can issue a quick patch or find a workaround for you.

This is a failure case because fake here is a non-leaf Variable and torch.autograd.backward has a dependency analysis bug. I am going to open a bug report for this.
For now, here is your workaround, and good luck with your deadline:

noise.resize_(opt.batchSize, nz, 1, 1).normal_(0, 1)
noisev = Variable(noise)
fake_v = netG(noisev)

fake = Variable(fake_v.data, requires_grad=True)
errG = netD(fake)
rec = netA(fake)
newd = rec.data - fake.data
errG.backward()
fake_v.backward(fake.grad.data + (-opt.daeWeight * newd))

For reference, here’s the opened issue: https://github.com/pytorch/pytorch/issues/1605

2 Likes