[Solved] Loss goes crazy for a copy task involving word embedding

I have a simple model. Inputs are a matrix of integers (0 < x < vocabulary size). Outputs are a matrix of vectors representing the prediction of the input. As you can see from the loss function below, I am trying
to ask the network to copy its own input. The network is applied element-wise for every integer.

class Net(nn.Module):
  def __init__(self):
    super(Net, self).__init__()

    self.embedding = nn.Embedding(vocabSize, hiddenSize)
    self.linear    = nn.Linear(hiddenSize, hiddenSize)
    self.layerNorm = nn.LayerNorm(hiddenSize)
    self.output    = nn.Linear(hiddenSize, vocabSize, bias = False)
    self.bias      = nn.Parameter(torch.zeros(vocabSize))

    self.output.weight = self.embedding.weight

  def forward(self, input):
    hiddens = self.embedding(input)
    hiddens = self.linear(hiddens)
    hiddens = F.relu(hiddens)
    hiddens = F.dropout(hiddens, p = 0.1, training = self.training)
    hiddens = self.layerNorm(hiddens)
    preds   = self.output(hiddens)
    return preds + self.bias

  def loss(self, target):
    copy  = self.forward(target)
    count = batchSize * seqSize
    return F.cross_entropy(copy.view(count, -1), target.view(count))

I noticed a very strange behavior:

  1. vocabSize = 3000, hiddenSize = 128.
    • The loss drops quite fast to 0.01 and accuracy achieves > 99%.
    • However, it gets worse very quickly, and the accuracy can go as low as 40%.
    • Then, accuracy quickly improves. The loss sticks around 3.
  2. vocabSize = 30000, hiddenSize = 128/256/512.
    • The loss drops quite fast to 0.01 and accuracy achieves > 99%.
    • However, it gets worse very quickly.
    • But this time it does not improve at all. Loss can go as high as 50 and still increasing.

Does anyone know why this might happen? In particular, how could the second part happen? I have tried to change relu to sigmoid. While it converges and diverges slower it still shows a similar pattern.

Here is a runnable code on gist. For me, the diverging part starts around 9k batch.

Learning rate seems not to be an issue. Since the network is element-wise, essentially for every batch it can see thousands of batch. Anyway, I used Adam with learning rate 1e-4, beta (0.9, 0.999), and weight decay 0.01.

About the dropout and layer normalization. It is a part of a bigger network. Removing them, while delaying the occurrence of divergence, incurs it nonetheless.

I am reproducing BERT, and right now I am trying to let my implementation to learn the copy task. My stack of transformers can ace that problem quite easily, yet the whole model cannot, so I tried to skip transformers and found this problem. Just to make sure my BERT implementation was correct I coded the above model freshly and found this problem persist.

For my BERT, my network just produced “the” all the time. Its weight embedding is quite degenerate, has a few very large singular values and everything else is quite small. I feel somehow that was related to the phenomenon I am observing here, but I checked the two weight matrices in this model and while their performance is getting worse, their singular values remain evenly distributed. Any advice would be helpful!

Thanks.

It turns out that the weight decay term I used in my Adam is the cause.