How to prevent one output in a multi-output network?

My code for multi-task training (segmentation task and reconstruction task):

class MyNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.layer1 = nn.Conv2d(3, 32, 3)
    self.layer2 = nn.Conv2d(32, 64, 3)
    self.layer3 = nn.Conv2d(64, 2, 3)
    self.layer4 = nn.Conv2d(64, 1, 3)

  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)
    out1 = self.layer3(out)
    out2 = self.layer4(out)

    return out1, out2


During training, I get 2 outputs and compute their losses respectively. So that I can train the net alternatively.

x, y = dataloader()  # generate a batch of data
net = MyNet()
opt = torch.optim.Adam(self.net.parameters(), lr=1e-4)

pred1, pred2 = net(x)
loss1 = seg_loss(pred1, y)
loss2 = rec_loss(pred2, x)
loss1.backward()
opt.step()
loss2.backward()
opt.step()
```
Correct, right?

One day, I want to just train the reconstruction branch. The I would use the following code:

x, y = dataloader() # generate a batch of data
net = MyNet()
opt = torch.optim.Adam(self.net.parameters(), lr=1e-4)

pred1, pred2 = net(x)
loss2 = rec_loss(pred2, x)
loss2.backward()
opt.step()


It works well. Pred1 does not participate backward as I expected. But I found that the forward of pred1 also accupy some GPU memory. In order to save GPU memory, how to prevent the forward of pred1???

If you want to keep this output during reconstruction-only training, and prevent autograd from storing its intermediate tensors, you can wrap self.layer3(out) execution with torch.set_grad_enabled(False) context manager. In order to swap between your training regimes, you can use the following code example:

class MyNet(nn.Module):
  def __init__(self):
    super().__init__()
    self.seg_train = True
    self.layer1 = nn.Conv2d(3, 32, 3)
    self.layer2 = nn.Conv2d(32, 64, 3)
    self.layer3 = nn.Conv2d(64, 2, 3)
    self.layer4 = nn.Conv2d(64, 1, 3)

  def forward(self, x):
    out = self.layer1(x)
    out = self.layer2(out)

    # in grad is not enabled - we should not enable it
    grad_enabled = torch.is_grad_enabled()
    with torch.set_grad_enabled(self.seg_train and grad_enabled):
        out1 = self.layer3(out)

    out2 = self.layer4(out)

    return out1, out2

Then you will be able to swap training regimes by chaning model.seg_train variable.

thanks, another question is:

If I finish the training of “Seg+Recon” Net, and save the whole net to disk. In the future if I want to load the trainned net and continue to fine-tune the “Recon” Net. Can your code work in this case??

Or If I finish the training of “Recon” Net, and save the “Recon” Net to disk. In the future if I want to load the trainned “Recon” and continue to fine-tune the “Seg+Recon” Net. Can your code work in this case??

I am just curious if the saved model will include the “self.layer3” even if I did not use it in “forward” function.