Let’s say that I have two MLP networks with one hidden layer each and size 100 that I would like to train simultaneously. Then I would like to implement 3 loss functions for this model, one for MLP1, one for MLP2, and one for 2 specific nodes from MLP1 and MLP2.
Any help is appreciated
Hello!
I wrote a little code that might work. Try it out and report back .
model1 = Net(...)
model2 = Net(...)
all_params = list(model1.parameters()) + list(model2.parameters())
optimizer = torch.optim.Adam(all_params)
loss = torch.nn.CrossEntropyLoss() # or other loss
loss_for_specific_nodes = None # whatever loss you want
outp1 = model1(inputs)
outp2 = model2(inputs)
loss1 = loss(outp1, labels)
loss2 = loss(outp2, labels)
special_node_output = ...
special_node_labels = ...
special_loss = loss_for_specific_nodes(special_node_output, special_node_labels)
final_loss = loss1 + loss2 + special_loss
final_loss.backward()
optimizer.step()
1 Like
Thanks for the code, but my question is how I can choose the nodes?
Well that’s hard to answer since you never said which nodes you want to choose. Please elaborate on this
Sure! sorry about that, let’s say that nodes are ordered as 0,1,2,3,…100 in mlp1 and 0,1,2,3,100 in mlp2. Then I choose nodes 0,1,2,3 from mlp1 and 0,1,2,3 from mlp2, and I would like to have a separate loss for those specific nodes, so I want to see if there is a way to select those nodes?
There sure is
outp1 = model1(inputs) # is of shape (batch_size, 100)
outp2 = model2(inputs)
indexes = torch.tensor([0,1,2,3])
spec1 = outp1.index_select(dim=1, index=indexes) # Selects nodes at indexes over the whole batch
spec2 = outp2.index_select(dim=1, index=indexes)
Here is isolated code that you can run to experiment.
import torch
outp1 = torch.randn(2, 5)
outp2 = torch.randn(2, 5)
indexes = torch.tensor([0,1,2])
spec1 = outp1.index_select(dim=1, index=indexes)
spec2 = outp2.index_select(dim=1, index=indexes)
print("Output1: \n%s\n" % outp1)
print("Output2: \n%s\n" % outp2)
print("Selected1: \n%s\n" % spec1)
print("Selected2: \n%s\n" % spec2)
1 Like
thank you so much!!!
1 Like