I am working on a problem where the input space has different regions: R1, R2, and R3, and I have three datasets (D1, D2, and D3) sampled from the corresponding regions. The trick is that the loss function has three loss terms as well: L1, L2, and L3 (corresponding to the three regions of the input space), where the definition of each loss term is different from the other two. The problem is to be solved in PyTorch.
I am not sure how to proceed with sampling/mini-batching. Any advice is appreciated.
Below, I include a minimal working example, but it is without mini-batching. The actual problem has way larger datasets, so mini-batching is a must.
import torch input1 = torch.rand((100,2)) input2 = torch.rand((120,2)) input3 = torch.rand((140,2)) output1 = torch.rand((100,1)) output2 = torch.rand((120,1)) output3 = torch.rand((140,1)) L1 = torch.nn.L1Loss() L2 = torch.nn.MSELoss() L3 = torch.nn.HuberLoss(delta=0.1) D_in = 2 Hidd = 4 D_out = 1 class MLP(torch.nn.Module): def __init__(self, D_in, Hidd, D_out): super(MLP, self).__init__() self.linear1 = torch.nn.Linear(D_in, Hidd) self.linear2 = torch.nn.Linear(Hidd, D_out) def forward(self, x): y = torch.tanh(self.linear1(x)) y = self.linear2(y) return y mlp = MLP(D_in, Hidd, D_out) epochs = 5000 optimizer_name = "Adam" lr = 0.001 optimizer = getattr(torch.optim, optimizer_name)(mlp.parameters(), lr=lr) loss_history =  for epoch in range(epochs): pred1 = mlp(input1) pred2 = mlp(input2) pred3 = mlp(input3) loss1 = L1(pred1, output1) loss2 = L2(pred2, output2) loss3 = L3(pred3, output3) loss = loss1 + loss2 + loss3 optimizer.zero_grad() loss.backward(retain_graph=True) loss_history.append(loss.item()) print(loss.item()) optimizer.step()
Many thanks in advance for the help!