What will happen with multiple gradients?


This is probably not a common use case, but I am having this issue now. I try to understand how the network is going to update when there are multiple gradients.

In my case, given 2 networks: a_network and b_network, the loss function is defined as

loss = b - a

where b is computed by b_network (using outputs from a_network) and a is computed by a_network.
We also need to perform sampling:
(1) sampling a_network M times
(2) for each samples from a_network, sampling b_network N times

The Pseudo-code (sorry that the actual code is very large):

#b_network is pre-trained and no updates during this training
for param in b_network.parameters():
        param.requires_grad = False


a_sampling_times = M
b_sampling_times = N

#do sampling with a_network
for _ in range(a_sampling_times):
    a_network_outputs = a_network(inputs)
    a_network_samples = a_network_sampling(a_network_outputs)
    a_network_loss = a_network_compute_loss(a_network_outputs)

    #do sampling with b_network
    for b_sampling_time in range(b_sampling_times):
        b_network_outputs = b_network(a_network_outputs)
        b_network_loss = b_network_compute_loss(b_network_outputs, a_network_outputs)
        #the loss function
        loss = b_network_loss - a_network_loss
        #compute gradients
        if b_sampling_time == b_sampling_times:


With retain_graph=True, the gradients of a_network_loss will be computed N times. What will happen when optimizer.step() is called? How are the network’s parameters going to be updated with those N gradients?


By default the gradients will be accumulated.
If you call N times backward on your output, the N gradients will be summed in for your parameters.

w = nn.Parameter(torch.ones(1))
x = torch.ones(1) * 2

for _ in range(3):
    output = x * w

> tensor([2.])
> tensor([4.])
> tensor([6.])

If you need to compute the gradients several times, you might scale your learning rate a bit down, but of course this depends on your use case, the model etc.


Thanks, I understand it now.