Combining Multiple Models and Datasets

I’m trying to combine multiple related datasets and models into one giant model. Here’s what the architecture looks like:

For each head, there’s a dataset with the following structure:

I’ve referred to the following 2 sources for setting up the model, loss, and optimiser:



I’ve setup my model roughly as follows:

class CombinedModel(nn.Module):
    def __init__(self):
        super(CombinedModel, self).__init__()
        self.features = models.mobilenet_v2(pretrained=True)
        self.pool   = nn.Sequential(*[
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(1)
        ])
        self.head1 = nn.Linear(1280,2)
        self.head2 = nn.Linear(1280,3)

    def forward(x, head_name):
        assert head_name in ['head1', 'head2']
        x = self.features(x)
        # only run the input for the head that is assigned
        # to that dataset
        return getattr(self, head_name)(x)

What I’d like to do is a round-robin style training where I train one batch from one dataset, fine-tune the respective head and the feature extractor’s parameters, then do another batch from another dataset, and so on.

I’ve setup individual loss functions for each head and a single optimiser for all of them. I’ve checked that the optimiser only calculates gradients for the head_name that is passed in while calling the forward method.

optimiser = optim.Adam([
    {'params': model.features[-5:].parameters(), 'lr':1e-5},
    *[{'params':getattr(model,head).parameters(), 'lr':1e-3} 
      for head in ['head1', 'head2']
])

criterions = {
    'head1': nn.CrossEntropyLoss(),
    'head2': nn.CrossEntropyLoss()
}

I can create a dataset for the respective head with torchvision.datasets.ImageFolder but am unclear on how to combine it for multiple datasets since I’d like to switch between datasets after every batch.

I’d imagined the training loop to look something like this:

model = CombinedModel()

# training mode
# each batch should switch (randomly) between
# head1 and head2 
for input,labels,head_name in dataloaders:
    # head_name == 'head1' or 'head2'
    optimiser.zero_grad()
    outputs = model(inputs, head_name)

    # calculate loss for respective head
    loss = criterions[head_name](outputs, labels)
    loss.backward()
    optimiser.step()

I’d love any input on what would be the right way to implement this.
Thank you for your time!

1 Like

One possible approach would be to create two separate DataLoader and use their iterators directly via:

dataset_head1 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))
dataset_head2 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))

for epoch in range(2):
    print('epoch ', epoch)
    # in epoch loop
    loader1 = DataLoader(dataset_head1)
    iter_loader1 = iter(loader1)
    loader2 = DataLoader(dataset_head2)
    iter_loader2 = iter(loader2)
    
    try:
        while True:
            # train head1
            data, target = next(iter_loader1)
            print('training head1')
    
            # train head2
            data, target = next(iter_loader2)
            print('training head2')
    except StopIteration:
        pass

To avoid code duplication, you could of course write a train method and pass the current data into it, if that fits your use case.

Let me know, if this approach would work for you.

1 Like

Excellent! This is definitely the right approach.

I suppose this is more of a python question, but with this exact approach, the training loop exits on the first StopIteration, which is a problem if the two dataset sizes differ. Is the only way to handle this in the except block?

Here’s a simulation of that situation from your code block:

dataset_head1 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))
dataset_head2 = TensorDataset(torch.randn(2, 1), torch.randn(2, 1))

for epoch in range(2):
    print('epoch ', epoch)
    # in epoch loop
    loader1 = DataLoader(dataset_head1)
    loader2 = DataLoader(dataset_head2)
    
    iter_loader1 = iter(loader1)  # length - 10
    iter_loader2 = iter(loader2)  # length - 2
    
    try:
        while True:
            # train head1
            data, target = next(iter_loader1)
            print('training head1')
    
            # train head2
            data, target = next(iter_loader2)
            print('training head2')
    except StopIteration:
        # train remaining iterations from `iter_loader1` here?
        # Use `DataLoader`._num_yielded?

I think I figured it out. Thanks @ptrblck!

dataset_head1 = TensorDataset(torch.randn(10, 1), torch.randn(10, 1))
dataset_head2 = TensorDataset(torch.randn(2, 1), torch.randn(2, 1))

heads = {
    'head1': dataset_head1,
    'head2': dataset_head2
}

for epoch in range(2):
    print('epoch ', epoch)
    print('========')
    # in epoch loop
    
    dls = {
        'head1': DataLoader(heads['head1']),
        'head2': DataLoader(heads['head2'])
    } 
    iters = {
        'head1': iter(dls['head1']),
        'head2': iter(dls['head2'])
    }
    
    try:
        while True:
            # Keep looping over dataloaders sequentially
            # until there are none for that head
            for name,iterator in iters.items():
                if not len(iterator) - iterator._num_yielded == 0:
                    data,target = next(iterator)
                    print(f'training {name}')
                else:
                    pass
            
            # raise `StopIteration` when all batches for all heads
            # have been yielded
            pending_per_head = [len(it) - it._num_yielded for _,it in iters.items()]
            if torch.tensor(pending_per_head).sum() == 0:
                raise StopIteration
    except StopIteration:
        pass
    print()