Do you mean like this?
def forward(self, input):
t = self.A1(input)
res1 = self.B1(input)
res2 = self.B2(self.A2(input))
return res1, res2
Then in train script
res1, res2 = net(input)
loss1 = criterion(res1, target)
loss2 = criterion(res2, target)
loss = loss1 + loss2
loss.backward()