How can I control gradient computation in cascaded operation with shared parameters

Hello,

I was designing some structure with cascaded operations sharing same parameters, i.e., something like loss = g(f(x, w), w). Normally in a backward pass, both f and g will contribute gradients to w. I was wondering if I am able to control the backward pass so that only one of f or g contributes to the gradient of w.

I know a naive say of doing this is to copy w as w1 and operate by loss = g(f(x, w), w1). However, since f and g might be a complicated network and w be a series of parameters, I was wondering if we can do this without copying parameters but by just setting some flags.

It is possible to do this by detaching before one of the uses i.e. g(f(x, w.detach()), w). I’m not sure why you would want to do this though because your gradient would be wrong. To consider an example: If f, g are both mul operations, your whole function would be xw^2 and dw would be 2wx, but if you detach before one of the uses, one of the w’s would be treated as a constant, and your dw would be wx a factor of 2 off from the correct gradient.

Thank you for the reply!

Yes, detaching w could be a solution. But as mentioned in the original post, f and g might be a complicated network and contain many layers and w parameters. I was wondering if there is a way to detach all the parameters of a network in just one operation?

My use case comes from invertible networks. The computation is of a form of f_w^{-1} (h(f_w(x))) and I would like to observe how will f^{-1} or f contribute to the training respectively. I would be great if you know any references on similar bidirectional training and could share with me:)

Perhaps the easiest way to do this is to just have two copies of each f or g module one as-is and one with all the parameters detached, and just swap in the detached one whenever you want to only one to contribute to the backward pass during the next backward. (If you are asking about how to iterate through the parameters of a module, you can use mod.parameters() which will iterate recursively through the current module as well as its submodule’s parameters.)

Invertible networks are pretty cool, are you trying to use them for the purposes of memory reduction or normalizing flows or some other application?

Thanks. Copying f and g is a baseline solution. I will do it this way if no better solution is found.

Yes. I am playing some normalizing-flow-like structures and would like to see if training it in a bidirectional fashion of both GAN and BPD will produce any new discovery (perhaps similar to FlowGAN). Let me know if you have any experience on this:)