Train a local module based on a global loss

I am trying to train an embedding model based on a loss function that depends on a clustering-based metric. As the dataset, I have n instances of clustering problems. In each instance, a model has to cluster m data points, represented by the embedding vector, and a single metric is calculated based on the quality of the cluster. My training loop looks like the following:

        for epoch in trange(epochs, desc="Epoch"):
            self.model.train() # embedding model
            self.loss_model.train() # clustering loss function as a pytorch module, no trainable param
            for _ in trange(steps_per_epoch, desc="Iteration", smoothing=0.05):
                data = next(data_iter)
                batch_cand_labels = self.model(data)
                batch_true_labels = torch.stack([self.get_gt(label) for label in labels])
                loss = self.loss_model(batch_cand_labels, batch_true_labels)

And following is my embedding model:

class DeepClusterEmbedder(nn.Module):

    def __init__(self, doc_tf: MyEmbedding, lambda_val: float):
        super(MyEmbedding, self).__init__()
        self.doc_tf = doc_tf
        self.lambda_val = lambda_val
        self.optim = OptimCluster()

    def forward(self, data: List):
        doc_features = self.doc_tf.tokenize(data)
        self.doc_tf.eval() # freeze embedding layer
        doc_embeddings = self.doc_tf(doc_features)['sentence_embedding']
        self.doc_tf.train() # now update embedding layer
        doc_embeddings = doc_embeddings.reshape((batch_size, maxlen, -1))
        embeddings_dist_mats = torch.stack([euclid_dist(doc_embeddings[i]) for i in range(batch_size)])
        batch_adj_mats = self.optim.apply(embeddings_dist_mats, self.lambda_val) # some optimizer for clustering
        return batch_adj_mats

As you can see, I want to freeze my embedding model when I calculate embedding vectors for all samples in a clustering instances and only want to update the parameters after I calculate the final clustering loss. That’s why I am surrounding the embedding calculations with model.eval() and model.train(). If I do not do that then the model parameters get updated for every embedding vector calculation (meaning if I have m samples in a single clustering instance, then the model parameters will be updated m times, which I do not want). I inferred this from the observation that the model is generating different embeddings for same data samples when I do not freeze the embedding layer. Now my question is, is this the correct way of freezing the layer param in this case? Also, is my understanding of parameter update of the embedding layer correct?

Can you define the parameters in your optimizer in a way to skip unwanted layers update?
That would be the easiest solution.

Thank you @CedricLy for your response. If I make the optimizer not update the doc_tf layer when I create the optimizer, then it will never learn the embeddings. Is there a way to tell the optimizer from the forward function of DeepClusterEmbedder not to update doc_tf before I calculate all the embeddings and only update it after that? Currently, I am achieving it by calling doc_tf.eval() and doc_tf.train() before and after the embedding calculation respectively. Is that the correct way?