Weight Decay for tied weights (embedding and linear layers)

So I have a model where I have an embedding layer (nn.Embedding) and a final nn.Linear projection layer that are sharing weights via weight tying.

It seems like the best practice is not to perform weight decay on embedding weights, but to perform decay on linear layer weights. What should I do in this situation?

Here are the pages I have checked unsuccessfully in search of an answer.

  1. Weight decay in the optimizers is a bad idea (especially with BatchNorm)
  2. Weight decay exclusions by michaellavelle · Pull Request #24 · karpathy/minGPT · GitHub
  3. https://github.com/karpathy/minGPT/blob/3ed14b2cec0dfdad3f4b2831f2b4a86d11aef150/mingpt/model.py#L136
  4. regularization - Why not perform weight decay on layernorm/embedding? - Cross Validated
  5. https://github.com/pytorch/examples/blob/main/word_language_model/model.py#L28
  6. python - Tying weights in neural machine translation - Stack Overflow
  7. Weight decay only for weights of nn.Linear and nn.Conv*
1 Like