# Manually scaling gradients during backward pass

Hi,
I have a network that is trained with two loss functions, loss1 and loss2. The total loss of the network is given by loss1+loss2 (with no weights). What I would like to do is ensure that the gradients of both losses have the same magnitude/norm. Theoretically, I could just determine some weight empirically for reweighting the losses, but this is really suboptimal because the relative norm of the gradients can and most likely will change there is no way to guarantee that some weighting will be optimal during the entire training.
Here is a minimalistic example with two random loss functions to demonstrate what I would like to do:

``````lin = nn.Linear(32, 16, False)
inp = torch.rand((1, 32))
out = lin(inp)

loss1 = 1 - out.std()
loss2 = out.sum()
loss = loss1 + loss2

loss.backward()
``````

Basically I need the gradients coming from loss1 and loss2 on the fly during the backward pass. This will allow me to manually increase or decrease the scale/norm of the gradients to make them match between the loss functions.
The first thing I did was using a backward hook on the linear layer. This will however only give the gradients after the summation has already occured:

``````def print_hook(module, grad_in, grad_out):
print(module)
print([i.shape if i is not None else 'None' for i in grad_in])
print([i.shape if i is not None else 'None' for i in grad_out])

lin = nn.Linear(32, 16, False)
lin.register_backward_hook(print_hook)
inp = torch.rand((1, 32))

out = lin(inp)

loss1 = 1 - out.std()
loss2 = out.sum()
loss = loss1 + loss2

loss.backward()
``````

Linear(in_features=32, out_features=16, bias=False)
[‘None’, torch.Size([32, 16])]
[torch.Size([1, 16])]

The gradients here are already post-summation (which is to be expected). I need the gradients that go into the sum. So attaching a backward hook to the loss Tensor was the next step:

``````def print_hook(grad):

lin = nn.Linear(32, 16, False)
inp = torch.rand((1, 32))

out = lin(inp)

loss1 = 1 - out.std()
loss2 = out.sum()
loss = loss1 + loss2

loss.register_hook(print_hook)

loss.backward()
``````

torch.Size([])
tensor(1.)

Again, not what I need. Next attempt was to create a custom sum module to which I can attach a hook:

``````def print_hook(module, grad_in, grad_out):
print(module)
print([i.shape if i is not None else 'None' for i in grad_in])
print([i.shape if i is not None else 'None' for i in grad_out])
print([i if i is not None else 'None' for i in grad_in])
print([i if i is not None else 'None' for i in grad_out])

class SumModule(nn.Module):
def __init__(self):
super().__init__()

def forward(self, a, b):
return a + b

lin = nn.Linear(32, 16, False)
inp = torch.rand((1, 32))
my_sum = SumModule()
my_sum.register_backward_hook(print_hook)

out = lin(inp)

loss1 = 1 - out.std()
loss2 = out.sum()
loss = my_sum(loss1, loss2)

loss.backward()