Supervised Contrastive Learning 2-stage training implementation

Hi all,

I am looking in training a model using the approach described in Supervised Contrastive Learning, where there is a metric learning loss and a classification loss.

In the paper, they first train the encoder using solely the metric learning loss. Then, they freeze the encoder, and train a classification layer using cross-entropy loss. This requires the model to be trained in two stages, however the authors mention: Note that in practice the linear classifier can be trained jointly with the encoder and projection networks by blocking gradient propagation from the linear classifier back to the encoder, and achieve roughly the same results without requiring two-stage training.

Therefore I would like to ask if my pseudocode here would work as a way to train jointly the encoder and classification layer as proposed by the authors.

Some questions:

  • Is this implementation of two backward calls correct? Or should I add the losses and perform the backward step once?
  • Should I use two different optimizers for each part of the network (encoder and classification layer)?
class MyModel(nn.Module):

    def __init__(self):
        self.encoder = nn.Linear(100, 10)
        self.lin = nn.Linear(10, 2)

    def training_step(self, batch):
        x, y = batch['x'], batch['y']
        e = self.encoder(x)
        metric_loss = MyMetricLoss(e, y)
        metric_loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # use detach to block gradients from classification 
        # to reach encoder
        p = self.lin(e.detach()) 
        classification_loss = CrossEntropy(p, y)
        classification_loss.backward()
        optimizer.step()
        optimizer.zero_grad()


Many thanks!

I don’t fully understand this statement and how the encoder can be trained if the gradient propagation is blocked. Is the encoder used in another part of the model or is it sharing parameters with another module?

Hi @ptrblck, the encoder is trained by the backpropagation of the metric learning loss. Which is computed previous to the detach (see pseudocode example). Or at least that is my understanding.

Yes, I understand your code snippet, which I believe shows the 2-stage approach.
However, I don’t fully understand the quoted statement as I don’t know how the encoder would be trained in this case assuming that a single backward pass is used which a detached encoder output. Or do you know if this statement refers to using a single optimizer but still two different forward/backward passes?

1 Like

Can’t you just combine metric_loss and classification_loss and then backpropagate once? You’ve detached the encoder output before feeding it into the linear layer, so ‘e.detach()’ should behave like a typical input ‘x’ value, but self.lin will still backpropagate. The encoder will receive updates based on the metric loss only, while ‘lin’ will receive updates based on the crossentropy loss.

So can’t you just remove the .backward, .step, and .zero_grad calls from the metric_loss and classification_loss steps, sum the losses, and backpropagate:

e = self.encoder(x)
metric_loss = MyMetricLoss(e, y)
p = self.lin(e.detach())
classification_loss = CrossEntropy(p, y)

# Unitary scalarization
total_loss = metric_loss + classification_loss
total_loss.backward()