I’m trying to follow the pytorch tutorial about saving model checkpoints (https://pytorch.org/tutorials/beginner/saving_loading_models.html). I ran it with the cbow snippet from pytorch tutorial page, and it gave me an unexpected behavior. The loss after running training the model for the first time is about 7-8, but after reloading the model, the loss increased significantly to about 3-400.
Code to reproduce
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
import os
CONTEXT_SIZE = 2 # 2 words to the left, 2 to the right
raw_text = """We are about to study the idea of a computational process.
Computational processes are abstract beings that inhabit computers.
As they evolve, processes manipulate other abstract things called data.
The evolution of a process is directed by a pattern of rules
called a program. People create programs to direct processes. In effect,
we conjure the spirits of the computer with our spells.""".split()
# By deriving a set from `raw_text`, we deduplicate the array
vocab = set(raw_text)
vocab_size = len(vocab)
word_to_ix = {word: i for i, word in enumerate(vocab)}
data = []
for i in range(2, len(raw_text) - 2):
context = [raw_text[i - 2], raw_text[i - 1],
raw_text[i + 1], raw_text[i + 2]]
target = raw_text[i]
data.append((context, target))
class CBOW(nn.Module):
def __init__(self,vocab_size, embedding_dim, context_size):
super(CBOW, self).__init__()
self.embeddings= nn.Embedding(vocab_size,embedding_dim)
self.linear1 = nn.Linear(2*context_size*embedding_dim, 100)
self.linear2 = nn.Linear(100,vocab_size)
def forward(self, inputs):
embed = self.embeddings(inputs).view(1,-1)
out = F.relu(self.linear1(embed))
out = self.linear2(out)
log_probs = F.log_softmax(out,dim=1)
return log_probs
def make_context_vector(context, word_to_ix):
idxs = [word_to_ix[w] for w in context]
return torch.tensor(idxs, dtype=torch.long)
make_context_vector(data[0][0], word_to_ix) # example
losses = []
loss_function = nn.NLLLoss()
optimizer = optim.Adam(model.parameters())
if os.path.exists("model.pt"):
checkpoint = torch.load("model.pt")
epoch = checkpoint["epoch"]
loss = checkpoint["loss"]
print("Loaded checkpoint")
print("EPOCH ",epoch)
print("PREVIOUS LOSS ",loss)
for epoch in range(EPOCHS):
total_loss = 0
for context, target in data:
context_idx = make_context_vector(context,word_to_ix)
log_probs = model(context_idx)
loss = loss_function(log_probs,torch.tensor([word_to_ix[target]],dtype=torch.long))
total_loss += loss
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"loss": total_loss.item()
print("NEW LOSSES",losses)