Weight Decay for tied weights (embedding and linear layers)

So I have a model where I have an embedding layer (nn.Embedding) and a final nn.Linear projection layer that are sharing weights via weight tying.

It seems like the best practice is not to perform weight decay on embedding weights, but to perform decay on linear layer weights. What should I do in this situation?

Here are the pages I have checked unsuccessfully in search of an answer.

  1. Weight decay in the optimizers is a bad idea (especially with BatchNorm)
  2. Weight decay exclusions by michaellavelle · Pull Request #24 · karpathy/minGPT · GitHub
  3. https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136
  4. regularization - Why not perform weight decay on layernorm/embedding? - Cross Validated
  5. https://github.com/pytorch/examples/blob/main/word_language_model/model.py#L28
  6. python - Tying weights in neural machine translation - Stack Overflow
  7. Weight decay only for weights of nn.Linear and nn.Conv*
1 Like

If you are weight tying you are effectively creating a linear layer that points to the embedding weight matrix rather than it’s own weights, so when you “search” for the weights of all linear layers the tied embeddings shouldn’t show up because those tied weights – although used by the linear layer – don’t belong to the linear layer.

In practice this depends on how you implement the weight tying. In the case of something like HuggingFace’s PreTrainedModel (with tie_word_embeddings=True in the config and a correctly set _tied_weights_keys = [ 'lm_head.weight' ] class attribute) the weights of self.lm_head will NOT be included when iterating through all child weights and modules, and will only be included in the nn.Embedding layer. So as long as your code knows not to enable weight decay for embedding modules the tied linear layer will not receive weight decay.