When loading a checkpoint and recalculating the validation loss with which I had saved it, the validation loss values do not agree. Here is an example with an LSTM on MNIST (I know one would usually use CNNs on MNIST, but it’s a simple toy model that does its job here):
class LSTM(nn.Module):
def __init__(self, input_size, num_layers, hidden_size, num_classes, sequence_length, bidirectional):
super(LSTM, self).__init__()
self.input_size = input_size
self.num_layers = num_layers
self.hidden_size = hidden_size
self.num_classes = num_classes
self.sequence_length = sequence_length
self.bidirectional = bidirectional
if self.bidirectional == True:
self.num_directions = 2
else:
self.num_directions = 1
self.LSTM = nn.LSTM(
input_size=self.input_size,
hidden_size=self.hidden_size,
num_layers=self.num_layers,
batch_first=True,
dropout=0,
bidirectional=self.bidirectional
)
self.dropout = nn.Dropout(p=0.2, inplace=False)
self.fc = nn.Linear(in_features=self.num_directions*self.hidden_size*self.sequence_length, out_features=self.num_classes)
def forward(self, x):
"""Standard forward pass."""
# Initialize hidden state:
h0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(device)
# Forward prop:
out, (hidden_state, cell_state) = self.LSTM(x, (h0, c0))
out = self.dropout(out)
out = out.reshape(out.shape[0], -1)
out = self.fc(out)
return out
So initializing the model:
model = LSTM(
input_size=28,
num_layers=3,
hidden_size=256,
num_classes=10,
sequence_length=28,
bidirectional=False
).to(device)
My training loop looks like this:
val_losses = []
for epoch in range(num_epochs):
val_loss_per_batch = []
# do the training stuff ...
with torch.no_grad():
model.eval()
for val_batch_idx, (val_images, val_labels) in enumerate(val_loader):
val_images = val_images.to(device)
val_images = torch.squeeze(input = val_images, dim = 1) # shape: (batch_size, 28, 28), otherwise RNN throws error
val_labels = val_labels.to(device)
val_output = model(val_images)
val_loss_per_batch.append(loss_sum(val_output, val_labels).detach().cpu().item())
val_losses.apend(np.sum(val_loss_per_batch)/val_loader.dataset.__len__())
And here is what loss_sum
looks like:
loss_sum = nn.CrossEntropyLoss(reduction='sum')
(I use reduction = 'sum'
because if I used the mean to calculate the loss for each batch, and then again took the mean of all batches, the mean of all batches would be slightly biased, since the batches do not all have the same size, i.e. the smallest batch would be weighted more in the mean. But I believe this to be a technical detail.)
After training for a small number of epochs, I save a checkpoint:
checkpoint = {'state_dict' : model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'checkpoint.pth.tar')
And then when I want to load the checkpoint:
model = LSTM(...) # initialization just as above
checkpoint = torch.load('checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
model.to(device)
model.eval()
Now I do the whole thing as in the loop, i.e. with torch.no_grad: ...
, but I get another validation loss than when I had saved the checkpoint! And when I load the checkpoint again, I get different values all the time. This is how my val_loader
looks like:
val_loader = DataLoader(dataset=val_subset, shuffle=True, batch_size=1024)
I already tried shuffle = False
, but to no avail. Can anybody confirm this weird behavior? And does anybody know why it occurs?