I am tying to implement the following kind of network architecture: There are two inputs which first go through the same CNN (like a siamese network), then there is a part where both outputs are concatenated followed by two heads for different purposes.
I have got two related problems which both deal with the best way to optimize a network with shared weights.
For the siamese network I could solve the weight sharing problem by giving two inputs for the same module (though I don’t know if this is the best solution):
def forward(self, input1, input2):
out1 = self.sharedNet(input1)
out2 = self.sharedNet(input2)
out = torch.cat((out1, out2), dim=1)
out = self.header(out)
return out
Now, the second problem is that depending on the input pair only one of the heads or both are needed. Thus, just always computing both seems wasteful. But if I split the first part and the two heads into separate modules, I don’t know how to optimize them. If I use one optimizer per head, I will train the shared part twice. The only solution I could think of was to use three optimizers, one for each head and one for the shared part and disable gradients in the shared part when optimizing the heads.