Different backwards through different branches

I have a modified version of Inception ResNet v2, where I have created a two branches for an additional task, and the output of one branch then goes into a decoder. The issue is that when I create these various branches, I do not want my loss to backpropagate through all the network, but specific losses should affect specific layers. So I created 3 different optimisers, each with their own set of network parameters which I want them to train. But I get the issue:

File "irv2_m_trainer.py", line 220, in <module>
    totalLossM.backward()
  File "/home/azhan/.local/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/azhan/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

The model is given as:

class Inception_ResNetv2(nn.Module):
    def __init__(self, in_channels=3, feature_size=256, n_classes=10785, k=256, l=256, m=384, n=384):
        super(Inception_ResNetv2, self).__init__()
        stem = []
        stem.append(Stem(in_channels))
        t = []
        m = []
        for i in range(6):
            m.append(Inception_ResNet_A(320, 0.17))
        for i in range(10):
            t.append(Inception_ResNet_A(320, 0.17))
        t.append(Reduction_A(320, k, l, m, n))
        for i in range(20):
            t.append(Inception_ResNet_B(1088, 0.10))
        t.append(Reduciton_B(1088))
        for i in range(9):
            t.append(Inception_ResNet_C(2080, 0.20))
        t.append(Inception_ResNet_C(2080, activation=False))
        self.stem = nn.Sequential(*stem)
        self.t_features = nn.Sequential(*t)
        self.m_net = nn.Sequential(*m)
        self.conv = Conv2d(2080, 1536, 1, stride=1, padding=0, bias=False)
        self.global_average_pooling = nn.AdaptiveAvgPool2d((1, 1))
        self.t_feature_layer = nn.Linear(1536, feature_size)
        self.m_feature_layer = nn.Sequential(nn.Conv2d(320,768,3,stride=1),
                                                    nn.Conv2d(768,768,3,stride=2),
                                                    nn.Conv2d(768,896,3,stride=1),
                                                    nn.Conv2d(896,1024,3,stride=2),
                                                    nn.Conv2d(1024,1024,3,stride=1),
                                                    nn.MaxPool2d(2,stride=2),
                                                    nn.Flatten(),
                                                    nn.Linear(1024,256))
        self.m_classification_layer = nn.Linear(256,n_classes)
        self.t_classification_layer = nn.Linear(256,n_classes)

    def forward(self, x):
        intermediate_output = self.stem(x)
        m_intermediate_output = self.m_net(intermediate_output)
        m_feature = self.m_feature_layer(m_intermediate_output)
        m_classes = self.m_classification_layer(m_feature)
        t_intermediate_output = self.t_features(intermediate_output)
        t_intermediate_output = self.conv(t_intermediate_output)
        t_intermediate_output = self.global_average_pooling(t_intermediate_output)
        t_intermediate_output = t_intermediate_output.view(t_intermediate_output.size(0), -1)
        t_feature = self.t_feature_layer(t_intermediate_output)
        t_classes = self.t_classification_layer(t_feature)
        return m_intermediate_output, m_feature, m_classes, t_feature, t_classes

The decoder is given as:

class M_DecoderNetwork(nn.Module):
    def __init__(self):
        super(M_DecoderNetwork, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(320, 128, 3, stride=2, padding=0, bias=True)
        self.up1 = nn.Upsample(size=(70,70))
        self.conv1   = nn.Conv2d(128,128,7,stride=1, padding=0)
        self.up2 = nn.Upsample(size=(128,128))
        self.deconv2 = nn.ConvTranspose2d(128, 32,3,stride=1,padding=1,bias=True)
        self.up3 = nn.Upsample(size=(256,256))
        self.conv2   = nn.Conv2d(32,12,3,stride=1,padding=1)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.deconv1(x)
        x = self.up1(x)
        x = self.conv1(x)
        x = self.up2(x)
        x = self.deconv2(x)
        x = self.up3(x)
        x = self.conv2(x)
        x = (self.tanh(x)+1.)/2.
        return x

I want the following branches of the network to be trained using the mentioned losses:

  • Loss 1: Trains model.stem, model.t_features, model.conv, model.t_feature_layer and model.t_classification_layer.
  • Loss 2: Trains model.stem, model.m_net and decoder.
  • Loss 3: Trains model.stem, model.m_net, model.m_feature_layer and model.m_classification_layer.
    The optimisers are declared as given below:
t_params = list(model.stem.parameters()) + list(model.t_features.parameters()) + list(model.conv.parameters()) + list(model.t_feature_layer.parameters()) + list(model.t_classification_layer.parameters())

m_params =  list(model.stem.parameters()) + list(model.m_net.parameters()) + list(model.m_feature_layer.parameters()) + list(model.m_classification_layer.parameters())

m_decoder_params = list(model.stem.parameters()) + list(model.m_net.parameters()) + list(decoder.parameters())

optimizer_t = torch.optim.Adam(t_params, lr=1e-3)

optimizer_m_decoder = torch.optim.Adam(m_decoder_params, lr=1e-3)

optimizer_m = torch.optim.Adam(m_params, lr=1e-3)

I have called the prediction and backward functions as shown:

m_intermediate_output, m_features, m_classes, t_features, t_classes = model(images)

tripletLossT = criterion_triplet_loss(labels_ver, t_features, device)
crossEntropyLossT = criterion_cross_entropy(t_classes, labels_cls)
totalLossT = tripletLossT + crossEntropyLossT
        
totalLossT.backward()
optimizer_t.step()

optimizer_m_decoder.zero_grad()

predicted_m = decoder(m_intermediate_output.detach())

reconstructionLoss = criterion_bce(predicted_m, gt_m)
totalLossMDecoder = 1e4*reconstructionLoss 

totalLossMDecoder.backward()
optimizer_m_decoder.step()

optimizer_m.zero_grad()

tripletLossM= criterion_triplet_loss(labels_ver, m_features.detach(), device)
crossEntropyLossM = criterion_cross_entropy(m_classes.detach(), labels_cls)
totalLossM = tripletLossM + crossEntropyLossM

totalLossM.backward()
optimizer_m.step()

Can someone please help me in understanding what is wrong?

It seems as if you are explicitly detaching the tensors from the computation graph in:

tripletLossM= criterion_triplet_loss(labels_ver, m_features.detach(), device)
crossEntropyLossM = criterion_cross_entropy(m_classes.detach(), labels_cls)
totalLossM = tripletLossM + crossEntropyLossM

totalLossM.backward()
optimizer_m.step()

which will yield this error.
Could you explain why you want to detach() them?

If I do not detach, I get another error which is:

Traceback (most recent call last):
  File "irv2_m_trainer.py", line 217, in <module>
    totalLossM.backward()
  File "/home/azhan/.local/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/azhan/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

So I kept retain_graph=True in both the losses backward call and get the error:

Traceback (most recent call last):
  File "irv2_m_trainer.py", line 224, in <module>
    totalLossMDecoder.backward()
  File "/home/azhan/.local/lib/python3.8/site-packages/torch/_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/azhan/.local/lib/python3.8/site-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64]] is at version 3; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I don’t know how the different optimizers etc. are initialized but it generally seems you are trying to calculate multiple losses, compute the gradients with them, and step() with different optimizers.
For this to work, make sure the computation graph is alive if you need to reuse it later and also make sure the forward activations aren’t “stale” after an optimizer.step() operation was performed.
This post explains these errors in a GAN setup.

1 Like