Update multiple headers during training

I have a relatively simple model. I use a base BERT model with 10 linear classifiers on top of it. It looks like this:

class BERT(nn.Module):
def init(self, bert):

    super(BERT, self).__init__()
    
    self.bert = bert
    self.fc1 = nn.Linear(768, 768)
    self.relu =  nn.ReLU()
    self.dropout = nn.Dropout(0.1)
    self.dropout_head = nn.Dropout(p=0.25)
    self.classifier_heads = nn.ModuleList([nn.Linear(768, dataset.number_of_classes()) for i in range(10)])
    self.output = []

def forward(self, **kwargs):
    cls_hs = self.bert(**kwargs)
    hidden_state = cls_hs.last_hidden_state
    pooler = hidden_state[:, 0]
    x = self.fc1(pooler)
    x = self.relu(x)
    x = self.dropout(x)
    self.output = []
    for layer in self.classifier_heads:
        x_out = layer(x)
        x_out = self.dropout_head(x_out)
        self.output.append(x_out)

    return self.output 

The model returns a list of length 10, each item contains a tensor of output values (768, 1000 (number of classes)). I want to dynamically update the headers of the model. I hope that the headers of the model will specialize in parts of the data (e.g. outliers, text about images).

My approach was to stack the output and the labels as well to match the output creating a tensor of batch size * number of classes * number of headers. After this, use CrossEntropy WITHOUT reduction, to get the loss per header, resulting in a tensor of batch size * number of headers.

CEloss = nn.CrossEntropyLoss(reduction='none')
# During the training loop...
preds_stacked = torch.stack(preds, axis=-1)
labels_stacked = torch.stack([labels for i in range(len(model.classifier_heads))], axis=-1)
losses = CEloss(preds_stacked, labels_stacked)

I have to approaches I want to try:

  • Only backpropagate the best loss across all headers;
  • Only backpropagate the best loss across that header

For option 1, I simply thought to loop over the batch, select the best loss and backpropagate the loss:

for sample in losses.size[0]: # The batch dimension
    best_loss = losses[sample].min()
    best_loss.backward(retain_graph=True)
    optimizer.step()

However it apparently is not as simple as this. I can of course use a batch size of 1, but this is not very efficient. Maybe someone can offer a solution?

My approach for the second version was to update the optimizer during training.

# Stack predictions and labels
preds_stacked = torch.stack(preds, axis=-1)
labels_stacked = torch.stack([labels for i in range(len(model.classifier_heads))], axis=-1)
# Calculate loss
losses = CEloss(preds_stacked, labels_stacked)
# losses per head
losses_per_head = torch.mean(losses, axis=0)
# Backpropagate best losses across headers
for header in range(len(model.classifier_heads)):
    # Update model head to update
    params = list(model.fc1.parameters()) + list(model.classifier_heads[header].parameters())
    optimizer = torch.optim.Adam(params=params, lr=3e-4)
    optimizer.zero_grad()
    loss = losses_per_head[header]
    loss.backward(retain_graph=True)
    optimizer.step()

However, I get the error that the gradient computation has been modified by an inplace operation. I can create multiple models, but I want to do this in a single model.

I hope someone can offer some suggestions to improve this

For those interested, I did not use reduction during the loss calculation. So each batch I end up with batch x header losses. I took the best losses, averaged them and backpropagated the whole thing. Seem to work like a charm.