My code for multi-task training (segmentation task and reconstruction task):
class MyNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Conv2d(3, 32, 3)
self.layer2 = nn.Conv2d(32, 64, 3)
self.layer3 = nn.Conv2d(64, 2, 3)
self.layer4 = nn.Conv2d(64, 1, 3)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out1 = self.layer3(out)
out2 = self.layer4(out)
return out1, out2
During training, I get 2 outputs and compute their losses respectively. So that I can train the net alternatively.
x, y = dataloader() # generate a batch of data
net = MyNet()
opt = torch.optim.Adam(self.net.parameters(), lr=1e-4)
pred1, pred2 = net(x)
loss1 = seg_loss(pred1, y)
loss2 = rec_loss(pred2, x)
loss1.backward()
opt.step()
loss2.backward()
opt.step()
```
Correct, right?
One day, I want to just train the reconstruction branch. The I would use the following code:
x, y = dataloader() # generate a batch of data
net = MyNet()
opt = torch.optim.Adam(self.net.parameters(), lr=1e-4)
pred1, pred2 = net(x)
loss2 = rec_loss(pred2, x)
loss2.backward()
opt.step()
It works well. Pred1 does not participate backward as I expected. But I found that the forward of pred1 also accupy some GPU memory. In order to save GPU memory, how to prevent the forward of pred1???