skipgram word embeddings results in nan values

After training a word embedding model on a large-ish corpus, my embeddings converge to nan values.

The model is very simple (skipgram with negative sampling)

full model:

class NEG_loss(nn.Module):
    def __init__(self, vocab_size, embed_size, neg_sampling_table=None):
        """
        :param vocab_size: An int. The number of possible classes.
        :param embed_size: An int. EmbeddingLockup size
        :param num_sampled: An int. The number of sampled from noise examples
        :param neg_sampling_table: A list of non negative floats. Class neg_sampling_table. None if
            using uniform sampling. The neg_sampling_table are calculated prior to
            estimation and can be of any form, e.g equation (5) in [1]
        """

        super(NEG_loss, self).__init__()

        self.device = t.device("cuda:0" if t.cuda.is_available() else "cpu")

        self.vocab_size = vocab_size
        self.embed_size = embed_size

        self.out_embed.weight = nn.Parameter(
            t.cat(
                [
                    t.zeros(1, self.embed_size),
                    FT(self.vocab_size - 1, self.embed_size).uniform_(
                        -0.5 / self.embed_size, 0.5 / self.embed_size
                    ),
                ]
            )
        )


        self.in_embed.weight = nn.Parameter(
            t.cat(
                [
                    t.zeros(1, self.embed_size),
                    FT(self.vocab_size - 1, self.embed_size).uniform_(
                        -0.5 / self.embed_size, 0.5 / self.embed_size
                    ),
                ]
            )
        )

        self.neg_sampling_table = neg_sampling_table
        if self.neg_sampling_table is not None:
            assert min(self.neg_sampling_table) >= 0, "Each weight should be >= 0"

            self.neg_sampling_table = Variable(t.from_numpy(neg_sampling_table)).float()

    # TODO this is bad - find more elegant solution
    def sample(self, num_sample):
        """
        draws a sample from classes based on neg_sampling_table
        """

        return self.neg_sampling_table[
            t.randint(0, len(self.neg_sampling_table), (num_sample,))
        ]

    def forward(self, input_labels, out_labels, num_sampled):
        """
        :param input_labels: Tensor with shape of [batch_size] of Long type
        :param out_labels: Tensor with shape of [batch_size, window_size] of Long type
        :param num_sampled: An int. The number of sampled from noise examples
        :return: Loss estimation with shape of [1]
            loss defined in Mikolov et al. Distributed Representations of Words and Phrases and their Compositionality
            papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf
        """

        input_labels = input_labels.to("cuda:0")
        out_labels = out_labels.to("cuda:0")

        batch_size = out_labels.size()[0]

 
        input_ = self.in_embed(input_labels.view(-1))
        output = self.out_embed(out_labels.view(-1))

        if self.neg_sampling_table is not None:
           
            noise_sample_count = batch_size * num_sampled
            draw = self.sample(noise_sample_count)
            noise = draw.view(batch_size, num_sampled).long()
            
        else:


            noise = Variable(
                t.Tensor(batch_size, num_sampled)
                .uniform_(0, self.vocab_size - 1)
                .long()
            )
        # if use_cuda:
        noise = noise.to(self.device)
        noise = self.out_embed(noise).neg()

        log_target = (input_ * output).sum(1).squeeze().sigmoid().log()

        """ āˆ‘[batch_size * window_size, num_sampled, embed_size] * [batch_size * window_size, embed_size, 1] ->
            āˆ‘[batch_size, num_sampled, 1] -> [batch_size] """
        sum_log_sampled = (
            t.bmm(noise, input_.unsqueeze(2)).sigmoid().log().sum(1).squeeze()
        )

        loss = log_target + sum_log_sampled

        return -loss.mean()

    def input_embeddings(self):
        return self.in_embed.weight.detach().cpu().numpy()

and training loop and optimizer code

    # NEG loss and optim
    neg = NEG_loss(vocab_size, dim, neg_sampling_table=neg_sampling_dist)
    neg.to("cuda:0")
    optimizer = Adam(neg.parameters(), 0.01)

    sys.stdout.write("BEGINNING TRAINING\n")
    for i in range(epochs):
        sys.stdout.write("-" * 35 + "\n")
        sys.stdout.write(f"EPOCH {i+1}\n")
        for i, batch in enumerate(tqdm(dataloader)):
            input_, output_ = batch
            optimizer.zero_grad()
            loss = neg(input_, output_, neg_samples)
            loss.backward()
            optimizer.step()
        sys.stdout.write("-" * 35 + "\n")

This model is about as simple as it gets so I am a bit surprised I am having this issue. Everything is fine on a smaller corpus (7k batches of size 128) but gets nan-y with a larger corpus (200k batches of size 128).

Anyone see anything immediately wrong or have any tips to figure out what is going?

Any help is appreciated. Also pretty new to pytorch - so if you see anything that is a bit dumb feel free to point it out :slight_smile:

Could you check, if all inputs contain valid values via torch.isfinite(input).all() and torch.isfinite(output).all()?
If the data is alright, I would recommend to set torch.autograd.detect_anomaly(True) at the beginning of the script and post the stack trace here. This should hopefully point to the operation, which created the NaN outputs.

Thanks for your response @ptrblck. I do not believe my inputs are the culprit - they are just indices that then look up the embedding within the in_embed layer. Do you mean to check this for at model init for all of in_embed.

In the meantime, I am currently running the model using the torch.autograd.detect_anomaly(True).

Thanks again

@ptrblck I included the torch.autograd.detect_anomaly(True) before my training loop and my script ran to completion - still yielding embedding weights of nan. Any other recommendations?

Are you seeing a NaN output after your embedding weights are getting NaN values?
Anomaly detection should raise an error, so Iā€™m a bit confused why your script finished.