Embedding with very small value generate nan gradient

Is there anyone get nan value during training with nn.Embedding?

I recently got nan gradient in embedding layer. I saw nan issue in the nn.Embedding model, but I don’t know whether that issue is resolved or not.

My use case is this, I try to use nn.Embedding as weight for linear layer and during forward, some selected row from nn.Embedding is used for calculating forward. Also, I apply weight_decay in optimizer.

I used to investigate when the nan gradient is generated and I found the nan is generated in the embedding model.

I saw some issue when embedding goes to zero, then nan is generated for gradient. I use pdb to check the row vector which gradient is nan, and I found that some values are very small like 1e-41.

Is there anyone encounter this kind of situation? How to resolve it? I will replace the nn.Embedding to nn.Parameter trying to avoid this kind of nan issue, but I want to know why this kind of situation is generated.

I seem to have a very similar issue… Going to run it by disabling the weight decay on the weights of nn.Embedding (similar as https://github.com/pytorch/pytorch/issues/1402).

Just to confirm that the weight decay may have been the culprit… Added the following snippet and things are training smoothly…

def group_weight(module):
group_decay = []
group_no_decay = []
for n,m in module.named_parameters():
if ‘input_embedding’ in n:
group_no_decay.append(m)
else:
group_decay.append(m)
assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay)
groups = [dict(params=group_decay), dict(params=group_no_decay, weight_decay=.0)]
return groups

I see… Your point is that the weight decay is not stable when the weight is extremely small.

Thanks, I will apply this in my code.

Exactly, in my case, my guess is the following:

  • I have some embedding tokens that are very very uncommon in the training data (they may even not be present in the whole training dataset).

  • In this case, the weight decay essentially updates these embedding weights by
    w(t+1) = w(t) - 2 * decayfactor * w(t),
    which slowly approaches zero.

  • When training long enough, these weights become so small that there is some numerical underflow, resulting in NaN.

1 Like