I am trying to train an embedding model based on a loss function that depends on a clustering-based metric. As the dataset, I have n instances of clustering problems. In each instance, a model has to cluster m data points, represented by the embedding vector, and a single metric is calculated based on the quality of the cluster. My training loop looks like the following:
for epoch in trange(epochs, desc="Epoch"):
self.model.train() # embedding model
self.loss_model.train() # clustering loss function as a pytorch module, no trainable param
for _ in trange(steps_per_epoch, desc="Iteration", smoothing=0.05):
optimizer.zero_grad()
self.model.zero_grad()
self.loss_model.zero_grad()
data = next(data_iter)
batch_cand_labels = self.model(data)
batch_true_labels = torch.stack([self.get_gt(label) for label in labels])
loss = self.loss_model(batch_cand_labels, batch_true_labels)
loss.backward()
optimizer.step()
And following is my embedding model:
class DeepClusterEmbedder(nn.Module):
def __init__(self, doc_tf: MyEmbedding, lambda_val: float):
super(MyEmbedding, self).__init__()
self.doc_tf = doc_tf
self.lambda_val = lambda_val
self.optim = OptimCluster()
def forward(self, data: List):
doc_features = self.doc_tf.tokenize(data)
self.doc_tf.eval() # freeze embedding layer
doc_embeddings = self.doc_tf(doc_features)['sentence_embedding']
self.doc_tf.train() # now update embedding layer
doc_embeddings = doc_embeddings.reshape((batch_size, maxlen, -1))
embeddings_dist_mats = torch.stack([euclid_dist(doc_embeddings[i]) for i in range(batch_size)])
batch_adj_mats = self.optim.apply(embeddings_dist_mats, self.lambda_val) # some optimizer for clustering
return batch_adj_mats
As you can see, I want to freeze my embedding model when I calculate embedding vectors for all samples in a clustering instances and only want to update the parameters after I calculate the final clustering loss. That’s why I am surrounding the embedding calculations with model.eval() and model.train(). If I do not do that then the model parameters get updated for every embedding vector calculation (meaning if I have m samples in a single clustering instance, then the model parameters will be updated m times, which I do not want). I inferred this from the observation that the model is generating different embeddings for same data samples when I do not freeze the embedding layer. Now my question is, is this the correct way of freezing the layer param in this case? Also, is my understanding of parameter update of the embedding layer correct?