Backpropagate through one of the network modules

Hi everyone.

I have a model that is built on the top of multiple nn.Modules and my model have two loss functions. The error of one of these two loss functions will have to backpropagate through the whole network while the second error should only update one of the modules that build the model.

As an example, my modules are:

class first_module(nn.Module):
#some code here#

class second_module(nn.Module):
#some code here#

My final network is something like:

class model(nn.Module):
def __ init __(self):
self.first_part = first_module()
self.second_part = second_module()

def forward(input):
x = self.first_part (input)
y = self.second_part(x)
z = some_operations(y)

return z,y

Now on my main code i have:

network = model()
output_1, output_2 = network(image)
loss_1 = error_fn(output_1 , labels) # main loss
loss_2 = error_fn(output_2, ref) # i want this error to only update layers on (class second_module)

In other words, second_module is updated jointly by loss_1 and loss_2, but other parts of the modules are only updated by loss_1.

Your help is highly appreciated.

You could use the inputs argument in backward to specify the gradient calculation for these parameters:

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.first_part = nn.Linear(10, 10)
        self.second_part = nn.Linear(10, 10)

    def forward(self, input):
        x = self.first_part(input)
        y = self.second_part(x)
        z = y * 2.
        return z, y

network = Model()
image = torch.randn(1, 10)
output_1, output_2 = network(image)
loss_1 = output_1.mean()
loss_2 = output_2.mean()

# make sure gradients are empty
for name, param in network.named_parameters():
    if param.grad is not None:
        print('{}.grad.abs().sum(): {}'.format(name, param.grad.abs().sum()))


# calculate gradients for `second_part` only
loss_2.backward(inputs=list(network.second_part.parameters()), retain_graph=True)

# check grads
for name, param in network.named_parameters():
    if param.grad is not None:
        print('{}.grad.abs().sum(): {}'.format(name, param.grad.abs().sum()))

# calculate all gradients
loss_1.backward()
for name, param in network.named_parameters():
    if param.grad is not None:
        print('{}.grad.abs().sum(): {}'.format(name, param.grad.abs().sum()))
1 Like

Many thanks for the reply. That is very helpful. It worked :slight_smile:

One followup question that might be simple: given that I am having optimizer.zero_grad() at the start of each epoch iteration during my training, is there anything else that i need to do after loss_2.backward(…).
I am just confused about why you are making sure that gradients are empty after every step above. What if they aren’t (which is indeed the case after the first backwards).

Checking for the empty (or zero) gradients was just used in this test to show that the behavior is indeed as I want it to be. I.e. the first

loss_2.backward(inputs=list(network.second_part.parameters()), retain_graph=True)

call would populate the gradients of second_part's parameters and leave the others untouched.

You should be fine using optimizer.zero_grad() at the beginning.

1 Like