Inconsistency When Loading a Checkpoint

When loading a checkpoint and recalculating the validation loss with which I had saved it, the validation loss values do not agree. Here is an example with an LSTM on MNIST (I know one would usually use CNNs on MNIST, but it’s a simple toy model that does its job here):

class LSTM(nn.Module):
    def __init__(self, input_size, num_layers, hidden_size, num_classes, sequence_length, bidirectional):
        super(LSTM, self).__init__()
        self.input_size = input_size
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.num_classes = num_classes
        self.sequence_length = sequence_length
        self.bidirectional = bidirectional

        if self.bidirectional == True:
            self.num_directions = 2
        else:
            self.num_directions = 1

        self.LSTM = nn.LSTM(
                        input_size=self.input_size,
                        hidden_size=self.hidden_size,
                        num_layers=self.num_layers,
                        batch_first=True, 
                        dropout=0,
                        bidirectional=self.bidirectional
        )
        self.dropout = nn.Dropout(p=0.2, inplace=False)
        self.fc = nn.Linear(in_features=self.num_directions*self.hidden_size*self.sequence_length, out_features=self.num_classes)

    def forward(self, x):
        """Standard forward pass."""
        # Initialize hidden state:
        h0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(device)
        c0 = torch.zeros(self.num_layers * self.num_directions, x.size(0), self.hidden_size).to(device)
        # Forward prop:
        out, (hidden_state, cell_state) = self.LSTM(x, (h0, c0))
        out = self.dropout(out)
        out = out.reshape(out.shape[0], -1)
        out = self.fc(out)
        return out

So initializing the model:

model = LSTM(
               input_size=28, 
               num_layers=3, 
               hidden_size=256, 
               num_classes=10, 
               sequence_length=28, 
               bidirectional=False
).to(device)

My training loop looks like this:

val_losses = []

for epoch in range(num_epochs): 
     val_loss_per_batch = []
     # do the training stuff ... 
     with torch.no_grad(): 
         model.eval()
         for val_batch_idx, (val_images, val_labels) in enumerate(val_loader):
              val_images = val_images.to(device)
              val_images = torch.squeeze(input = val_images, dim = 1) # shape: (batch_size, 28, 28), otherwise RNN throws error
              val_labels = val_labels.to(device)
              val_output = model(val_images)
              val_loss_per_batch.append(loss_sum(val_output, val_labels).detach().cpu().item())
     val_losses.apend(np.sum(val_loss_per_batch)/val_loader.dataset.__len__())

And here is what loss_sum looks like:

loss_sum = nn.CrossEntropyLoss(reduction='sum')

(I use reduction = 'sum' because if I used the mean to calculate the loss for each batch, and then again took the mean of all batches, the mean of all batches would be slightly biased, since the batches do not all have the same size, i.e. the smallest batch would be weighted more in the mean. But I believe this to be a technical detail.)

After training for a small number of epochs, I save a checkpoint:

checkpoint = {'state_dict' : model.state_dict(), 'optimizer': optimizer.state_dict()}
torch.save(checkpoint, 'checkpoint.pth.tar')

And then when I want to load the checkpoint:

model = LSTM(...) # initialization just as above
checkpoint = torch.load('checkpoint.pth.tar')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])

model.to(device)
model.eval()

Now I do the whole thing as in the loop, i.e. with torch.no_grad: ..., but I get another validation loss than when I had saved the checkpoint! And when I load the checkpoint again, I get different values all the time. This is how my val_loader looks like:

val_loader = DataLoader(dataset=val_subset, shuffle=True, batch_size=1024)

I already tried shuffle = False, but to no avail. Can anybody confirm this weird behavior? And does anybody know why it occurs?

(This issue is resolved.)