I have a model that is built on the top of multiple nn.Modules and my model have two loss functions. The error of one of these two loss functions will have to backpropagate through the whole network while the second error should only update one of the modules that build the model.
def forward(input):
x = self.first_part (input)
y = self.second_part(x)
z = some_operations(y)
return z,y
Now on my main code i have:
network = model()
output_1, output_2 = network(image)
loss_1 = error_fn(output_1 , labels) # main loss
loss_2 = error_fn(output_2, ref) # i want this error to only update layers on (class second_module)
In other words, second_module is updated jointly by loss_1 and loss_2, but other parts of the modules are only updated by loss_1.
You could use the inputs argument in backward to specify the gradient calculation for these parameters:
class Model(nn.Module):
def __init__(self):
super().__init__()
self.first_part = nn.Linear(10, 10)
self.second_part = nn.Linear(10, 10)
def forward(self, input):
x = self.first_part(input)
y = self.second_part(x)
z = y * 2.
return z, y
network = Model()
image = torch.randn(1, 10)
output_1, output_2 = network(image)
loss_1 = output_1.mean()
loss_2 = output_2.mean()
# make sure gradients are empty
for name, param in network.named_parameters():
if param.grad is not None:
print('{}.grad.abs().sum(): {}'.format(name, param.grad.abs().sum()))
# calculate gradients for `second_part` only
loss_2.backward(inputs=list(network.second_part.parameters()), retain_graph=True)
# check grads
for name, param in network.named_parameters():
if param.grad is not None:
print('{}.grad.abs().sum(): {}'.format(name, param.grad.abs().sum()))
# calculate all gradients
loss_1.backward()
for name, param in network.named_parameters():
if param.grad is not None:
print('{}.grad.abs().sum(): {}'.format(name, param.grad.abs().sum()))
Many thanks for the reply. That is very helpful. It worked
One followup question that might be simple: given that I am having optimizer.zero_grad() at the start of each epoch iteration during my training, is there anything else that i need to do after loss_2.backward(…).
I am just confused about why you are making sure that gradients are empty after every step above. What if they aren’t (which is indeed the case after the first backwards).