Hi.
This is probably not a common use case, but I am having this issue now. I try to understand how the network is going to update when there are multiple gradients.
In my case, given 2 networks: a_network and b_network, the loss function is defined as
loss = b - a
where b is computed by b_network (using outputs from a_network) and a is computed by a_network.
We also need to perform sampling:
(1) sampling a_network M times
(2) for each samples from a_network, sampling b_network N times
The Pseudo-code (sorry that the actual code is very large):
#b_network is pre-trained and no updates during this training
for param in b_network.parameters():
param.requires_grad = False
optimizer.zero_grad()
a_sampling_times = M
b_sampling_times = N
#do sampling with a_network
for _ in range(a_sampling_times):
a_network_outputs = a_network(inputs)
a_network_samples = a_network_sampling(a_network_outputs)
a_network_loss = a_network_compute_loss(a_network_outputs)
#do sampling with b_network
for b_sampling_time in range(b_sampling_times):
b_network_outputs = b_network(a_network_outputs)
b_network_loss = b_network_compute_loss(b_network_outputs, a_network_outputs)
#the loss function
loss = b_network_loss - a_network_loss
#compute gradients
if b_sampling_time == b_sampling_times:
loss.backward()
else:
loss.backward(retain_graph=True)
#updates
optimizer.step()
With retain_graph=True
, the gradients of a_network_loss will be computed N times. What will happen when optimizer.step()
is called? How are the network’s parameters going to be updated with those N gradients?
Thanks