Combining Multiple Models and Datasets

One possible approach would be to create two separate DataLoader and use their iterators directly via:

dataset_head1 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))
dataset_head2 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))

for epoch in range(2):
    print('epoch ', epoch)
    # in epoch loop
    loader1 = DataLoader(dataset_head1)
    iter_loader1 = iter(loader1)
    loader2 = DataLoader(dataset_head2)
    iter_loader2 = iter(loader2)
    
    try:
        while True:
            # train head1
            data, target = next(iter_loader1)
            print('training head1')
    
            # train head2
            data, target = next(iter_loader2)
            print('training head2')
    except StopIteration:
        pass

To avoid code duplication, you could of course write a train method and pass the current data into it, if that fits your use case.

Let me know, if this approach would work for you.

1 Like