Tuning nn.Embedding weight with constraint on norm


I have a really simple model which uses only nn.Embedding module. The goal is to minimize a specific loss function but with additional contraint that the L2-norm of the embeddings is 1.
I found two options to normalize embeddings, specifically:

  1. Reassign weight at each forward call: self.embeddings.weight.data = F.normalize(self.embeddings.weight.data, p=2, dim=1)
  2. Use param max_norm of nn.Embedding: self.embeddings = nn.Embedding(..., max_norm=1, norm_type=2)

The problem here is each of this solutions is ignored during backprop, as far as I can see. For the fist option a reassignment of the weights.data just doesn’t allow to track the gradients. The second option implemented using torch.no_grad() context as we can see here.

Actually, I see many examples where people use the first option, but IMO it is not correct (if someone can explain my why it is correct, please, I would be grateful).

The other possible solution is to modify loss function some way, but I don’t really to do this right now…