How to reverse gradient sign during backprop?

For example, I have models net0 and net1 working like this

out0 = net0(inp0)
out1 = net0(net1(net0(inp1)))
loss = criterion(out0, out1, target)

I want to reverse the gradient sign from the output of model net1 but calculate net0(inp0) as usual.

In simple case, I would do

out0 = net0(inp0)
out1 = net0(net1(inp1))
loss = criterion(out0, out1, target)
loss.backward()
[p.grad.data.neg_() for p in net1.parameters()]
opt.step()

Can I do some hook to reverse gradients from output of net1?

loss = -criterion(out0, out1, target)

Its not my case, i want to reverse partially and from one branch.

Oh sorry I tend to read partially :smiley:
The approach you are following is the simplest one IMO. In fact when you call backward you can see an iterative process like that so I don’t think it would have the performance.

Still you can use Module — PyTorch 1.10.0 documentation

register_full_backward_hook

My solution will work in simple case without using net0 two times.
I will check docs, thanks.

Hi Had!

As an alternative to using a hook, you could write a custom Function
whose forward() simply passes through the tensor(s) unchanged, but
whose backward() flips the sign of the gradient(s).

You would then insert it at the desired place in your network, e.g.:

out1 = net0 (GradientReversalFunction.apply (net1 (inp1)))

I don’t really have an opinion about which method is cleaner.

Best.

K. Frank

Much cleaner with function, because I can’t find tutorials or examples of hook use.
So I need to apply the function on the output tensor and all previous calculations will be with reversed gradients, right?

Also, while googling function, I found this example

-x + (x*2).detach()

Hi,
here you ahve a tutorial Debugging and Visualisation in PyTorch using Hooks which looks nice.
Here you have a tutorial about how to modify the backward function:
Extending PyTorch — PyTorch master documentation

Thanks a lot you both