This is probably not a common use case, but I am having this issue now. I try to understand how the network is going to update when there are multiple gradients.
In my case, given 2 networks: a_network and b_network, the loss function is defined as
loss = b - a
where b is computed by b_network (using outputs from a_network) and a is computed by a_network.
We also need to perform sampling:
(1) sampling a_network M times
(2) for each samples from a_network, sampling b_network N times
The Pseudo-code (sorry that the actual code is very large):
#b_network is pre-trained and no updates during this training for param in b_network.parameters(): param.requires_grad = False optimizer.zero_grad() a_sampling_times = M b_sampling_times = N #do sampling with a_network for _ in range(a_sampling_times): a_network_outputs = a_network(inputs) a_network_samples = a_network_sampling(a_network_outputs) a_network_loss = a_network_compute_loss(a_network_outputs) #do sampling with b_network for b_sampling_time in range(b_sampling_times): b_network_outputs = b_network(a_network_outputs) b_network_loss = b_network_compute_loss(b_network_outputs, a_network_outputs) #the loss function loss = b_network_loss - a_network_loss #compute gradients if b_sampling_time == b_sampling_times: loss.backward() else: loss.backward(retain_graph=True) #updates optimizer.step()
retain_graph=True, the gradients of a_network_loss will be computed N times. What will happen when
optimizer.step() is called? How are the network’s parameters going to be updated with those N gradients?