Deriving part of the path with a smaller magnitude - gradients question

Hi,
So first of a simple pre-question to the main:

In train mode (with a model with dropouts, BN and e.t.c) is it true that if y = model(x), then model((x,x))=(y,y') and y=y' with all other identity matters (like in-variance to dropouts, backward path and e.t.c)

So to the main question:
I have the following scheme of forward and the light-blue and blue forward paths which are part of one forward path of my GAIN model (see Guided-Inferrence-Attention-Network for more details, which are not necessary for my question):

Pay attention that the second input image is derivable through the first blue path.

Now I want that the magnitude of the gradients of the light-blue path will be 1/10 of its original and the regular blue will be as usual 1.

So Iā€™ll make the following:

(y,yā€™) = model.forward((x,x)) # y=yā€™ from the above assumption on the pre-question
y,yā€™ = y,yā€™.detach()
light_blue_loss = loss_fn(yā€™, _) # _ wildcard for something, irrelevant
light_blue_loss.backward(gradient=-0.9)
total_loss = loss_fn(y, _) # _ wildcard for something, irrelevant
total_loss.backward() # total_loss.backward(gradient=1) default

Now, is this true that for the light-blue path on the detached input according to the scheme on top, the gradients were accumulated with a constant multiplying them by -0.9, thus after accumulating gradients with the total_loss backward with default gradient parameter 1, it can be said the light-blue path was influenced with gradients of magnitude 0.1 (1+(-0.9)) and the whole other path with magnitude 1 of the gradients ?

Iā€™ll be grateful for your answer, thank you.

A better way to accomplish this might be to use hooks:

import torch
import torch.nn as nn

def grads(mod):
    return list(x.grad for x in mod.parameters())

torch.manual_seed(0)

x = torch.tensor([1., 2.])
mod = nn.Linear(2, 2)
mod2 = nn.Linear(2, 2)

y = mod2(mod(x))
y.sum().backward()

print(grads(mod))
print(grads(mod2))

print("--------------")

torch.manual_seed(0)

x = torch.tensor([1., 2.])
mod = nn.Linear(2, 2)
mod2 = nn.Linear(2, 2)

for p in mod2.parameters():
    p.register_hook(lambda x: x/10)

y = mod2(mod(x))
y.sum().backward()

print(grads(mod))
print(grads(mod2))

Output:

[tensor([[-0.0768, -0.1535],
        [ 0.7478,  1.4955]]), tensor([-0.0768,  0.7478])]
[tensor([[ 0.4810, -1.4331],
        [ 0.4810, -1.4331]]), tensor([1., 1.])]
--------------
[tensor([[-0.0768, -0.1535],
        [ 0.7478,  1.4955]]), tensor([-0.0768,  0.7478])]
[tensor([[ 0.0481, -0.1433],
        [ 0.0481, -0.1433]]), tensor([0.1000, 0.1000])]

In train mode (with a model with dropouts, BN and e.t.c) is it true that if y = model(x), then model((x,x))=(y,y') and y=y' with all other identity matters (like in-variance to dropouts, backward path and e.t.c)

I guess it would be the same, but a better way of doing that would just to be to clone the output y and then detach?

@soulitzer, thatā€™s an interesting point, about the hooks.
What would you do if mod and mod2 have to be the same network with exactly the same parameters (as it is in my case) ? Could you think about easy way to handle it in this case?
As much as I can understand I need some variable to differ the point of using the hooks and not using them on the same parameters, right ?
like to support a flag of some specific layer, suppose I know that the last layer name is ā€œlast_layerā€(and during the backward path I should see it only twice in total), so my hook should look like:

use_hook = 0
def backward_hook(x):
   if cur_layer_name == 'last_layer':
       use_hook+=1
    if use_hook < 2:
        return x/10
    else:
       return x

And that is possible if I can indeed to check through the backward path using a hook the current layer name (can I?) ?

I need to think about that.

And another thing, correct me if I wrong, if you had two nn.Linear layers, you still have to do the hook only once because of the chain rule (and the multiplications of all the gradients according to it, correct ?)?

Or is should be done for all the differential layers (wouldnā€™t it be 10^{#layers_num} factor in that case) ?

And another thing, correct me if I wrong, if you had two nn.Linear layers, you still have to do the hook only once because of the chain rule (and the multiplications of all the gradients according to it, correct ?)?

Hmm I probably misunderstood your original question. In my example I used two modules in sequence to demonstrate that it is possible to use hooks to selectively alter the gradients of a subgraph of the original network, and in this particular case, so that the first moduleā€™s parameters remain the original gradient, but the second moduleā€™s gradients are a tenth of what they wouldā€™ve been.

We have to register a hook individually for each tensor (aka parameter) in module 2 because if we register the hook directly on the tensor, it triggers when grad is computed for that tensor, and alters only the gradient computation for that individual tensor.

I actually donā€™t quite understand what you are going for yet, but if you wanted to alter the gradients of both nn.Linear layers youā€™d need to individually register hooks for four tensors. Or do ā€œfor each mod.parameters()ā€ twice.

And that is possible if I can indeed to check through the backward path using a hook the current layer name

Not for hooks registered directly on a tensor. But you do still have access when you loop through the parameters of the module itself.

Iā€™ll make it a little bit clearer about that. I see what youā€™ve done, and itā€™s ok if you have two modules, but in my case, as in the image, it is the same module, thus the same parameters, so if I register a hook on them, it will be called actually twice in the backward path of the one called ā€˜am_lossā€™, and what I seek to achieve is a gradient divided by 10 only on the first time of two where the backward of the parameter is called.

So if I have only one module and Iā€™ll register a backward hook:

mod = nn.Linear(2, 2)

for p in mod2.parameters():
    p.register_hook(lambda x: x/10)

y_hat = mod(x)
y = mod(y_hat)
y.sum().backward()

in that case unwillingly the grads will be twice computed with division by 10, and I wanted only for y gradients will be divided by 10 and going back through y_hat the gradients will remain as is.

So I want to do a manipulation of the same parameters gradients, but they exists twice in the backward graph (if you look with torchviz on the graph of y you will see view functions that create a another second view of the parameter, but that is the same parameter, so if I register a hook, it will be called twice, and I want only once).

Thatā€™s why I thought to create a variable as a counter of amount of times I visited that parameter (I know it is 2 at most, so I will increase it every time the hook called, so if it is 1, then Iā€™ll devide the grads by 10 and if it is 2 I will not)

Ok, got you on that. I asked this because I know the chain rule is like Dy/Dz * Dz/Dw ā€¦
So I thought although it is individual for each tensor-parameter now we have Dy/(Dz*10)*Dz/(Dw*10) so it will be a mistake to do it for each tensor-parameter because of the math, but youā€™re saying that not the case and in case of a network of many layers, the correct logic is to register ā€˜devide_10_hookā€™ for each layer-parameters, is that correct ?
And the effect in this case will be:
1/10 * (Dy/Dz * Dz/Dw ā€¦)

Ok, got you on that. I asked this because I know the chain rule is like Dy/Dz * Dz/Dw ā€¦

Registering a hook to the each of the parameters of a module would not affect future gradient computation, because parameters are leaf tensors. If you registered a hook to a non-leaf tensor (e.g., the output of your module), yes it would do what you describe, and you would only need to register a single hook in that case.

So I guess you could just do the below actually:

mod = nn.Linear(2, 2)
y_hat = mod(x)
y = mod(y_hat)
y.register_hook(lambda x: x/10)
y.sum().backward()

In this case, you can avoid registering hooks directly on the parameters as well so it should be okay to use the same mod twice. But the issue is that it would also affect gradient computation of the second backward as well. So we can actually register another hook to multiply by 10 again.

mod = nn.Linear(2, 2)
y_hat = mod(x)
y_hat.register_hook(lambda x: x*10)
y = mod(y_hat)
y.register_hook(lambda x: x/10)
y.sum().backward()

Does this resolve your issue?

1 Like

That idea also came to my mind, about the multiplication by 10. I think it will resolve it.

As an alternative, suppose mod is a model with environment where I can save a counter, which role is to determine the number of the visit through the backward through this one layer.
So using this counter as an indicator with an if statement what to do when the hook called:
If it is the first visit (there is a deterministic definite order of the backwards because of the forwards sequence) I need to divide it by 10, if it is the second visit, I should do nothing (just zero the the counter for the next forward and backward iteration), did you get my suggestion and its logic by that way?
Thatā€™s how I can avoid the multiplication on the second output, although in that solution my hook will be on the parameters of the module and not on the output (where the problem is the distinction between the two visits through backward function which I think can be resolved by a simple counter).

But this solution with division and then a correction by multiplication should do the job I think, thatā€™s good also and maybe even more simple.

It is good to consider the alternatives indeed - the counter method might work also:
One way to do it is actually save the counter as the functions default argument! This avoids maintaining an extra global set of counters for each parameter.

fns = []

for i in range(10):
    def fn(x = [0]):
        x[0] += 1
        print(x[0])
    fns.append(fn)

fns[0]() # 1
fns[1]() # 1
fns[2]() # 1

fns[0]() # 2
1 Like

@soulitzer, hmm ā€¦ never wrote such code with a counter as a default argument, can it be zeroed ?
like some code which will do:

fns = []

for i in range(10):
    def fn(x = [0]):
        x[0] += 1
        print(x[0])
    fns.append(fn)

fns[0]() # 1
fns[1]() # 1
fns[2]() # 1

fns[0]() # 2

'some_zero_code()'
fns[0]() # 0

Uhhh, I guess you could just do the counter mod 2, i.e. only scale on every other invocation haha. But this is all hypothetical anyway - I wouldnā€™t recommend anyone to write such code.

1 Like

@albanD, Hi, will be glad to hear your opinion about the correctness of this solution, because I find it the easiest to experiment with.

Think it should work ?

In my case, Iā€™ll register a hook on the am_loss which will divide by 10 and another hook on the input of the second path, which is the masked_image, which will multiply the grads back by 10, so Iā€™ll get the light-blue path gradients divided by 10 and the blue as regular, will I ?

Another solution as we discussed is to register a hook for each parameter in the module, which is a little bit painful (and to maintain counters or as @soulitzer suggested with internal counters trick and e.t.c which is still isnā€™t nice). But the idea using many hook for modules which are in nn.Sequential isnā€™t likable, especially to to the fact that right now the backward_hook for nn.Sequential is broken in terms of it only works on the gradients of the last module in it and not on every one, so it means I need to loop over each module handly and register a hook which will divide by 10 the gradients.

That solution would work yes.

For Sequential hooks, you can now use the full_backward_hook that actually work :stuck_out_tongue:

1 Like

good news. Didnā€™t find the docs about it, do you mind to share a link :slight_smile: ?

1.9 was released 2 minutes ago: Module ā€” PyTorch 1.9.0 documentation haha

1 Like