Loaded model results not same as models saved in thread

I found this in my code, and I looked over literally 100 times and finally decided to make a replica in mnist version. And it seems like this happens as well.

Concept Description

Main Thread:
train data and toss the weight to the tester thread at a certain period of time

Tester Thread:
Thread waits through queue for next weight. When it gets a new thread it starts to do some test (pretend this is the evaluation during train). To make a similar environment I put time.sleep(3) right after tester takes the weights from the queue, so that the main thread still trains some more in its original copy.


  1. Saved model weights are not same as the one in loaded model
  2. Sometimes they are same but they don’t give the same results


  1. For reproducibility, I will attach the test code at the bottom
  2. This is not reproducible all the time so I gave up. I just mentioned for someone has same experience.

As you can see, the weight at count1 and its loaded weights are exactly same as below:
Weights at test count 1 while training

Weights loaded count 1

ref: https://korchris.github.io/2019/08/23/mnist/
#Importing Library

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

#--- NN
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 32)
        self.fc3 = nn.Linear(32, 10)

    def forward(self, x):
        x = x.float()
        h1 = F.relu(self.fc1(x.view(-1, 784)))
        h2 = F.relu(self.fc2(h1))
        h3 = self.fc3(h2)
        return F.log_softmax(h3, dim=1)

print("init model done")

##--- Define some inits and prepare data
batch_size = 64
test_batch_size = 1000
epochs = 10
lr = 0.1
momentum = 0.5
no_cuda = True
seed = 1
log_interval = 200

use_cuda = not no_cuda and torch.cuda.is_available()


device = torch.device("cuda" if use_cuda else "cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}

print("set vars and device done")

transform = transforms.Compose([
                 transforms.Normalize((0.1307,), (0.3081,))])

train_loader = torch.utils.data.DataLoader(
  datasets.MNIST('../data', train=True, download=True, 
    batch_size = batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('../data', train=False, download=True,
    batch_size=test_batch_size, shuffle=True, **kwargs)

##--- Instantiate model
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

##--- define train, test
def train(log_interval, model, device, train_loader, optimizer, epoch):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

from torch.multiprocessing import Queue
from threading import Thread
import time

class Tester(Thread):
    def __init__(self, model, model_name, queue, print_log=True):
        self.done = False
        self.model = model
        self.model_name = model_name
        self.queue = queue
        self.print_log = print_log
    def set_done(self):
        self.done = True
    def run(self):
        count = 0
        while not self.done:
            count += 1
            weights = self.queue.get()
            print('[count: {}] weights: {}'.format(count, self.model.state_dict()))

            test_loss = 0
            correct = 0
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.to(device), target.to(device)
                    output = model(data)
                    test_loss += F.nll_loss(output, target, reduction='sum').item() 
                    pred = output.argmax(dim=1, keepdim=True)
                    correct += pred.eq(target.view_as(pred)).sum().item()

            test_loss /= len(test_loader.dataset)

            print('[count: {}] Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format
                  (count, test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))
            save_name = '{}_{}.pth'.format(self.model_name, count)
            print('Save path: {}'.format(save_name))
            torch.save(model, save_name)

# Train and save model starts here!

from copy import deepcopy
queue = Queue()
tester = Tester(model=Net().to(device), model_name='model', queue=queue, print_log=False)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

count = 0
th_list = list()
for epoch in range(1, 2):
    for batch_idx, (data, target) in enumerate(train_loader):
        count += 1
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.nll_loss(output, target)
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
            weights = {k: deepcopy(v) for k, v in model.state_dict().items()}


# Checking loaded model!
model = torch.load('model_1.pth')
test_loss = 0
correct = 0
with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += F.nll_loss(output, target, reduction='sum').item() 
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()

test_loss /= len(test_loader.dataset)

print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format
      (test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))

