Optimize network with multiple heads

I am tying to implement the following kind of network architecture: There are two inputs which first go through the same CNN (like a siamese network), then there is a part where both outputs are concatenated followed by two heads for different purposes.

I have got two related problems which both deal with the best way to optimize a network with shared weights.
For the siamese network I could solve the weight sharing problem by giving two inputs for the same module (though I don’t know if this is the best solution):

def forward(self, input1, input2):
    out1 = self.sharedNet(input1)
    out2 = self.sharedNet(input2)
    out = torch.cat((out1, out2), dim=1)
    out = self.header(out)
    return out

Now, the second problem is that depending on the input pair only one of the heads or both are needed. Thus, just always computing both seems wasteful. But if I split the first part and the two heads into separate modules, I don’t know how to optimize them. If I use one optimizer per head, I will train the shared part twice. The only solution I could think of was to use three optimizers, one for each head and one for the shared part and disable gradients in the shared part when optimizing the heads.

1 Like
  1. Assuming that your sharedNet can deal with batches, you can do something like
def forward(self, input1, input2):
    out = self.sharedNet(torch.cat([input1, input2], 0))
    b = input1.size(0)
    if b == 1:
        out = out.view(1, -1, out.size(2), out.size(3))  # change if 1d/3d input
    else:
        out = torch.cat((out[:b], out[b:]), dim=1)
    out = self.header(out)
    return out
  1. The set of params each optimizer updates is unrelated with how you structure your code. It is entirely decided by the collection of params you gave at constructing the optimizer, e.g. Adam(itertools.chain(net1.parameters(), net2.parameters(), ...). And all optimizers do is to update parameters basing on the computed .grad gradients. You won’t be able to easily update multiple times even if you want to.

Thanks.

  1. My actual problem was that on the self.header there are another two different headers. But I think I can solve that the same way.

  2. Let’s assume, that net1 and net2 share some parameter: If the params argument gets some parameter twice, is it internally updating them just once?

You can also include other parameters in your forward function such as a boolean value whether or not to use head1 or head2.

def forward(self, input, use_head1=True):
x = sharedNet(input)
if useHead1:
   x = head1(x)
else:
   x = head2(x)

output = model(input, useHead1=False) #Use head2