After training a word embedding model on a large-ish corpus, my embeddings converge to nan
values.
The model is very simple (skipgram with negative sampling)
full model:
class NEG_loss(nn.Module):
def __init__(self, vocab_size, embed_size, neg_sampling_table=None):
"""
:param vocab_size: An int. The number of possible classes.
:param embed_size: An int. EmbeddingLockup size
:param num_sampled: An int. The number of sampled from noise examples
:param neg_sampling_table: A list of non negative floats. Class neg_sampling_table. None if
using uniform sampling. The neg_sampling_table are calculated prior to
estimation and can be of any form, e.g equation (5) in [1]
"""
super(NEG_loss, self).__init__()
self.device = t.device("cuda:0" if t.cuda.is_available() else "cpu")
self.vocab_size = vocab_size
self.embed_size = embed_size
self.out_embed.weight = nn.Parameter(
t.cat(
[
t.zeros(1, self.embed_size),
FT(self.vocab_size - 1, self.embed_size).uniform_(
-0.5 / self.embed_size, 0.5 / self.embed_size
),
]
)
)
self.in_embed.weight = nn.Parameter(
t.cat(
[
t.zeros(1, self.embed_size),
FT(self.vocab_size - 1, self.embed_size).uniform_(
-0.5 / self.embed_size, 0.5 / self.embed_size
),
]
)
)
self.neg_sampling_table = neg_sampling_table
if self.neg_sampling_table is not None:
assert min(self.neg_sampling_table) >= 0, "Each weight should be >= 0"
self.neg_sampling_table = Variable(t.from_numpy(neg_sampling_table)).float()
# TODO this is bad - find more elegant solution
def sample(self, num_sample):
"""
draws a sample from classes based on neg_sampling_table
"""
return self.neg_sampling_table[
t.randint(0, len(self.neg_sampling_table), (num_sample,))
]
def forward(self, input_labels, out_labels, num_sampled):
"""
:param input_labels: Tensor with shape of [batch_size] of Long type
:param out_labels: Tensor with shape of [batch_size, window_size] of Long type
:param num_sampled: An int. The number of sampled from noise examples
:return: Loss estimation with shape of [1]
loss defined in Mikolov et al. Distributed Representations of Words and Phrases and their Compositionality
papers.nips.cc/paper/5021-distributed-representations-of-words-and-phrases-and-their-compositionality.pdf
"""
input_labels = input_labels.to("cuda:0")
out_labels = out_labels.to("cuda:0")
batch_size = out_labels.size()[0]
input_ = self.in_embed(input_labels.view(-1))
output = self.out_embed(out_labels.view(-1))
if self.neg_sampling_table is not None:
noise_sample_count = batch_size * num_sampled
draw = self.sample(noise_sample_count)
noise = draw.view(batch_size, num_sampled).long()
else:
noise = Variable(
t.Tensor(batch_size, num_sampled)
.uniform_(0, self.vocab_size - 1)
.long()
)
# if use_cuda:
noise = noise.to(self.device)
noise = self.out_embed(noise).neg()
log_target = (input_ * output).sum(1).squeeze().sigmoid().log()
""" ā[batch_size * window_size, num_sampled, embed_size] * [batch_size * window_size, embed_size, 1] ->
ā[batch_size, num_sampled, 1] -> [batch_size] """
sum_log_sampled = (
t.bmm(noise, input_.unsqueeze(2)).sigmoid().log().sum(1).squeeze()
)
loss = log_target + sum_log_sampled
return -loss.mean()
def input_embeddings(self):
return self.in_embed.weight.detach().cpu().numpy()
and training loop and optimizer code
# NEG loss and optim
neg = NEG_loss(vocab_size, dim, neg_sampling_table=neg_sampling_dist)
neg.to("cuda:0")
optimizer = Adam(neg.parameters(), 0.01)
sys.stdout.write("BEGINNING TRAINING\n")
for i in range(epochs):
sys.stdout.write("-" * 35 + "\n")
sys.stdout.write(f"EPOCH {i+1}\n")
for i, batch in enumerate(tqdm(dataloader)):
input_, output_ = batch
optimizer.zero_grad()
loss = neg(input_, output_, neg_samples)
loss.backward()
optimizer.step()
sys.stdout.write("-" * 35 + "\n")
This model is about as simple as it gets so I am a bit surprised I am having this issue. Everything is fine on a smaller corpus (7k batches of size 128) but gets nan
-y with a larger corpus (200k batches of size 128).
Anyone see anything immediately wrong or have any tips to figure out what is going?
Any help is appreciated. Also pretty new to pytorch - so if you see anything that is a bit dumb feel free to point it out