I want to make a network that has three parts - the first one is choosing the next one to be used:
model = nn.Sequential()
# hinge is only used to choose between subnetwork_1 and _2
model.add_module('0: hinge', nn.Linear(3, 2))
# based on the choice made by hinge, one of the following is used:
model.add_module('1: subnetwork_1', nn.Linear(3, 1))
model.add_module('2: subnetwork_2', nn.Linear(3, 1))
x = torch.randn((1, 3), requires_grad=True)
z = model(x)
_, max_indice = torch.max(z, 1)
x = model[max_indice+1](x)
[Google Colab](Code to play is available on Google Colab)
The problem is that the layer ‘hinge’ does not learn anything, so my network does not become better and better in choosing the best subnetwork for further processing the data. I already know [max_indice] will not work, but what will?..
That is because you have not asked it to learn anything. Your code does only a forward pass through the ‘hinge’ layer. For the layer to learn, you need to (i) define a loss function which tells how bad the computed output is with respect to the expected output, and (ii) do a backward pass so that the knowledge from (i) is taught to the model.
This is a good tutorial on how to teach a model in PyTorch.
I know that I do not propagate any loss in my code, it is simply not a complete application. However, if I do loss.backward(), the layer hinge would not learn anything, anyway.
If you could provide a minimal working example of this behaviour, then I can try debugging it.
My wild guess is that you actually wanted sth like this:
hinge = nn.Linear(3, 2)
branches = [
hinge_logits = hinge(x)
hinge_dist = torch.distributions.categorical.Categorical(logits=hinge_logits)
hinge_probs = hinge_dist.probs()
# weighted average of branch outputs
x = sum([p * branch(x) for branch in zip(hinge_probs, branches)])
I didn’t test it so there are surely issue with batch dimensions. You’d probably need to add some squeezes and unsqueezes here and there.
In general take a look at Probability distributions - torch.distributions — PyTorch 1.9.0 documentation