Different datasets and different loss terms in pytorch

Hello,
I am working on a problem where the input space has different regions: R1, R2, and R3, and I have three datasets (D1, D2, and D3) sampled from the corresponding regions. The trick is that the loss function has three loss terms as well: L1, L2, and L3 (corresponding to the three regions of the input space), where the definition of each loss term is different from the other two. The problem is to be solved in PyTorch.

I am not sure how to proceed with sampling/mini-batching. Any advice is appreciated.

Below, I include a minimal working example, but it is without mini-batching. The actual problem has way larger datasets, so mini-batching is a must.

import torch
input1 = torch.rand((100,2))
input2 = torch.rand((120,2))
input3 = torch.rand((140,2))

output1 = torch.rand((100,1))
output2 = torch.rand((120,1))
output3 = torch.rand((140,1))

L1 = torch.nn.L1Loss()
L2 = torch.nn.MSELoss()
L3 = torch.nn.HuberLoss(delta=0.1)

D_in  = 2
Hidd  = 4
D_out = 1

class MLP(torch.nn.Module):
    def __init__(self, D_in, Hidd, D_out):
        super(MLP, self).__init__()
        
        self.linear1  = torch.nn.Linear(D_in, Hidd)
        self.linear2  = torch.nn.Linear(Hidd, D_out)
        
    def forward(self, x):
        
        y = torch.tanh(self.linear1(x))
        y = self.linear2(y)
        return y
mlp = MLP(D_in, Hidd, D_out)
epochs = 5000
optimizer_name = "Adam"
lr = 0.001
optimizer = getattr(torch.optim, optimizer_name)(mlp.parameters(), lr=lr)
loss_history = []
for epoch in range(epochs):
    pred1 = mlp(input1)
    pred2 = mlp(input2)
    pred3 = mlp(input3)
    loss1 = L1(pred1, output1)
    loss2 = L2(pred2, output2)
    loss3 = L3(pred3, output3)
    loss  = loss1 + loss2 + loss3
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    loss_history.append(loss.item())
    print(loss.item())
    optimizer.step()
    

Many thanks in advance for the help!

You could wrap each Dataset into a DataLoader, create the iterators manually, and iterate them by calling next. Something like this should work:

input1 = torch.rand((100,2))
input2 = torch.rand((120,2))
input3 = torch.rand((140,2))

output1 = torch.rand((100,1))
output2 = torch.rand((120,1))
output3 = torch.rand((140,1))

dataset1 = TensorDataset(input1, output1)
dataset2 = TensorDataset(input2, output2)
dataset3 = TensorDataset(input3, output3)

loader1 = DataLoader(dataset1, batch_size=5, shuffle=True)
loader2 = DataLoader(dataset2, batch_size=5, shuffle=True)
loader3 = DataLoader(dataset3, batch_size=5, shuffle=True)

iter1 = iter(loader1)
iter2 = iter(loader2)
iter3 = iter(loader3)

try:
    while True:
        data1 = next(iter1)
        data2 = next(iter2)
        data3 = next(iter3)
except StopIteration:
    print("iter exhausted")

Alternatively, you could also use the itertools library and iterate e.g. the longest sequence:

iter1 = iter(loader1)
iter2 = iter(loader2)
iter3 = iter(loader3)
for a, b, c in itertools.zip_longest(iter1, iter2, iter3):
    print("a ", a)
    print("b ", b)
    print("c ", c)

In the end it depends on your use case as the datasets have a different length.

1 Like

Many thanks for the reply and detailed response, @ptrblck. I have a follow-up question. For the training process, how can I incorporate this into it?

loss_history = []
for epoch in range(epochs):
    pred1 = mlp(input1)
    pred2 = mlp(input2)
    pred3 = mlp(input3)
    loss1 = L1(pred1, output1)
    loss2 = L2(pred2, output2)
    loss3 = L3(pred3, output3)
    loss  = loss1 + loss2 + loss3
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    loss_history.append(loss.item())
    print(loss.item())
    optimizer.step()

Thanks again

You could add the next calls or the for loop into the epoch loop.

1 Like

Many thanks! Much appreciated

@ptrblck I tried the following. It seems that it does not loop for the epoch

for epoch in range(epochs):
    try:
        while True:
            data1 = next(iter1)
            data2 = next(iter2)
            data3 = next(iter3)
            pred1 = mlp(data1[0])
            pred2 = mlp(data2[0])
            pred3 = mlp(data3[0])
            loss1 = L1(pred1, data1[1])
            loss2 = L2(pred2, data2[1])
            loss3 = L3(pred3, data3[1])
            loss  = loss1 + loss2 + loss3
            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            loss_history.append(loss.item())
            print(epoch, loss.item())
            optimizer.step()
    except StopIteration:
        print("iter exhausted")

It’s working for me:

epochs = 3
mlp = nn.Linear(2, 1)
L1 = nn.MSELoss()
L2 = nn.MSELoss()
L3 = nn.MSELoss()

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)

for epoch in range(epochs):
    try:
        while True:
            data1 = next(iter1)
            data2 = next(iter2)
            data3 = next(iter3)
            pred1 = mlp(data1[0])
            pred2 = mlp(data2[0])
            pred3 = mlp(data3[0])
            loss1 = L1(pred1, data1[1])
            loss2 = L2(pred2, data2[1])
            loss3 = L3(pred3, data3[1])
            loss  = loss1 + loss2 + loss3
            optimizer.zero_grad()
            loss.backward()
            print(epoch, loss.item())
            optimizer.step()
    except StopIteration:
        print("iter exhausted")
        iter1 = iter(loader1)
        iter2 = iter(loader2)
        iter3 = iter(loader3)

Note that you would need to recreate the iterators as they are already exhausted.

1 Like

Indeed! I did not create the iterators. Thanks!