One possible approach would be to create two separate DataLoader
and use their iterators
directly via:
dataset_head1 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))
dataset_head2 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))
for epoch in range(2):
print('epoch ', epoch)
# in epoch loop
loader1 = DataLoader(dataset_head1)
iter_loader1 = iter(loader1)
loader2 = DataLoader(dataset_head2)
iter_loader2 = iter(loader2)
try:
while True:
# train head1
data, target = next(iter_loader1)
print('training head1')
# train head2
data, target = next(iter_loader2)
print('training head2')
except StopIteration:
pass
To avoid code duplication, you could of course write a train
method and pass the current data into it, if that fits your use case.
Let me know, if this approach would work for you.