How to implement multiple loss

Let’s say that I have two MLP networks with one hidden layer each and size 100 that I would like to train simultaneously. Then I would like to implement 3 loss functions for this model, one for MLP1, one for MLP2, and one for 2 specific nodes from MLP1 and MLP2.
Any help is appreciated


I wrote a little code that might work. Try it out and report back :slight_smile:.

model1 = Net(...)
model2 = Net(...)

all_params = list(model1.parameters()) + list(model2.parameters())
optimizer = torch.optim.Adam(all_params)

loss = torch.nn.CrossEntropyLoss() # or other loss
loss_for_specific_nodes = None # whatever loss you want

outp1 = model1(inputs)
outp2 = model2(inputs)

loss1 = loss(outp1, labels)
loss2 = loss(outp2, labels)

special_node_output = ...
special_node_labels = ...
special_loss = loss_for_specific_nodes(special_node_output, special_node_labels)

final_loss = loss1 + loss2 + special_loss
1 Like

Thanks for the code, but my question is how I can choose the nodes?

Well that’s hard to answer since you never said which nodes you want to choose. Please elaborate on this :slight_smile:

Sure! sorry about that, let’s say that nodes are ordered as 0,1,2,3,…100 in mlp1 and 0,1,2,3,100 in mlp2. Then I choose nodes 0,1,2,3 from mlp1 and 0,1,2,3 from mlp2, and I would like to have a separate loss for those specific nodes, so I want to see if there is a way to select those nodes?

There sure is :smiley:

outp1 = model1(inputs) # is of shape (batch_size, 100)
outp2 = model2(inputs)

indexes = torch.tensor([0,1,2,3])

spec1 = outp1.index_select(dim=1, index=indexes) # Selects nodes at indexes over the whole batch
spec2 = outp2.index_select(dim=1, index=indexes)

Here is isolated code that you can run to experiment.

import torch

outp1 = torch.randn(2, 5)
outp2 = torch.randn(2, 5)
indexes = torch.tensor([0,1,2])

spec1 = outp1.index_select(dim=1, index=indexes)
spec2 = outp2.index_select(dim=1, index=indexes)

print("Output1: \n%s\n" % outp1)
print("Output2: \n%s\n" % outp2)
print("Selected1: \n%s\n" % spec1)
print("Selected2: \n%s\n" % spec2)

1 Like

thank you so much!!!

1 Like