# How can I control gradient computation in cascaded operation with shared parameters

Hello,

I was designing some structure with cascaded operations sharing same parameters, i.e., something like `loss = g(f(x, w), w)`. Normally in a backward pass, both `f` and `g` will contribute gradients to `w`. I was wondering if I am able to control the backward pass so that only one of `f` or `g` contributes to the gradient of `w`.

I know a naive say of doing this is to copy `w` as `w1` and operate by `loss = g(f(x, w), w1)`. However, since `f` and `g` might be a complicated network and `w` be a series of parameters, I was wondering if we can do this without copying parameters but by just setting some flags.

It is possible to do this by detaching before one of the uses i.e. `g(f(x, w.detach()), w)`. I’m not sure why you would want to do this though because your gradient would be wrong. To consider an example: If f, g are both mul operations, your whole function would be `xw^2` and `dw` would be `2wx`, but if you detach before one of the uses, one of the w’s would be treated as a constant, and your `dw` would be `wx` a factor of 2 off from the correct gradient.

Yes, detaching w could be a solution. But as mentioned in the original post, `f` and `g` might be a complicated network and contain many layers and `w` parameters. I was wondering if there is a way to detach all the parameters of a network in just one operation?
My use case comes from invertible networks. The computation is of a form of `f_w^{-1} (h(f_w(x)))` and I would like to observe how will `f^{-1}` or `f` contribute to the training respectively. I would be great if you know any references on similar bidirectional training and could share with me:)
Perhaps the easiest way to do this is to just have two copies of each `f` or `g` module one as-is and one with all the parameters detached, and just swap in the detached one whenever you want to only one to contribute to the backward pass during the next backward. (If you are asking about how to iterate through the parameters of a module, you can use `mod.parameters()` which will iterate recursively through the current module as well as its submodule’s parameters.)
Thanks. Copying `f` and `g` is a baseline solution. I will do it this way if no better solution is found.